diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000000..294863ce8ac --- /dev/null +++ b/.gitignore @@ -0,0 +1,48 @@ +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch +*.ipch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +# vim tags +tags +.tags +.*.swp + +# Editors +.vscode + +# build-in-source directory +build* + +# emacs temporary/backup files +.\#* +\#*\# +*~ + +# GDB temporary files +.gdb_history \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 306e6ca6491..e5903f3747f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,10 +1,25 @@ -cmake_minimum_required(VERSION 3.5) +cmake_minimum_required(VERSION 3.14) + +# Check support for CUDA/HIP in Cmake project(composable_kernel) list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") +enable_testing() + +find_package(ROCM REQUIRED PATHS /opt/rocm) + +include(ROCMInstallTargets) +include(ROCMPackageConfigHelpers) +include(ROCMSetupVersion) +include(ROCMInstallSymlinks) +include(ROCMCreatePackage) include(CheckCXXCompilerFlag) +rocm_setup_version(VERSION 1.0.0) +include(TargetFlags) +list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip) + ## C++ enable_language(CXX) set(CMAKE_CXX_STANDARD 17) @@ -30,35 +45,54 @@ message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}") message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}") message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}") -set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") link_libraries(${OpenMP_gomp_LIBRARY}) link_libraries(${OpenMP_pthread_LIBRARY}) ## HIP find_package(HIP REQUIRED) -message(STATUS "Build with HIP ${hip_VERSION}") +# Override HIP version in config.h, if necessary. +# The variables set by find_package() can't be overwritten, +# therefore let's use intermediate variables. +set(CK_HIP_VERSION_MAJOR "${HIP_VERSION_MAJOR}") +set(CK_HIP_VERSION_MINOR "${HIP_VERSION_MINOR}") +set(CK_HIP_VERSION_PATCH "${HIP_VERSION_PATCH}") +if( DEFINED CK_OVERRIDE_HIP_VERSION_MAJOR ) + set(CK_HIP_VERSION_MAJOR "${CK_OVERRIDE_HIP_VERSION_MAJOR}") + message(STATUS "CK_HIP_VERSION_MAJOR overriden with ${CK_OVERRIDE_HIP_VERSION_MAJOR}") +endif() +if( DEFINED CK_OVERRIDE_HIP_VERSION_MINOR ) + set(CK_HIP_VERSION_MINOR "${CK_OVERRIDE_HIP_VERSION_MINOR}") + message(STATUS "CK_HIP_VERSION_MINOR overriden with ${CK_OVERRIDE_HIP_VERSION_MINOR}") +endif() +if( DEFINED CK_OVERRIDE_HIP_VERSION_PATCH ) + set(CK_HIP_VERSION_PATCH "${CK_OVERRIDE_HIP_VERSION_PATCH}") + message(STATUS "CK_HIP_VERSION_PATCH overriden with ${CK_OVERRIDE_HIP_VERSION_PATCH}") +endif() +message(STATUS "Build with HIP ${HIP_VERSION}") + + +rocm_create_package( + NAME composablekernel + DESCRIPTION "High Performance Composable Kernel for AMD GPUs" + MAINTAINER "MIOpen Kernels Dev Team " + LDCONFIG +) ## half -#find_path(HALF_INCLUDE_DIR half.hpp) +set(HALF_INCLUDE_DIR "${PROJECT_SOURCE_DIR}/external/include/half") message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}") -# CMAKE_CXX_FLAGS -SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") -if(BUILD_DEV) - string(APPEND CMAKE_CXX_FLAGS " -Werror -Weverything") -endif() -message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") - ## tidy include(EnableCompilerWarnings) -set(MIOPEN_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name) +set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name) if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+") - set(MIOPEN_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter) + set(CK_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter) # Enable tidy on hip -elseif(MIOPEN_BACKEND STREQUAL "HIP" OR MIOPEN_BACKEND STREQUAL "HIPNOGPU") - set(MIOPEN_TIDY_ERRORS ALL) +elseif(CK_BACKEND STREQUAL "HIP" OR CK_BACKEND STREQUAL "HIPNOGPU") + set(CK_TIDY_ERRORS ALL) endif() + include(ClangTidy) enable_clang_tidy( CHECKS @@ -150,13 +184,12 @@ enable_clang_tidy( -cppcoreguidelines-narrowing-conversions -altera-struct-pack-align -cppcoreguidelines-prefer-member-initializer - - ${MIOPEN_TIDY_CHECKS} - ${MIOPEN_TIDY_ERRORS} + ${CK_TIDY_CHECKS} + ${CK_TIDY_ERRORS} HEADER_FILTER "\.hpp$" EXTRA_ARGS - -DMIOPEN_USE_CLANG_TIDY + -DCK_USE_CLANG_TIDY ) include(CppCheck) @@ -180,19 +213,59 @@ enable_cppcheck( unmatchedSuppression FORCE SOURCES - host/host_tensor/src - host/driver_offline/src - composable_kernel/src/kernel_wrapper + library/src INCLUDE - host/host_tensor/include - host/solver/include - host/driver_offline/include - composable_kernel/include/* ${CMAKE_CURRENT_SOURCE_DIR}/include ${CMAKE_CURRENT_BINARY_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/library/include DEFINE CPPCHECK=1 __linux__=1 ) -add_subdirectory(host) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) + +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_BINARY_DIR}/include + ${PROJECT_SOURCE_DIR}/library/include +) + + +SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") +if(BUILD_DEV) + add_compile_options(-Werror) + add_compile_options(-Weverything) +endif() +message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") + +add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) + +add_subdirectory(library) +add_subdirectory(example) +add_subdirectory(test) +add_subdirectory(profiler) + +#Create an interface target for the include only files and call it "composablekernels" +include(CMakePackageConfigHelpers) + +set(version 1.0.0) +write_basic_package_version_file( + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake" + VERSION "${version}" + COMPATIBILITY AnyNewerVersion +) + +configure_package_config_file(${CMAKE_CURRENT_SOURCE_DIR}/Config.cmake.in + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake" + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel + NO_CHECK_REQUIRED_COMPONENTS_MACRO +) + +install(FILES + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake" + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake" + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel +) diff --git a/Config.cmake.in b/Config.cmake.in new file mode 100644 index 00000000000..12b5c331aeb --- /dev/null +++ b/Config.cmake.in @@ -0,0 +1,11 @@ +@PACKAGE_INIT@ + +set(_composable_kernel_supported_components device_operations host_tensor) + +foreach(_comp ${composable_kernel_FIND_COMPONENTS}) + if(NOT _comp IN_LIST _composable_kernel_supported_components) + set(composable_kernel_FOUND False) + set(composable_kernel_NOT_FOUND_MESSAGE "Unsupported component: ${_comp}") + endif() + include("${CMAKE_CURRENT_LIST_DIR}/composable_kernel${_comp}Targets.cmake") +endforeach() diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000000..79c961144a3 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,90 @@ +FROM ubuntu:18.04 + +ARG ROCMVERSION=5.1 +ARG OSDB_BKC_VERSION + +RUN set -xe + +ARG BUILD_THREADS=8 +ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ +# Add rocm repository +RUN apt-get update +RUN apt-get install -y wget gnupg +RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - +RUN sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list" +RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add - +RUN sh -c "echo deb https://apt.kitware.com/ubuntu/ bionic main | tee -a /etc/apt/sources.list" + +# ADD requirements.txt requirements.txt +# Install dependencies +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ + apt-utils \ + build-essential \ + cmake-data=3.15.1-0kitware1 \ + cmake=3.15.1-0kitware1 \ + curl \ + g++ \ + gdb \ + git \ + hip-rocclr \ + jq \ + libelf-dev \ + libncurses5-dev \ + libnuma-dev \ + libpthread-stubs0-dev \ + llvm-amdgpu \ + pkg-config \ + python \ + python3.8 \ + python-dev \ + python3-dev \ + python-pip \ + python3-pip \ + software-properties-common \ + wget \ + rocm-dev \ + rocm-device-libs \ + rocm-cmake \ + vim \ + zlib1g-dev \ + openssh-server \ + clang-format-10 \ + kmod && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Setup ubsan environment to printstacktrace +RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer +ENV UBSAN_OPTIONS=print_stacktrace=1 + +# Install an init system +RUN wget https://github.com/Yelp/dumb-init/releases/download/v1.2.0/dumb-init_1.2.0_amd64.deb +RUN dpkg -i dumb-init_*.deb && rm dumb-init_*.deb + +# Install cget +RUN pip install cget + +# Install rclone +RUN pip install https://github.com/pfultz2/rclone/archive/master.tar.gz + +ARG PREFIX=/opt/rocm +# Install dependencies +RUN cget install pfultz2/rocm-recipes +# Install rbuild +RUN pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/6d78a0553babdaea8d2da5de15cbda7e869594b8.tar.gz +# Install packages for processing the performance results +RUN pip3 install --upgrade pip +RUN pip3 install sqlalchemy +RUN pip3 install pymysql +RUN pip3 install pandas +RUN pip3 install setuptools-rust +RUN pip3 install sshtunnel +# Setup ubsan environment to printstacktrace +ENV UBSAN_OPTIONS=print_stacktrace=1 + +ENV LC_ALL=C.UTF-8 +ENV LANG=C.UTF-8 +ADD rbuild.ini /rbuild.ini +ADD dev-requirements.txt dev-requirements.txt +RUN rbuild prepare -s develop -d $PREFIX +RUN groupadd -f render diff --git a/Jenkinsfile b/Jenkinsfile new file mode 100644 index 00000000000..b912062e647 --- /dev/null +++ b/Jenkinsfile @@ -0,0 +1,412 @@ +def rocmnode(name) { + return 'rocmtest && miopen && ' + name +} + +def show_node_info() { + sh """ + echo "NODE_NAME = \$NODE_NAME" + lsb_release -sd + uname -r + cat /sys/module/amdgpu/version + ls /opt/ -la + """ +} + +def cmake_build(Map conf=[:]){ + + def compiler = conf.get("compiler","/opt/rocm/bin/hipcc") + def config_targets = conf.get("config_targets","check") + def debug_flags = "-g -fno-omit-frame-pointer -fsanitize=undefined -fno-sanitize-recover=undefined " + conf.get("extradebugflags", "") + def build_envs = "CTEST_PARALLEL_LEVEL=4 " + conf.get("build_env","") + def prefixpath = conf.get("prefixpath","/opt/rocm") + def setup_args = conf.get("setup_args","") + + if (prefixpath != "/usr/local"){ + setup_args = setup_args + " -DCMAKE_PREFIX_PATH=${prefixpath} " + } + + def build_type_debug = (conf.get("build_type",'release') == 'debug') + + //cmake_env can overwrite default CXX variables. + def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","") + + def package_build = (conf.get("package_build","") == "true") + + if (package_build == true) { + config_targets = "package" + } + + if(conf.get("build_install","") == "true") + { + config_targets = 'install ' + config_targets + setup_args = ' -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install' + setup_args + } else{ + setup_args = ' -DBUILD_DEV=On' + setup_args + } + + if(build_type_debug){ + setup_args = " -DCMAKE_BUILD_TYPE=debug -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'" + setup_args + }else{ + setup_args = " -DCMAKE_BUILD_TYPE=release" + setup_args + } + + def pre_setup_cmd = """ + echo \$HSA_ENABLE_SDMA + ulimit -c unlimited + rm -rf build + mkdir build + rm -rf install + mkdir install + cd build + """ + def setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") + // reduce parallelism when compiling, clang uses too much memory + def build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j\$(( \$(nproc) / 1 )) ${config_targets}") + def execute_cmd = conf.get("execute_cmd", "") + + def cmd = conf.get("cmd", """ + ${pre_setup_cmd} + ${setup_cmd} + ${build_cmd} + ${execute_cmd} + """) + + echo cmd + sh cmd + + // Only archive from master or develop + if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "master")) { + archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true + } +} + +def buildHipClangJob(Map conf=[:]){ + show_node_info() + + env.HSA_ENABLE_SDMA=0 + checkout scm + + def image = "composable_kernels" + def prefixpath = conf.get("prefixpath", "/opt/rocm") + def gpu_arch = conf.get("gpu_arch", "gfx908") + + // Jenkins is complaining about the render group + // def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + if (conf.get("enforce_xnack_on", false)) { + dockerOpts = dockerOpts + " --env HSA_XNACK=1" + } + def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg GPU_ARCH='${gpu_arch}' " + + def variant = env.STAGE_NAME + + + def retimage + gitStatusWrapper(credentialsId: '7126e5fe-eb51-4576-b52b-9aaf1de8f0fd', gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + try { + retimage = docker.build("${image}", dockerArgs + '.') + withDockerContainer(image: image, args: dockerOpts) { + timeout(time: 5, unit: 'MINUTES') + { + sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' + } + } + } + catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){ + echo "The job was cancelled or aborted" + throw e + } + catch(Exception ex) { + retimage = docker.build("${image}", dockerArgs + "--no-cache .") + withDockerContainer(image: image, args: dockerOpts) { + timeout(time: 5, unit: 'MINUTES') + { + sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' + } + } + } + + withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { + timeout(time: 5, unit: 'HOURS') + { + cmake_build(conf) + } + } + } + return retimage +} + +def reboot(){ + build job: 'reboot-slaves', propagate: false , parameters: [string(name: 'server', value: "${env.NODE_NAME}"),] +} + + + + + +def buildHipClangJobAndReboot(Map conf=[:]){ + try{ + buildHipClangJob(conf) + } + catch(e){ + echo "throwing error exception for the stage" + echo 'Exception occurred: ' + e.toString() + throw e + } + finally{ + if (!conf.get("no_reboot", false)) { + reboot() + } + } +} + + +def runCKProfiler(Map conf=[:]){ + show_node_info() + + env.HSA_ENABLE_SDMA=0 + checkout scm + + def image = "composable_kernels" + def prefixpath = conf.get("prefixpath", "/opt/rocm") + def gpu_arch = conf.get("gpu_arch", "gfx908") + + // Jenkins is complaining about the render group + // def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + if (conf.get("enforce_xnack_on", false)) { + dockerOpts = dockerOpts + " --env HSA_XNACK=1" + } + def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg GPU_ARCH='${gpu_arch}' " + + def variant = env.STAGE_NAME + + + def retimage + gitStatusWrapper(credentialsId: '7126e5fe-eb51-4576-b52b-9aaf1de8f0fd', gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + try { + retimage = docker.build("${image}", dockerArgs + '.') + withDockerContainer(image: image, args: dockerOpts) { + timeout(time: 5, unit: 'MINUTES') + { + sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' + } + } + } + catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){ + echo "The job was cancelled or aborted" + throw e + } + catch(Exception ex) { + retimage = docker.build("${image}", dockerArgs + "--no-cache .") + withDockerContainer(image: image, args: dockerOpts) { + timeout(time: 5, unit: 'MINUTES') + { + sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' + } + } + } + + withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { + timeout(time: 5, unit: 'HOURS') + { + cmake_build(conf) + dir("script"){ + def perf_log = "perf_gemm_${gpu_arch}.log" + sh "rm -f ${perf_log}" + sh "echo Branch name: ${env.BRANCH_NAME} > ${perf_log}" + sh "./profile_gemm.sh gemm 0 0 0 1 0 5 | tee -a ${perf_log}" + sh "./profile_gemm.sh gemm 1 0 0 1 0 5 | tee -a ${perf_log}" + sh "./profile_gemm.sh gemm 2 0 0 1 0 5 | tee -a ${perf_log}" + sh "./profile_gemm.sh gemm 3 0 0 1 0 5 | tee -a ${perf_log}" + sh "./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a ${perf_log}" + sh "./profile_gemm.sh gemm 1 1 0 1 0 5 | tee -a ${perf_log}" + sh "./profile_gemm.sh gemm 2 1 0 1 0 5 | tee -a ${perf_log}" + sh "./profile_gemm.sh gemm 3 1 0 1 0 5 | tee -a ${perf_log}" + sh "./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a ${perf_log}" + sh "./profile_gemm.sh gemm 1 2 0 1 0 5 | tee -a ${perf_log}" + sh "./profile_gemm.sh gemm 2 2 0 1 0 5 | tee -a ${perf_log}" + sh "./profile_gemm.sh gemm 3 2 0 1 0 5 | tee -a ${perf_log}" + sh "./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a ${perf_log}" + sh "./profile_gemm.sh gemm 1 3 0 1 0 5 | tee -a ${perf_log}" + sh "./profile_gemm.sh gemm 2 3 0 1 0 5 | tee -a ${perf_log}" + sh "./profile_gemm.sh gemm 3 3 0 1 0 5 | tee -a ${perf_log}" + //results will be parsed, stored, and analyzed within the python script + //the script will return 0 if the performance criteria are met + //or return 1 if the criteria are not met + archiveArtifacts "${perf_log}" + sh "python3 parse_perf_data.py ${perf_log} " + } + } + } + } + return retimage +} + + +def runPerfTest(Map conf=[:]){ + try{ + runCKProfiler(conf) + } + catch(e){ + echo "throwing error exception in performance tests" + echo 'Exception occurred: ' + e.toString() + throw e + } + finally{ + if (!conf.get("no_reboot", false)) { + reboot() + } + } +} + +pipeline { + agent none + options { + parallelsAlwaysFailFast() + } + // environment{ + // variable = value + // } + stages{ + stage("Static checks") { + parallel{ + // enable after we move from hipcc to hip-clang + // stage('Tidy') { + // agent{ label rocmnode("nogpu") } + // environment{ + // // setup_cmd = "CXX='/opt/rocm/bin/hipcc' cmake -DBUILD_DEV=On .. " + // build_cmd = "make -j\$(nproc) -k analyze" + // } + // steps{ + // buildHipClangJobAndReboot(build_cmd: build_cmd, no_reboot:true, prefixpath: '/opt/rocm', build_type: 'debug') + // } + // } + // we will build and run ckProfiler release version later, during the performance test stage + //stage('Build Profiler: Release, gfx908') + //{ + // agent { label rocmnode("nogpu")} + // environment{ + // setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """ + // } + // steps{ + // buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') + // } + //} + //stage('Build Profiler: Debug, gfx908') + //{ + // agent { label rocmnode("nogpu")} + // environment{ + // setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """ + // } + // steps{ + // // until we stabilize debug build due to compiler crashes + // catchError(buildResult: 'SUCCESS', stageResult: 'FAILURE') { + // buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Debug') + // } + // } + //} + stage('Clang Format') { + agent{ label rocmnode("nogpu") } + environment{ + execute_cmd = "find .. -iname \'*.h\' \ + -o -iname \'*.hpp\' \ + -o -iname \'*.cpp\' \ + -o -iname \'*.h.in\' \ + -o -iname \'*.hpp.in\' \ + -o -iname \'*.cpp.in\' \ + -o -iname \'*.cl\' \ + | grep -v 'build/' \ + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-10 -style=file {} | diff - {}\'" + } + steps{ + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) + } + } + } + } + stage("Tests") + { + parallel + { + stage("Run Tests: gfx908") + { + agent{ label rocmnode("gfx908")} + environment{ + setup_args = """ -D CMAKE_CXX_FLAGS=" --offload-arch=gfx900 --offload-arch=gfx906 --offload-arch=gfx908 --offload-arch=gfx90a -O3 " -DBUILD_DEV=On """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release') + } + + } + stage("Run Tests: gfx90a") + { + agent{ label rocmnode("gfx90a")} + environment{ + setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx90a -O3 " -DBUILD_DEV=On """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release') + } + + } + + } + } + stage("Client App") + { + parallel + { + stage("Run Client App") + { + agent{ label rocmnode("gfx908")} + environment{ + setup_args = """ -D -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " """ + execute_args = """ cd ../test/client_app && rm -rf build && mkdir build && cd build && cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" .. && make """ + } + steps{ + buildHipClangJobAndReboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + } + } + } + } + stage("Performance Tests") + { + parallel + { + stage("Run ckProfiler: gfx908") + { + agent{ label rocmnode("gfx908")} + environment{ + setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """ + dbuser = "${dbuser}" + dbpassword = "${dbpassword}" + dbsship = "${dbsship}" + dbsshport = "${dbsshport}" + dbsshuser = "${dbsshuser}" + dbsshpassword = "${dbsshpassword}" + } + steps{ + runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') + } + } + } + } + + // enable after the cmake file supports packaging + // stage("Packages") { + // when { + // expression { params.BUILD_PACKAGES && params.TARGET_NOGPU && params.DATATYPE_NA } + // } + // parallel { + // stage("Package /opt/rocm") { + // agent{ label rocmnode("nogpu") } + // steps{ + // buildHipClangJobAndReboot( package_build: "true", prefixpath: '/opt/rocm', gpu_arch: "gfx906;gfx908;gfx90a") + // } + // } + // } + // } + } +} diff --git a/README.md b/README.md index 4f071d5896c..9d7b578046a 100644 --- a/README.md +++ b/README.md @@ -1,177 +1,55 @@ -# How to build and run - -# Docker -``` -docker run \ --it \ ---rm \ ---privileged \ ---group-add sudo \ --w /root/workspace \ --v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ -rocm/tensorflow:rocm4.2-tf2.4-dev \ +## Docker script +```bash +docker run \ +-it \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ +rocm/tensorflow:rocm4.3.1-tf2.6-dev \ /bin/bash ``` -# Install Boost for online compilation -https://www.boost.org/doc/libs/1_66_0/more/getting_started/unix-variants.html#easy-build-and-install - - -# Build -Add path of Boost -``` - export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH -``` - -``` +## Build +```bash mkdir build && cd build ``` -cmake cmd. Need to Specify target ID, example below is gfx908 -``` -cmake \ --D CMAKE_BUILD_TYPE=Release \ --D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 -O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ --D HIP_ONLINE_COMPILER_FLAGS="-DCK_AMD_GPU_GFX908" \ --D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_PREFIX_PATH=/opt/rocm \ --D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ +```bash +# Need to specify target ID, example below is gfx908 and gfx90a +cmake \ +-D BUILD_DEV=OFF \ +-D CMAKE_BUILD_TYPE=Release \ +-D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 --offload-arch=gfx90a -O3" \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ .. ``` -Build drivers: \ -``conv_fwd_driver_offline`` is (offline compilation) driver for forward convolution, \ -``conv_bwd_driver_offline`` is (offline compilation) driver for backward-data convolution \ -``conv_fwd_driver_online`` is (online compilation) driver for forward convolution -``` - make -j conv_fwd_driver_offline - make -j conv_bwd_driver_offline - make -j conv_fwd_driver_online +### Build and Run Examples +```bash + make -j examples ``` +Instructions for running each individual examples are under ```example/``` -# Run -* layout: 0 = NCHW; 1 = NHWC -* algo: algorithm -* verify: 0 = no verification; 1 = do verification -* init: 0 ~ 5. initialization method -* log: 0 = no log; 1 = do log -* repeat: number of time kernel being launched -``` -######################################################## layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads - ./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 - ./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 - ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 - ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 - ./host/driver_offline/conv_bwd_driver_offline 1 5 0 0 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1 +## Tests +```bash + make -j tests + make test ``` -# Result -Forward convoltuion, FP16, NCHW +## Build ckProfiler +```bash + make -j ckProfiler ``` -./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 +Instructions for running ckProfiler are under ```profiler/``` -layout: 0 -in: dim 4, lengths {128, 192, 71, 71}, strides {967872, 5041, 71, 1} -wei: dim 4, lengths {256, 192, 3, 3}, strides {1728, 9, 3, 1} -out: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1296, 36, 1} -InLeftPads size 2, {1, 1, } -InRightPads size 2, {1, 1, } -ConvStrides size 2, {2, 2, } -ConvDilations size 2, {1, 1, } -device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw -a_k0_m_k1_grid_desc{216, 256, 8} -b_k0_n_k1_grid_desc{216, 165888, 8} -c_m_n_grid_desc{ 256, 165888} -launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 1 times... -Average time : 1.4155 ms, 103.686 TFlop/s -``` -Forward convoltuion, FP16, NCHW -``` - ./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 - - layout: 0 -in: dim 4, lengths {256, 256, 14, 14}, strides {50176, 196, 14, 1} -wei: dim 4, lengths {1024, 256, 3, 3}, strides {2304, 9, 3, 1} -out: dim 4, lengths {256, 1024, 14, 14}, strides {200704, 196, 14, 1} -InLeftPads size 2, {1, 1, } -InRightPads size 2, {1, 1, } -ConvStrides size 2, {1, 1, } -ConvDilations size 2, {1, 1, } -device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw -a_k0_m_k1_grid_desc{288, 1024, 8} -b_k0_n_k1_grid_desc{288, 50176, 8} -c_m_n_grid_desc{ 1024, 50176} -launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 1 times... -Average time : 2.21357 ms, 106.959 TFlop/s - ``` - - Forward convolution, FP16, NHWC - ``` - ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 - - layout: 1 -in: dim 4, lengths {128, 71, 71, 192}, strides {967872, 13632, 192, 1} -wei: dim 4, lengths {256, 3, 3, 192}, strides {1728, 576, 192, 1} -out: dim 4, lengths {128, 36, 36, 256}, strides {331776, 9216, 256, 1} -InLeftPads size 2, {1, 1, } -InRightPads size 2, {1, 1, } -ConvStrides size 2, {2, 2, } -ConvDilations size 2, {1, 1, } -device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk -a_k0_m_k1_grid_desc{216, 165888, 8} -b_k0_n_k1_grid_desc{216, 256, 8} -c_m_n_grid_desc{ 165888, 256} -launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 1 times... -Average time : 1.12014 ms, 131.025 TFlop/s - ``` - - Forward convolution, FP16, NHWC - ``` - ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 - - layout: 1 -in: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1} -wei: dim 4, lengths {1024, 3, 3, 256}, strides {2304, 768, 256, 1} -out: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1} -InLeftPads size 2, {1, 1, } -InRightPads size 2, {1, 1, } -ConvStrides size 2, {1, 1, } -ConvDilations size 2, {1, 1, } -device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk -a_k0_m_k1_grid_desc{288, 50176, 8} -b_k0_n_k1_grid_desc{288, 1024, 8} -c_m_n_grid_desc{ 50176, 1024} -launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 1 times... -Average time : 1.86877 ms, 126.693 TFlop/s - ``` - - Backward data convolution, FP16, NHWC - ``` - ./host/driver_offline/conv_bwd_driver_offline 1 1 0 3 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1 - - layout: 1 -in: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1} -wei: dim 4, lengths {256, 3, 3, 1024}, strides {9216, 3072, 1024, 1} -out: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1} -InLeftPads size 2, {1, 1, } -InRightPads size 2, {1, 1, } -ConvStrides size 2, {1, 1, } -ConvDilations size 2, {1, 1, } -device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk -a_k0_m_k1_grid_desc{288, 50176, 8} -b_k0_n_k1_grid_desc{288, 1024, 8} -c_m_n_grid_desc{ 50176, 1024} -launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 1 times... -Average time : 2.22461 ms, 106.428 TFlop/s -``` +## Caveat +### Kernel Timing and Verification +CK's own kernel timer will warn up kernel once, and then run it multiple times +to get average kernel time. For some kernels that use atomic add, this will cause +output buffer to be accumulated multiple times, causing verfication failure. +To work around it, do not use CK's own timer and do verification at the same time. +CK's own timer and verification in each example and ckProfiler can be enabled or +disabled from command line. diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 9f193b20904..78133af0315 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -66,7 +66,7 @@ else() -Wunreachable-code -Wunused - -Wno-sign-compare + -Wsign-compare -Wno-extra-semi-stmt ) if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang") diff --git a/cmake/TargetFlags.cmake b/cmake/TargetFlags.cmake new file mode 100644 index 00000000000..4f83fb5d396 --- /dev/null +++ b/cmake/TargetFlags.cmake @@ -0,0 +1,50 @@ + +function(get_target_property2 VAR TARGET PROPERTY) + get_target_property(_pflags ${TARGET} ${PROPERTY}) + if(_pflags) + set(${VAR} ${_pflags} PARENT_SCOPE) + else() + set(${VAR} "" PARENT_SCOPE) + endif() +endfunction() + + +macro(append_flags FLAGS TARGET PROPERTY PREFIX) + get_target_property2(_pflags ${TARGET} ${PROPERTY}) + foreach(FLAG ${_pflags}) + if(TARGET ${FLAG}) + target_flags(_pflags2 ${FLAG}) + string(APPEND ${FLAGS} " ${_pflags2}") + else() + string(APPEND ${FLAGS} " ${PREFIX}${FLAG}") + endif() + endforeach() +endmacro() + +macro(append_link_flags FLAGS TARGET PROPERTY) + get_target_property2(_pflags ${TARGET} ${PROPERTY}) + foreach(FLAG ${_pflags}) + if(TARGET ${FLAG}) + target_flags(_pflags2 ${FLAG}) + string(APPEND ${FLAGS} " ${_pflags2}") + elseif(FLAG MATCHES "^-.*") + string(APPEND ${FLAGS} " ${FLAG}") + elseif(EXISTS ${FLAG}) + string(APPEND ${FLAGS} " ${FLAG}") + else() + string(APPEND ${FLAGS} " -l${FLAG}") + endif() + endforeach() +endmacro() + +function(target_flags FLAGS TARGET) + set(_flags) + append_flags(_flags ${TARGET} "INTERFACE_COMPILE_OPTIONS" "") + append_flags(_flags ${TARGET} "INTERFACE_COMPILE_DEFINITIONS" "-D") + append_flags(_flags ${TARGET} "INTERFACE_INCLUDE_DIRECTORIES" "-isystem ") + append_flags(_flags ${TARGET} "INTERFACE_LINK_DIRECTORIES" "-L ") + append_flags(_flags ${TARGET} "INTERFACE_LINK_OPTIONS" "") + append_link_flags(_flags ${TARGET} "INTERFACE_LINK_LIBRARIES" "") + # message("_flags: ${_flags}") + set(${FLAGS} ${_flags} PARENT_SCOPE) +endfunction() diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake new file mode 100644 index 00000000000..959bc4f4b0e --- /dev/null +++ b/cmake/googletest.cmake @@ -0,0 +1,39 @@ +include(FetchContent) + +set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against") + +if(GOOGLETEST_DIR) + set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override") +endif() + +message(STATUS "Fetching GoogleTest") + +list(APPEND GTEST_CMAKE_CXX_FLAGS + -Wno-undef + -Wno-reserved-identifier + -Wno-global-constructors + -Wno-missing-noreturn + -Wno-disabled-macro-expansion + -Wno-used-but-marked-unused + -Wno-switch-enum + -Wno-zero-as-null-pointer-constant + -Wno-unused-member-function + -Wno-comma + -Wno-old-style-cast +) +message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}") + +FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG b85864c64758dec007208e56af933fc3f52044ee +) + +# Will be necessary for windows build +# set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +FetchContent_MakeAvailable(googletest) + +target_compile_options(gtest PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) +target_compile_options(gtest_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) +target_compile_options(gmock PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) +target_compile_options(gmock_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) diff --git a/composable_kernel/include/gridwise_operation_wrapper.hpp b/composable_kernel/include/gridwise_operation_wrapper.hpp deleted file mode 100644 index 0a1e07ec571..00000000000 --- a/composable_kernel/include/gridwise_operation_wrapper.hpp +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef CK_GRIDWISE_OPERATION_KERNEL_WRAPPER -#define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - run_gridwise_operation(Xs... xs) -{ - GridwiseOp{}.Run(xs...); -} - -#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp deleted file mode 100644 index 5cc2f2393ee..00000000000 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp +++ /dev/null @@ -1,183 +0,0 @@ -#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP -#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP - -#include "common_header.hpp" -#include "threadwise_gemm_dlops_v3.hpp" - -namespace ck { - -template -struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 -{ - struct MatrixIndex - { - index_t k; - index_t h; - index_t w; - }; - - // HACK: fix this @Jing Zhang - static constexpr index_t KPerThreadSubC = 4; - - static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - - static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number<1>{}, Number{}, Number{})); - - static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number<1>{}, Number{}, Number{})); - - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1>, - 1, - ThreadGemmADataPerRead_K, - 1>; - - __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3() - : c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, - a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)} - { - static_assert(BlockMatrixA::IsKnownAtCompileTime() && - BlockMatrixB::IsKnownAtCompileTime() && - ThreadMatrixC::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0), - "wrong! K dimension not consistent\n"); - - constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed - constexpr index_t H = BlockMatrixB{}.GetLength(I2); - constexpr index_t W = BlockMatrixB{}.GetLength(I3); - - static_assert(K % KPerThread == 0 && H % HPerThread == 0 && W % WPerThread == 0, - "wrong! Cannot evenly divide work among\n"); - - constexpr auto KThreadCluster = K / KPerThread; - constexpr auto HThreadCluster = H / HPerThread; - constexpr auto WThreadCluster = W / WPerThread; - - static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster, - "wrong! wrong blocksize\n"); - } - - __device__ static constexpr auto GetThreadMatrixCLengths() - { - return Sequence{}; - } - - __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) - { - constexpr index_t H = BlockMatrixB{}.GetLength(Number<2>{}); - constexpr index_t W = BlockMatrixB{}.GetLength(Number<3>{}); - - constexpr auto num_w_threads = W / WPerThread; - constexpr auto num_h_threads = H / HPerThread; - constexpr auto num_hw_threads = num_w_threads * num_h_threads; - - index_t k_thread_id = thread_id / num_hw_threads; - index_t hw_thread_id = thread_id % num_hw_threads; - - index_t h_thread_id = hw_thread_id / num_w_threads; - index_t w_thread_id = hw_thread_id % num_w_threads; - - return MatrixIndex{k_thread_id, h_thread_id, w_thread_id}; - } - - template - __device__ void Run(const ABlockBuffer& a_block_buf, - const BThreadBuffer& b_thread_buf, - CThreadBuffer& c_thread_buf) const - { - static_assert( - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value && - "wrong! inconsistent type"); - - constexpr auto I0 = Number<0>{}; - - constexpr auto a_block_mtx = BlockMatrixA{}; - - constexpr auto EPerBlock = a_block_mtx.GetLength(I0); - - // HACK: fix this @Jing Zhang - constexpr auto HoPerThreadSubC = 2; - constexpr auto WoPerThreadSubC = 2; - - static_assert(KPerThread % KPerThreadSubC == 0, ""); - static_assert(HPerThread % HoPerThreadSubC == 0, ""); - static_assert(WPerThread % WoPerThreadSubC == 0, ""); - - // thread A buffer for GEMM - StaticBuffer - a_thread_buf; - - constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3{}; - - static_for<0, EPerBlock, EPerThreadLoop>{}([&](auto e_begin) { - static_for<0, KPerThread, KPerThreadSubC>{}([&](auto k_begin) { - a_thread_copy_.Run(a_block_mtx, - make_tuple(e_begin, k_begin), - a_block_buf, - a_thread_mtx_, - make_tuple(I0, I0), - a_thread_buf); - - static_for<0, HPerThread, HoPerThreadSubC>{}([&](auto h_begin) { - static_for<0, WPerThread, WoPerThreadSubC>{}([&](auto w_begin) { - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I0), - b_thread_buf, - make_tuple(e_begin, I0, h_begin, w_begin), - c_thread_buf, - make_tuple(k_begin, I0, h_begin, w_begin)); - }); - }); - }); - }); - } - - template - __device__ void MoveASliceWindow(const BlockMatrixA&, - const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx) - { - a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx); - } - - private: - MatrixIndex c_thread_begin_mtx_idx_; - - AThreadCopy a_thread_copy_; -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp deleted file mode 100644 index 36c67832042..00000000000 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp +++ /dev/null @@ -1,282 +0,0 @@ -#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP -#define CK_BLOCKWISE_GEMM_XDLOPS_HPP - -#include "common_header.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "xdlops_gemm.hpp" -#include "tensor_adaptor.hpp" - -namespace ck { - -template -struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - static constexpr index_t WaveSize = 64; - - static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); - static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); - - static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0); - - static constexpr auto xdlops_gemm = XdlopsGemm{}; - - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - - StaticBufferV2, MRepeat * NRepeat, true> - c_thread_buf_; - - __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } - - __device__ static auto GetWaveIdx() - { - const index_t thread_id = get_thread_local_1d_id(); - - constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); - } - - __device__ static auto CalculateAThreadOriginDataIndex() - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_m = wave_idx[I0]; - - const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); - - return make_tuple(xdlops_a_idx[I0], 0, waveId_m, xdlops_a_idx[I1], 0); - } - - __device__ static auto CalculateBThreadOriginDataIndex() - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_n = wave_idx[I1]; - - const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); - - return make_tuple(xdlops_b_idx[I0], 0, waveId_n, xdlops_b_idx[I1], 0); - } - - template - __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_m = wave_idx[I0]; - const auto waveId_n = wave_idx[I1]; - - const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); - - constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1, 2>{})); - - constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1, 2>{})); - - const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( - make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; - const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( - make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; - - return make_tuple(c_thread_m, c_thread_n); - } - - __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() - { - static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && - BK0NK1BlockDesc::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0), - "wrong! K0 dimension not consistent"); - - static_assert(AK0MK1BlockDesc{}.GetLength(I2) == BK0NK1BlockDesc{}.GetLength(I2), - "wrong! K1 dimension not consistent"); - - static_assert(BlockSize == MWaves * NWaves * WaveSize, - "BlockSize != MWaves * NWaves * WaveSize\n"); - - static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, - "wrong!"); - } - - __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor() - { - constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); - - constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; - constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; - constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; - constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; - - return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N)); - } - - __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor() - { - constexpr auto c_m0_n0_m1_n1_m2_n2_block_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - Number{})); - - return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_block_desc); - } - - template - __host__ __device__ static constexpr auto - MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) - { - const auto c_m0_n0_m1_n1_m2_n2_grid_desc = transform_tensor_descriptor( - c_m_n_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)), - make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); - - return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc); - } - - __host__ __device__ static constexpr auto MakeAK0M0M1M2K1BlockDescriptor() - { - return transform_tensor_descriptor( - AK0MK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); - } - - __host__ __device__ static constexpr auto MakeBK0N0N1N2K1BlockDescriptor() - { - return transform_tensor_descriptor( - BK0NK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); - } - - static constexpr auto a_k0_m0_m1_m2_k1_block_desc = MakeAK0M0M1M2K1BlockDescriptor(); - static constexpr auto b_k0_n0_n1_n2_k1_block_desc = MakeBK0N0N1N2K1BlockDescriptor(); - - template - __device__ void Run(const ABlockBuffer& a_block_buf, - const BBlockBuffer& b_block_buf, - CThreadBuffer& c_thread_buf) const - { - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - - static_for<0, MRepeat, 1>{}([&](auto m0) { - // read A - a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc, - make_tuple(I0, m0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0, I0, I0), - a_thread_buf); - - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read B - b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc, - make_tuple(I0, n0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0, I0, I0), - b_thread_buf); - - static_for<0, K0, xdlops_gemm.K0PerXdlops>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, K1, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = a_thread_buf - [Number{}]; - b_thread_vec.template AsType()(i) = b_thread_buf - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0)); - - xdlops_gemm.template Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVector(Number{})); - }); - }); - }); - } - - private: - // A[K, M] - static constexpr auto a_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, I1, Number{})); - - // B[K, N] - static constexpr auto b_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, I1, Number{})); - - static constexpr auto c_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); - - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - K1, - K1>; - - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - K1, - K1>; - - AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; - BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp deleted file mode 100644 index 0214b713522..00000000000 --- a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp +++ /dev/null @@ -1,170 +0,0 @@ -#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP -#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "cluster_descriptor.hpp" -#include "threadwise_tensor_slice_transfer.hpp" - -namespace ck { - -// this version does following things to avoid scratch memory issue -// 1. Use StaticallyIndexedArray instead of C array for thread buffer -// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor -// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate -template -struct BlockwiseTensorSliceTransfer_v4 -{ - static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); - - using Index = MultiIndex; - - __device__ constexpr BlockwiseTensorSliceTransfer_v4(const SrcDesc& src_desc, - const Index& src_block_slice_origin, - const DstDesc& dst_desc, - const Index& dst_block_slice_origin) - : threadwise_transfer_( - src_desc, make_zero_multi_index(), dst_desc, make_zero_multi_index()) - - { - static_assert(nDim == remove_reference_t>::GetNumOfDimension() && - nDim == remove_reference_t>::GetNumOfDimension() && - nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() && - nDim == ThreadClusterLengths::Size() && - nDim == ThreadClusterArrangeOrder::Size() && - nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), - "wrong! nDim not consistent"); - - static_assert( - is_same{}, - "wrong! threads should be mapped to cover entire slicing window"); - - static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), - "wrong! BlockSize too small"); - - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( - make_multi_index(get_thread_local_1d_id())); - - const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{}; - - threadwise_transfer_.SetSrcSliceOrigin(src_desc, - src_block_slice_origin + thread_data_idx_begin); - threadwise_transfer_.SetDstSliceOrigin(dst_desc, - dst_block_slice_origin + thread_data_idx_begin); - } - } - - template - __device__ void - RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks); - } - } - - template - __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.RunRead(src_desc, src_buf); - } - } - - template - __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.RunWrite(dst_desc, dst_buf); - } - } - - __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); - } - } - - // SrcMoveSliceWindowStepHack to control index calculation move slice window - template - __device__ void - MoveSrcSliceWindow(const SrcDesc& src_desc, - const Index& step, - const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.MoveSrcSliceWindow( - src_desc, step, src_move_slice_window_step_hack); - } - } - - __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); - } - } - - private: - static constexpr auto thread_cluster_desc_ = - make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); - - using ThreadwiseTransfer = - ThreadwiseTensorSliceTransfer_v3; - - ThreadwiseTransfer threadwise_transfer_; -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp deleted file mode 100644 index 2653dd43401..00000000000 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp +++ /dev/null @@ -1,650 +0,0 @@ -#ifndef CK_GRIDWISE_GEMM_V1R3_HPP -#define CK_GRIDWISE_GEMM_V1R3_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_dlops_v2r3.hpp" -#include "blockwise_tensor_slice_transfer_v2.hpp" -#include "threadwise_tensor_slice_transfer_v2.hpp" -#include "threadwise_tensor_slice_set.hpp" - -namespace ck { - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_dlops_v1r3( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc, - const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc, - const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, - const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor, - integral_constant{}, - integral_constant{}); -} -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER -// pass tensor descriptor by CONSTANT void pointer -// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to -// non-modifiable parameter address space, so compiler can enable corresponding optimization -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_dlops_v1r3(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc, - const void CONSTANT* p_b_k0_n0_n1_k1_grid_desc, - const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, - const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - // first cast void CONSTANT void* to void* - // second cast void* to Desc* - // the copy constructor of tensor descriptor doesn't take address_space(4) - const auto a_k0_m0_m1_k1_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_a_k0_m0_m1_k1_grid_desc)); - const auto b_k0_n0_n1_k1_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_b_k0_n0_n1_k1_grid_desc)); - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = - *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc)); - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor)); - - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor, - integral_constant{}, - integral_constant{}); -} -#endif - -template -struct GridwiseGemmDlops_km_kn_mn_v1r3 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - // K1 should be Number<...> - static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2); - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // TODO: change this. I think it needs multi-dimensional alignment - constexpr auto max_lds_align = K1; - - // TODO: check alignment - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - - // TODO: check alignment - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - - // TODO: check alignment - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_aligned_space_size = - math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_aligned_space_size = - math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); - - return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); - } - - __host__ __device__ static constexpr bool - CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc) - { - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - - return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && - K0 == b_k0_n_k1_grid_desc.GetLength(I0) && - K1 == a_k0_m_k1_grid_desc.GetLength(I2) && - K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && - (M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0); - } - - __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) - { - const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1); - - return grid_size; - } - - __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) - { - const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1; - - return has_main_k_block_loop; - } - - __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) - { - const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0; - - return has_double_tail_k_block_loop; - } - - __host__ __device__ static constexpr auto - MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc) - { - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); - - const auto M1 = Number{}; - const auto M0 = M / M1; - - const auto a_k0_m0_m1_k1_grid_desc = - transform_tensor_descriptor(a_k0_m_k1_grid_desc, - make_tuple(make_pass_through_transform(K0), - make_unmerge_transform(make_tuple(M0, M1)), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - return a_k0_m0_m1_k1_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc) - { - const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); - - const auto N1 = Number{}; - const auto N0 = N / N1; - - const auto b_k0_n0_n1_k1_grid_desc = - transform_tensor_descriptor(b_k0_n_k1_grid_desc, - make_tuple(make_pass_through_transform(K0), - make_unmerge_transform(make_tuple(N0, N1)), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - return b_k0_n0_n1_k1_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - constexpr auto M11 = - Number{}; - constexpr auto N11 = - Number{}; - - constexpr auto M10 = M1 / M11; - constexpr auto N10 = N1 / N11; - - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor( - c_m_n_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), - make_unmerge_transform(make_tuple(N0, N10, N11))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); - - return c_m0_m10_m11_n0_n10_n11_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), - make_tuple(Sequence<0, 1>{}), - make_tuple(Sequence<0>{})); - - return c_blockid_to_m0_n0_block_cluster_adaptor; - } - - using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{})); - using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{})); - using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{})); - using CBlockIdToM0N0BlockClusterAdaptor = - decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{})); - - template - __device__ static void - Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - FloatAB* __restrict__ p_shared_block, - const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc, - const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc, - const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, - const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor, - integral_constant, - integral_constant) - { - const auto a_global_buf = make_dynamic_buffer( - p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); - - // divide block work by [M, N] - const auto c_m0_n0_block_cluster_idx = - c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( - make_multi_index(get_block_1d_id())); - - // HACK: this force index data into SGPR - const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); - const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); - - // TODO: change this. I think it needs multi-dimensional alignment - constexpr auto max_lds_align = K1; - - // TODO: check alignment - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, I1, Number{}, K1), max_lds_align); - - // TODO: check alignment - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, I1, Number{}, K1), max_lds_align); - - // TODO: check alignment - // A matrix in LDS memory, for blockwise GEMM - constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - - // TODO: check alignment - // B matrix in LDS memory, for blockwise GEMM - constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - - static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() == - a_k0_m_k1_block_desc.GetElementSpaceSize() && - b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() == - b_k0_n_k1_block_desc.GetElementSpaceSize() && - "wrong!"); - - // A matrix blockwise copy - auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< - BlockSize, - InMemoryDataOperationEnum_t::Set, - Sequence, - ABlockTransferThreadSliceLengths_K0_M0_M1_K1, - ABlockTransferThreadClusterLengths_K0_M0_M1_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_k0_m0_m1_k1_grid_desc), - decltype(a_k0_m0_m1_k1_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2, 3>, - ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths - ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths - ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder - Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder - false, - true>(a_k0_m0_m1_k1_grid_desc, - make_multi_index(0, im0, 0, 0), - a_k0_m0_m1_k1_block_desc, - make_multi_index(0, 0, 0, 0)); - - // B matrix blockwise copy - auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< - BlockSize, - InMemoryDataOperationEnum_t::Set, - Sequence, - BBlockTransferThreadSliceLengths_K0_N0_N1_K1, - BBlockTransferThreadClusterLengths_K0_N0_N1_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_k0_n0_n1_k1_grid_desc), - decltype(b_k0_n0_n1_k1_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2, 3>, - BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths - BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths - BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder - Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder - false, - true>(b_k0_n0_n1_k1_grid_desc, - make_multi_index(0, in0, 0, 0), - b_k0_n0_n1_k1_block_desc, - make_multi_index(0, 0, 0, 0)); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, MPerBlockM1] is in LDS - // b_mtx[KPerBlocl, NPerBlockN1] is in LDS - // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in - // register - const auto blockwise_gemm = - BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< - BlockSize, - FloatAB, - FloatAB, - FloatAcc, - decltype(a_k0_m_k1_block_desc), - decltype(b_k0_n_k1_block_desc), - M1PerThreadM111, - N1PerThreadN111, - KPerThread, - M11N11ThreadClusterM110Xs, - M11N11ThreadClusterN110Xs, - M1PerThreadM111, - N1PerThreadN111>{}; - - constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = - decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); - - constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed( - sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_aligned_space_size = math::integer_least_multiple( - a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_aligned_space_size = math::integer_least_multiple( - b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block_double = p_shared_block; - FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; - - // register allocation for output - auto c_thread_buf = make_static_buffer( - c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); - - ThreadwiseTensorSliceSet_v1{} - .Run(c_m10_m11_n10_n11_thread_desc, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - FloatAcc{0}); - - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); - - auto a_block_even_buf = make_dynamic_buffer( - p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); - auto b_block_even_buf = make_dynamic_buffer( - p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); - - auto a_block_odd_buf = make_dynamic_buffer( - p_a_block_double + a_block_aligned_space_size, - a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); - auto b_block_odd_buf = make_dynamic_buffer( - p_b_block_double + b_block_aligned_space_size, - b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); - - // LDS double buffer: preload data into LDS - { - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); - - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); - } - - if constexpr(HasMainKBlockLoop) - { - const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0); - - index_t k_block_data_begin = 0; - - // LDS double buffer: main body - // use Do-While loop instead of For loop to simplify control flow - do - { - // even iteration - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, - a_block_slice_copy_step, - AGridMoveSliceWindowStepHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, - b_block_slice_copy_step, - BGridMoveSliceWindowStepHacks{}); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, - a_block_even_buf, - b_block_even_buf, - c_thread_buf); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); - - // odd iteration - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, - a_block_slice_copy_step, - AGridMoveSliceWindowStepHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, - b_block_slice_copy_step, - BGridMoveSliceWindowStepHacks{}); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); - - k_block_data_begin += 2 * KPerBlock; - } while(k_block_data_begin < K0 - 2 * KPerBlock); - } - - // LDS double buffer: tail - if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left - { - a_blockwise_copy.MoveSrcSliceWindow( - a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{}); - b_blockwise_copy.MoveSrcSliceWindow( - b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{}); - - __syncthreads(); - - // LDS double buffer: load last data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); - - // LDS double buffer: store last data to LDS - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); - - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); - } - else // if has 1 iteration left - { - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); - } - - // output: register to global memory - { - constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - Number{}, - I1, - Number{}, - Number{})); - - const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = - blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( - get_thread_local_1d_id()); - - ThreadwiseTensorSliceTransfer_v1r3< - FloatAcc, - FloatC, - decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), - decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), - Sequence<1, - c_m10_m11_n10_n11_thread_tensor_lengths[I0], - c_m10_m11_n10_n11_thread_tensor_lengths[I1], - 1, - c_m10_m11_n10_n11_thread_tensor_lengths[I2], - c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{c_m0_m10_m11_n0_n10_n11_grid_desc, - make_multi_index(im0, - c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], - c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], - in0, - c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], - c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])} - .Run(c_m0_m10_m11_n0_n10_n11_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_grid_buf, - CGridStepHacks{}); - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp deleted file mode 100644 index 86e047c965a..00000000000 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ /dev/null @@ -1,639 +0,0 @@ -#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP -#define CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_set.hpp" - -namespace ck { - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AK0MK1GridDesc a_k0_m_k1_grid_desc, - const BK0NK1GridDesc b_k0_n_k1_grid_desc, - const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - const CBlockClusterAdaptor c_block_cluster_adaptor) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); -} -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k0_m_k1_grid_desc, - const void CONSTANT* p_b_k0_n_k1_grid_desc, - const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - const void CONSTANT* p_c_block_cluster_adaptor) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - const auto a_k0_m_k1_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc)); - const auto b_k0_n_k1_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc)); - const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = - *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc)); - const auto c_block_cluster_adaptor = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor)); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); -} -#endif - -template -struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - // K1 should be Number<...> - static constexpr auto K1 = Number{}; - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_k0_m_k1_block_desc = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_k0_n_k1_block_desc = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size = - math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); - } - - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - __host__ __device__ static constexpr bool - CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - index_t M01, - index_t N01) - { - static_assert(is_known_at_compile_time>::value, - "wrong! K1 need to be known at compile-time"); - - static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && - (NPerBlock % (NRepeat * NPerXDL)) == 0, - "Invalid tuning param!"); - - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - - if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && - K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K1 == a_k0_m_k1_grid_desc.GetLength(I2) && - K1 == b_k0_n_k1_grid_desc.GetLength(I2))) - return false; - - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) - return false; - - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) - return false; - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return true; - } - - __host__ __device__ static constexpr index_t - CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); - - return grid_size; - } - - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) - { - const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; - - return has_main_k0_block_loop; - } - - __host__ __device__ static constexpr auto - MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_k0_m_k1_block_desc = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_k0_n_k1_block_desc = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - using BlockwiseGemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; - - return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); - } - - // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto - MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); - - const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor); - - return c_blockid_to_m0_n0_block_cluster_adaptor; - } - - using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); - using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1)); - - template - __device__ static void Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - FloatAB* __restrict__ p_shared_block, - const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - const CBlockClusterAdaptor& c_block_cluster_adaptor) - { - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); - - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - - // divide block work by [M, N] - const auto block_work_idx = - c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); - - // lds max alignment - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_k0_m_k1_block_desc = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_k0_n_k1_block_desc = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_K0_M_K1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_k0_m_k1_grid_desc), - decltype(a_k0_m_k1_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>(a_k0_m_k1_grid_desc, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_k0_m_k1_block_desc, - make_multi_index(0, 0, 0)); - - // B matrix blockwise copy - auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_k0_n_k1_grid_desc), - decltype(b_k0_n_k1_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>(b_k0_n_k1_grid_desc, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_k0_n_k1_block_desc, - make_multi_index(0, 0, 0)); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[K0PerBlock, MPerBlock] is in LDS - // b_mtx[K0PerBlock, NPerBlock] is in LDS - // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in - // register - // sanity check - - auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - - auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block = p_shared_block; - FloatAB* p_b_block = p_shared_block + a_block_space_size; - - constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - - // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; - constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{}; - - // hack to control index calculation when move slice window for A and B matrix for - // threadwise copy - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{}; - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{}; - - auto a_block_buf = make_dynamic_buffer( - p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer( - p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); - - // preload data into LDS - { - a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); - b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); - - a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); - } - - // main body - index_t k0_block_data_begin = 0; - - if constexpr(HasMainKBlockLoop) - { - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, - a_block_slice_copy_step, - a_k0_m_k1_grid_move_slice_window_step_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc, - b_block_slice_copy_step, - b_k0_n_k1_grid_move_slice_window_step_hack); - - a_blockwise_copy.RunRead( - a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); - - block_sync_lds(); - - b_blockwise_copy.RunRead( - b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } - - // output: register to global memory - { - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = - blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor(); - - constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); - constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); - constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); - constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); - constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); - constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); - constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); - constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = - make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_grid = - m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; - - const index_t n_thread_data_on_grid = - n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; - - const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_grid_idx = - m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_grid)); - - const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_grid_idx = - n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_grid)); - - auto c_thread_copy = - ThreadwiseTensorSliceTransfer_v1r3, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{ - - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - make_multi_index(m_thread_data_on_grid_idx[I0], - n_thread_data_on_grid_idx[I0], - m_thread_data_on_grid_idx[I1], - n_thread_data_on_grid_idx[I1], - m_thread_data_on_grid_idx[I2], - m_thread_data_on_grid_idx[I3], - m_thread_data_on_grid_idx[I4], - n_thread_data_on_grid_idx[I2])}; - - c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_grid_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); - } - } -}; // namespace ck - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp deleted file mode 100644 index c635da57f4d..00000000000 --- a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp +++ /dev/null @@ -1,625 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_BLOCKWISE_HPP -#define CK_GRIDWISE_GENERIC_2D_REDUCTION_BLOCKWISE_HPP - -#include "data_type.hpp" -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_blockwise.hpp" - -#include "blockwise_tensor_slice_transfer.hpp" - -namespace ck { - -template -struct GridwiseReduction_xy_to_x_blockwise -{ - using opReduce = typename reduce_binary_operator::opType; - using preUnaryOpType = - typename reduce_unary_operator::preUnaryOp; - using posUnaryOpType = - typename reduce_unary_operator::posUnaryOp; - - static constexpr auto buffer2dDesc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - using blockwise_reduce = - BlockwiseReduction_2d_block_buffer; - - static constexpr index_t BlockBufferSize = buffer2dDesc.GetElementSize(); - - static constexpr auto I0 = Number<0>{}; - - template - __device__ static void Run(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global); - - template <> - __device__ static void Run<1>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global) - { - (void)ws_indices_global; - (void)indices_global; - - // LDS - __shared__ compType p_in_block_buffer[BlockBufferSize]; - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); - auto dst_global_buf = make_dynamic_buffer( - p_dst_global, dst1dDesc.GetElementSpaceSize()); - - auto in_block_buf = - make_dynamic_buffer(p_in_block_buffer, BlockBufferSize); - StaticBuffer accuValue_buf; - - accuValue_buf(I0) = zeroVal; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - const int divider = origReduceLen; - - const preUnaryOpType preUnaryOp(divider); - const posUnaryOpType posUnaryOp(divider); - - const index_t thread_local_id = get_thread_local_1d_id(); - const index_t block_global_1d_id = get_block_1d_id(); - - constexpr auto in_block_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number{})); - - using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>; - using ThreadClusterLengths = Sequence<1, BlockSize>; - - auto blockwise_src_load = - BlockwiseTensorSliceTransfer_v4, - ThreadSliceLengths, - ThreadClusterLengths, - Sequence<0, 1>, - srcDataType, - compType, - src2dDescType, - decltype(in_block_desc), - Sequence<0, 1>, - Sequence<0, 1>, - 1, - 1, - 1, - 1, - 1, - 1, - false, - true>(src2dDesc, - make_multi_index(block_global_1d_id, 0), - in_block_desc, - make_multi_index(0, 0)); - - constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize); - - const index_t toReduceBlocks = (toReduceLength + BlockSize - 1) / BlockSize; - - for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks; - reducedBlocks += GredAccessesPerThreadInBlock) - { - blockwise_src_load.RunRead(src2dDesc, src_global_buf); - blockwise_src_load.RunWrite(in_block_desc, in_block_buf); - - __syncthreads(); - - // do element-wise pre-reduction operation - blockwise_reduce::operate_on_elements(preUnaryOp, in_block_buf); - - index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock) - ? GredAccessesPerThreadInBlock - : toReduceBlocks - reducedBlocks; - blockwise_reduce::Reduce(in_block_buf, BlocksInOneOp, accuValue_buf(I0)); - - blockwise_src_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step); - } - - accuValue_buf(I0) = posUnaryOp(accuValue_buf[I0]); - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - // The first thread in the block stores the reduced result to the global location - // representing the block - if(thread_local_id == 0) - { - if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); - - StaticBuffer dstValue_buf; - - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); - - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - false>(dst1dDesc, - make_multi_index(block_global_1d_id)); - - StaticBuffer priorDstValue_buf; - - threadwise_dst_load.Run( - dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); - - dstValue_buf(I0) += priorDstValue_buf[I0] * beta; - } - - auto threadwise_dst_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - false>(dst1dDesc, - make_multi_index(block_global_1d_id)); - - threadwise_dst_store.Run( - ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf); - } - }; - - template <> - __device__ static void Run<2>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global) - { - (void)ws_indices_global; - - // LDS - __shared__ compType p_in_block_buffer[BlockBufferSize]; - __shared__ int block_indices_buffer[BlockBufferSize]; - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); - auto dst_global_val_buf = make_dynamic_buffer( - p_dst_global, dst1dDesc.GetElementSpaceSize()); - auto dst_global_idx_buf = make_dynamic_buffer( - indices_global, dst1dDesc.GetElementSpaceSize()); - - auto in_block_val_buf = - make_dynamic_buffer(p_in_block_buffer, BlockBufferSize); - auto in_block_idx_buf = - make_dynamic_buffer(block_indices_buffer, BlockBufferSize); - - StaticBuffer accuValue_buf; - StaticBuffer accuIndex_buf; - - accuValue_buf(I0) = zeroVal; - accuIndex_buf(I0) = 0; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - const int divider = origReduceLen; - - const preUnaryOpType preUnaryOp(divider); - - const index_t thread_local_id = get_thread_local_1d_id(); - const index_t block_global_1d_id = get_block_1d_id(); - - constexpr auto in_block_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number{})); - - using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>; - using ThreadClusterLengths = Sequence<1, BlockSize>; - - auto blockwise_src_load = - BlockwiseTensorSliceTransfer_v4, - ThreadSliceLengths, - ThreadClusterLengths, - Sequence<0, 1>, - srcDataType, - compType, - src2dDescType, - decltype(in_block_desc), - Sequence<0, 1>, - Sequence<0, 1>, - 1, - 1, - 1, - 1, - 1, - 1, - false, - true>(src2dDesc, - make_multi_index(block_global_1d_id, 0), - in_block_desc, - make_multi_index(0, 0)); - - constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize); - - const index_t toReduceBlocks = (toReduceLength + BlockSize - 1) / BlockSize; - - int indexOffset = 0; - - for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks; - reducedBlocks += GredAccessesPerThreadInBlock) - { - // load block data from global to LDS, no use of double buffers (to be improved) - blockwise_src_load.RunRead(src2dDesc, src_global_buf); - blockwise_src_load.RunWrite(in_block_desc, in_block_val_buf); - - __syncthreads(); - - // construct the indices for the current toReduce blocks - blockwise_reduce::init_buffer_indices(in_block_idx_buf, indexOffset); - - // unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually - // done here - blockwise_reduce::operate_on_elements(preUnaryOp, in_block_val_buf); - - index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock) - ? GredAccessesPerThreadInBlock - : toReduceBlocks - reducedBlocks; - - blockwise_reduce::Reduce2(in_block_val_buf, - in_block_idx_buf, - BlocksInOneOp, - accuValue_buf(I0), - accuIndex_buf(I0)); - - indexOffset += BlockBufferSize; - - blockwise_src_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step); - } - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - // The first thread in the block stores the reduced result to the global location - // representing the block - if(thread_local_id == 0) - { - if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); - - StaticBuffer dstValue_buf; - - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); - - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - false>(dst1dDesc, - make_multi_index(block_global_1d_id)); - - StaticBuffer priorDstValue_buf; - - threadwise_dst_load.Run(dst1dDesc, - dst_global_val_buf, - ReducedDataDesc, - make_tuple(I0), - priorDstValue_buf); - - dstValue_buf(I0) += priorDstValue_buf[I0] * beta; - } - - auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - false>(dst1dDesc, - make_multi_index(block_global_1d_id)); - - auto threadwise_dst_idx_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - false>(dst1dDesc, - make_multi_index(block_global_1d_id)); - - threadwise_dst_val_store.Run( - ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf); - threadwise_dst_idx_store.Run( - ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); - } - }; - - template <> - __device__ static void Run<3>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ ws_values_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global) - { - (void)origReduceLen; - - // LDS - __shared__ compType p_in_block_buffer[BlockBufferSize]; - __shared__ int block_indices_buffer[BlockBufferSize]; - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - const auto src_global_val_buf = - make_dynamic_buffer(ws_values_global, - src2dDesc.GetElementSpaceSize(), - type_convert{}(zeroVal)); - const auto src_global_idx_buf = make_dynamic_buffer( - ws_indices_global, src2dDesc.GetElementSpaceSize()); - auto dst_global_val_buf = make_dynamic_buffer( - p_dst_global, dst1dDesc.GetElementSpaceSize()); - auto dst_global_idx_buf = make_dynamic_buffer( - indices_global, dst1dDesc.GetElementSpaceSize()); - - auto in_block_val_buf = - make_dynamic_buffer(p_in_block_buffer, BlockBufferSize); - auto in_block_idx_buf = - make_dynamic_buffer(block_indices_buffer, BlockBufferSize); - - StaticBuffer accuValue_buf; - StaticBuffer accuIndex_buf; - - accuValue_buf(I0) = zeroVal; - accuIndex_buf(I0) = 0; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - - const index_t thread_local_id = get_thread_local_1d_id(); - const index_t block_global_1d_id = get_block_1d_id(); - - constexpr auto in_block_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number{})); - - using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>; - using ThreadClusterLengths = Sequence<1, BlockSize>; - - auto blockwise_src_val_load = - BlockwiseTensorSliceTransfer_v4, - ThreadSliceLengths, - ThreadClusterLengths, - Sequence<0, 1>, - srcDataType, - compType, - src2dDescType, - decltype(in_block_desc), - Sequence<0, 1>, - Sequence<0, 1>, - 1, - 1, - 1, - 1, - 1, - 1, - false, - true>(src2dDesc, - make_multi_index(block_global_1d_id, 0), - in_block_desc, - make_multi_index(0, 0)); - - auto blockwise_src_idx_load = - BlockwiseTensorSliceTransfer_v4, - ThreadSliceLengths, - ThreadClusterLengths, - Sequence<0, 1>, - int, - int, - src2dDescType, - decltype(in_block_desc), - Sequence<0, 1>, - Sequence<0, 1>, - 1, - 1, - 1, - 1, - 1, - 1, - false, - true>(src2dDesc, - make_multi_index(block_global_1d_id, 0), - in_block_desc, - make_multi_index(0, 0)); - - constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize); - - const index_t toReduceBlocks = (toReduceLength + BlockSize - 1) / BlockSize; - - for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks; - reducedBlocks += GredAccessesPerThreadInBlock) - { - // load block data from global to LDS, no use of double buffers (to be improved) - blockwise_src_val_load.RunRead(src2dDesc, src_global_val_buf); - blockwise_src_idx_load.RunRead(src2dDesc, src_global_idx_buf); - blockwise_src_val_load.RunWrite(in_block_desc, in_block_val_buf); - blockwise_src_idx_load.RunWrite(in_block_desc, in_block_idx_buf); - - __syncthreads(); - - index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock) - ? GredAccessesPerThreadInBlock - : toReduceBlocks - reducedBlocks; - - blockwise_reduce::Reduce2(in_block_val_buf, - in_block_idx_buf, - BlocksInOneOp, - accuValue_buf(I0), - accuIndex_buf(I0)); - - blockwise_src_val_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step); - blockwise_src_idx_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step); - } - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - // The first thread in the block stores the reduced result to the global location - // representing the block - if(thread_local_id == 0) - { - if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); - - StaticBuffer dstValue_buf; - - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); - - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - true>(dst1dDesc, - make_multi_index(block_global_1d_id)); - - StaticBuffer priorDstValue_buf; - - threadwise_dst_load.Run(dst1dDesc, - dst_global_val_buf, - ReducedDataDesc, - make_tuple(I0), - priorDstValue_buf); - - dstValue_buf(I0) += priorDstValue_buf[I0] * beta; - } - - auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(dst1dDesc, - make_multi_index(block_global_1d_id)); - - auto threadwise_dst_idx_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(dst1dDesc, - make_multi_index(block_global_1d_id)); - - threadwise_dst_val_store.Run( - ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf); - threadwise_dst_idx_store.Run( - ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); - } - }; -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp deleted file mode 100644 index adfeacc0374..00000000000 --- a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp +++ /dev/null @@ -1,503 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_THREADWISE_HPP -#define CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_THREADWISE_HPP - -#include "data_type.hpp" -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_threadwise.hpp" - -#include "threadwise_tensor_slice_transfer.hpp" - -namespace ck { - -template -struct GridwiseReduction_xy_to_x_direct_threadwise -{ - using opReduce = typename reduce_binary_operator::opType; - using preUnaryOpType = - typename reduce_unary_operator::preUnaryOp; - using posUnaryOpType = - typename reduce_unary_operator::posUnaryOp; - - static constexpr auto I0 = Number<0>{}; - - template - __device__ static void Run(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global); - - template <> - __device__ static void Run<1>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global) - { - (void)ws_indices_global; - (void)indices_global; - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); - auto dst_global_buf = make_dynamic_buffer( - p_dst_global, dst1dDesc.GetElementSpaceSize()); - - StaticBuffer - in_thread_buf; - - using threadwise_reduce = ThreadReduce; - - StaticBuffer accuValue_buf; - - accuValue_buf(I0) = zeroVal; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - const int divider = origReduceLen; - - const preUnaryOpType preUnaryOp(divider); - const posUnaryOpType posUnaryOp(divider); - - using ThreadBufferLengths = Sequence<1, GredThreadBufferLength>; - constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed( - make_tuple(Number<1>{}, Number{})); - - index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); - - auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2, - 1, - 1, - 1, - false>( - src2dDesc, make_multi_index(thread_global_1d_id, 0)); - - constexpr auto in_thread_copy_step = make_multi_index(0, GredThreadBufferLength); - - for(index_t reducedLength = 0; reducedLength < toReduceLength; - reducedLength += GredThreadBufferLength) - { - threadwise_src_load.Run( - src2dDesc, src_global_buf, ThreadBufferDesc, make_tuple(I0, I0), in_thread_buf); - - // do element-wise pre-reduction operation - threadwise_reduce::operate_on_elements(preUnaryOp, in_thread_buf); - - // do the reduction on the Thread Buffer - threadwise_reduce::Reduce(in_thread_buf, accuValue_buf(I0)); - - threadwise_src_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step); - } - - accuValue_buf(I0) = posUnaryOp(accuValue_buf[I0]); - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); - - StaticBuffer dstValue_buf; - - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); - - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - true>( - dst1dDesc, make_multi_index(thread_global_1d_id)); - - StaticBuffer priorDstValue_buf; - - threadwise_dst_load.Run( - dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); - - dstValue_buf(I0) += priorDstValue_buf[I0] * beta; - } - - auto threadwise_dst_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(dst1dDesc, - make_multi_index(thread_global_1d_id)); - - threadwise_dst_store.Run( - ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf); - }; - - template <> - __device__ static void Run<2>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global) - { - (void)ws_indices_global; - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); - auto dst_global_val_buf = make_dynamic_buffer( - p_dst_global, dst1dDesc.GetElementSpaceSize()); - auto dst_global_idx_buf = make_dynamic_buffer( - indices_global, dst1dDesc.GetElementSpaceSize()); - - StaticBuffer - in_thread_buf; - - using threadwise_reduce = ThreadReduce; - - StaticBuffer accuValue_buf; - StaticBuffer accuIndex_buf; - - accuValue_buf(I0) = zeroVal; - accuIndex_buf(I0) = 0; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - const int divider = origReduceLen; - - const preUnaryOpType preUnaryOp(divider); - - using ThreadBufferLengths = Sequence<1, GredThreadBufferLength>; - constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed( - make_tuple(Number<1>{}, Number{})); - - index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); - - auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2, - 1, - 1, - 1, - false>( - src2dDesc, make_multi_index(thread_global_1d_id, 0)); - - constexpr auto in_thread_copy_step = make_multi_index(0, GredThreadBufferLength); - - index_t indexStart = 0; - for(index_t reducedLength = 0; reducedLength < toReduceLength; - reducedLength += GredThreadBufferLength) - { - threadwise_src_load.Run( - src2dDesc, src_global_buf, ThreadBufferDesc, make_tuple(I0, I0), in_thread_buf); - - // unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually - // done here - threadwise_reduce::operate_on_elements(preUnaryOp, in_thread_buf); - - // do the reduction on the Thread Buffer - threadwise_reduce::Reduce2( - in_thread_buf, accuValue_buf(I0), accuIndex_buf(I0), indexStart); - - indexStart += GredThreadBufferLength; - - threadwise_src_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step); - } - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); - - StaticBuffer dstValue_buf; - - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); - - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - false>( - dst1dDesc, make_multi_index(thread_global_1d_id)); - - StaticBuffer priorDstValue_buf; - - threadwise_dst_load.Run( - dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); - - dstValue_buf(I0) += priorDstValue_buf[I0] * beta; - } - - auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - false>(dst1dDesc, - make_multi_index(thread_global_1d_id)); - - auto threadwise_dst_idx_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - false>(dst1dDesc, - make_multi_index(thread_global_1d_id)); - - threadwise_dst_val_store.Run( - ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf); - threadwise_dst_idx_store.Run( - ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); - }; - - template <> - __device__ static void Run<3>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ ws_values_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global) - { - (void)origReduceLen; - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - const auto src_global_val_buf = - make_dynamic_buffer(ws_values_global, - src2dDesc.GetElementSpaceSize(), - type_convert{}(zeroVal)); - const auto src_global_idx_buf = make_dynamic_buffer( - ws_indices_global, src2dDesc.GetElementSpaceSize()); - auto dst_global_val_buf = make_dynamic_buffer( - p_dst_global, dst1dDesc.GetElementSpaceSize()); - auto dst_global_idx_buf = make_dynamic_buffer( - indices_global, dst1dDesc.GetElementSpaceSize()); - - StaticBuffer - in_thread_val_buf; - StaticBuffer in_thread_idx_buf; - - using threadwise_reduce = ThreadReduceWithIndicesInput; - - StaticBuffer accuValue_buf; - StaticBuffer accuIndex_buf; - - accuValue_buf(I0) = zeroVal; - accuIndex_buf(I0) = 0; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - - using ThreadBufferLengths = Sequence<1, GredThreadBufferLength>; - constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed( - make_tuple(Number<1>{}, Number{})); - - index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); - - auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2, - 1, - 1, - 1, - false>( - src2dDesc, make_multi_index(thread_global_1d_id, 0)); - - auto threadwise_src_idx_load = ThreadwiseTensorSliceTransfer_v2, - 1, - 1, - 1, - false>( - src2dDesc, make_multi_index(thread_global_1d_id, 0)); - - constexpr auto in_thread_copy_step = make_multi_index(0, GredThreadBufferLength); - - for(index_t reducedLength = 0; reducedLength < toReduceLength; - reducedLength += GredThreadBufferLength) - { - threadwise_src_val_load.Run(src2dDesc, - src_global_val_buf, - ThreadBufferDesc, - make_tuple(I0, I0), - in_thread_val_buf); - threadwise_src_idx_load.Run(src2dDesc, - src_global_idx_buf, - ThreadBufferDesc, - make_tuple(I0, I0), - in_thread_idx_buf); - - // do the reduction on the Thread Buffer - threadwise_reduce::Reduce( - in_thread_val_buf, in_thread_idx_buf, accuValue_buf(I0), accuIndex_buf(I0)); - - threadwise_src_val_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step); - threadwise_src_idx_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step); - } - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); - - StaticBuffer dstValue_buf; - - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); - - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - false>( - dst1dDesc, make_multi_index(thread_global_1d_id)); - - StaticBuffer priorDstValue_buf; - - threadwise_dst_load.Run( - dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); - - dstValue_buf(I0) += priorDstValue_buf[I0] * beta; - } - - auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - false>(dst1dDesc, - make_multi_index(thread_global_1d_id)); - - auto threadwise_dst_idx_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - false>(dst1dDesc, - make_multi_index(thread_global_1d_id)); - - threadwise_dst_val_store.Run( - ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf); - threadwise_dst_idx_store.Run( - ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); - }; -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp deleted file mode 100644 index 4136dae75ff..00000000000 --- a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp +++ /dev/null @@ -1,544 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_WARPWISE_HPP -#define CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_WARPWISE_HPP - -#include "data_type.hpp" -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_warpwise.hpp" - -#include "threadwise_tensor_slice_transfer.hpp" - -namespace ck { - -template -struct GridwiseReduction_xy_to_x_direct_warpwise -{ - using opReduce = typename reduce_binary_operator::opType; - using preUnaryOpType = - typename reduce_unary_operator::preUnaryOp; - using posUnaryOpType = - typename reduce_unary_operator::posUnaryOp; - - static constexpr auto I0 = Number<0>{}; - - template - __device__ static void Run(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global); - - template <> - __device__ static void Run<1>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global) - { - (void)ws_indices_global; - (void)indices_global; - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); - auto dst_global_buf = make_dynamic_buffer( - p_dst_global, dst1dDesc.GetElementSpaceSize()); - - StaticBuffer - in_thread_buf; - - using warpwise_reduce = - WarpReduce; - - StaticBuffer accuValue_buf; - - accuValue_buf(I0) = zeroVal; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - const int divider = origReduceLen; - - const preUnaryOpType preUnaryOp(divider); - const posUnaryOpType posUnaryOp(divider); - - using ThreadBufferLengths = Sequence<1, GredAccessesPerThreadInWarp>; - constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed( - make_tuple(Number<1>{}, Number{})); - - index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); - index_t warp_global_1d_id = thread_global_1d_id / warpSize; - index_t thread_inwarp_id = thread_global_1d_id % warpSize; - - auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2, - 1, - 1, - 1, - false>( - src2dDesc, - make_multi_index(warp_global_1d_id, thread_inwarp_id * GredAccessesPerThreadInWarp)); - - constexpr auto in_thread_copy_step = - make_multi_index(0, warpSize * GredAccessesPerThreadInWarp); - - for(index_t reducedLength = 0; reducedLength < toReduceLength; - reducedLength += warpSize * GredAccessesPerThreadInWarp) - { - threadwise_src_load.Run( - src2dDesc, src_global_buf, ThreadBufferDesc, make_tuple(I0, I0), in_thread_buf); - - // do element-wise pre-reduction operation - warpwise_reduce::operate_on_elements(preUnaryOp, in_thread_buf); - - // do the warp-wise reduction on data of all thread buffers - warpwise_reduce::Reduce(in_thread_buf, accuValue_buf(I0)); - - threadwise_src_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step); - } - - accuValue_buf(I0) = posUnaryOp(accuValue_buf[I0]); - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - // The first thread in the warp stores the reduced result to the global location - // representing the Warp - if(thread_inwarp_id == 0) - { - if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); - - StaticBuffer dstValue_buf; - - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); - - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - true>(dst1dDesc, - make_multi_index(warp_global_1d_id)); - - StaticBuffer priorDstValue_buf; - - threadwise_dst_load.Run( - dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); - - dstValue_buf(I0) += priorDstValue_buf(I0) * beta; - } - - auto threadwise_dst_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(dst1dDesc, - make_multi_index(warp_global_1d_id)); - - threadwise_dst_store.Run( - ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf); - } - }; - - template <> - __device__ static void Run<2>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global) - { - (void)ws_indices_global; - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); - auto dst_global_val_buf = make_dynamic_buffer( - p_dst_global, dst1dDesc.GetElementSpaceSize()); - auto dst_global_idx_buf = make_dynamic_buffer( - indices_global, dst1dDesc.GetElementSpaceSize()); - - StaticBuffer - in_thread_buf; - - using warpwise_reduce = - WarpReduce; - - StaticBuffer accuValue_buf; - StaticBuffer accuIndex_buf; - - accuValue_buf(I0) = zeroVal; - accuIndex_buf(I0) = 0; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - const int divider = origReduceLen; - - const preUnaryOpType preUnaryOp(divider); - - using ThreadBufferLengths = Sequence<1, GredAccessesPerThreadInWarp>; - constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed( - make_tuple(Number<1>{}, Number{})); - - index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); - index_t warp_global_1d_id = thread_global_1d_id / warpSize; - index_t thread_inwarp_id = thread_global_1d_id % warpSize; - - auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2, - 1, - 1, - 1, - false>( - src2dDesc, - make_multi_index(warp_global_1d_id, thread_inwarp_id * GredAccessesPerThreadInWarp)); - - constexpr auto in_thread_copy_step = - make_multi_index(0, warpSize * GredAccessesPerThreadInWarp); - - index_t indexOffset = 0; - for(index_t reducedLength = 0; reducedLength < toReduceLength; - reducedLength += warpSize * GredAccessesPerThreadInWarp) - { - threadwise_src_load.Run( - src2dDesc, src_global_buf, ThreadBufferDesc, make_tuple(I0, I0), in_thread_buf); - - // unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually - // done here - warpwise_reduce::operate_on_elements(preUnaryOp, in_thread_buf); - - // do the warp-wise reduction on data of all thread buffers - warpwise_reduce::Reduce2( - in_thread_buf, accuValue_buf(I0), accuIndex_buf(I0), indexOffset); - - indexOffset += warpSize * GredAccessesPerThreadInWarp; - - threadwise_src_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step); - } - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - // The first thread in the warp stores the reduced result to the global location - // representing the Warp - if(thread_inwarp_id == 0) - { - if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); - - StaticBuffer dstValue_buf; - - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); - - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - true>(dst1dDesc, - make_multi_index(warp_global_1d_id)); - - StaticBuffer priorDstValue_buf; - - threadwise_dst_load.Run(dst1dDesc, - dst_global_val_buf, - ReducedDataDesc, - make_tuple(I0), - priorDstValue_buf); - - dstValue_buf(I0) += priorDstValue_buf[I0] * beta; - } - - auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(dst1dDesc, - make_multi_index(warp_global_1d_id)); - - auto threadwise_dst_idx_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(dst1dDesc, - make_multi_index(warp_global_1d_id)); - - threadwise_dst_val_store.Run( - ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf); - threadwise_dst_idx_store.Run( - ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); - } - }; - - template <> - __device__ static void Run<3>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ ws_values_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global) - { - (void)origReduceLen; - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - const auto src_global_val_buf = - make_dynamic_buffer(ws_values_global, - src2dDesc.GetElementSpaceSize(), - type_convert{}(zeroVal)); - const auto src_global_idx_buf = make_dynamic_buffer( - ws_indices_global, src2dDesc.GetElementSpaceSize()); - auto dst_global_val_buf = make_dynamic_buffer( - p_dst_global, dst1dDesc.GetElementSpaceSize()); - auto dst_global_idx_buf = make_dynamic_buffer( - indices_global, dst1dDesc.GetElementSpaceSize()); - - StaticBuffer - in_thread_val_buf; - StaticBuffer - in_thread_idx_buf; - - using warpwise_reduce = WarpReduceWithIndicesInput; - - StaticBuffer accuValue_buf; - StaticBuffer accuIndex_buf; - - accuValue_buf(I0) = zeroVal; - accuIndex_buf(I0) = 0; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - - using ThreadBufferLengths = Sequence<1, GredAccessesPerThreadInWarp>; - constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed( - make_tuple(Number<1>{}, Number{})); - - index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); - index_t warp_global_1d_id = thread_global_1d_id / warpSize; - index_t thread_inwarp_id = thread_global_1d_id % warpSize; - - auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2, - 1, - 1, - 1, - false>( - src2dDesc, - make_multi_index(warp_global_1d_id, thread_inwarp_id * GredAccessesPerThreadInWarp)); - - auto threadwise_src_idx_load = ThreadwiseTensorSliceTransfer_v2, - 1, - 1, - 1, - false>( - src2dDesc, - make_multi_index(warp_global_1d_id, thread_inwarp_id * GredAccessesPerThreadInWarp)); - - constexpr auto in_thread_copy_step = - make_multi_index(0, warpSize * GredAccessesPerThreadInWarp); - - for(index_t reducedLength = 0; reducedLength < toReduceLength; - reducedLength += warpSize * GredAccessesPerThreadInWarp) - { - threadwise_src_val_load.Run(src2dDesc, - src_global_val_buf, - ThreadBufferDesc, - make_tuple(I0, I0), - in_thread_val_buf); - threadwise_src_idx_load.Run(src2dDesc, - src_global_idx_buf, - ThreadBufferDesc, - make_tuple(I0, I0), - in_thread_idx_buf); - - // do the warp-wise reduction on data of all thread buffers - warpwise_reduce::Reduce( - in_thread_val_buf, in_thread_idx_buf, accuValue_buf(I0), accuIndex_buf(I0)); - - threadwise_src_val_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step); - threadwise_src_idx_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step); - } - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - // The first thread in the warp stores the reduced result to the global location - // representing the Warp - if(thread_inwarp_id == 0) - { - if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); - - StaticBuffer dstValue_buf; - - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); - - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - true>(dst1dDesc, - make_multi_index(warp_global_1d_id)); - - StaticBuffer priorDstValue_buf; - - threadwise_dst_load.Run(dst1dDesc, - dst_global_val_buf, - ReducedDataDesc, - make_tuple(I0), - priorDstValue_buf); - - dstValue_buf(I0) += priorDstValue_buf[I0] * beta; - } - - auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(dst1dDesc, - make_multi_index(warp_global_1d_id)); - - auto threadwise_dst_idx_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(dst1dDesc, - make_multi_index(warp_global_1d_id)); - - threadwise_dst_val_store.Run( - ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf); - threadwise_dst_idx_store.Run( - ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); - } - }; -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp deleted file mode 100644 index feee2b594a3..00000000000 --- a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp +++ /dev/null @@ -1,376 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_MULTIBLOCK_HPP -#define CK_GRIDWISE_GENERIC_2D_REDUCTION_MULTIBLOCK_HPP - -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_blockwise.hpp" - -#include "blockwise_tensor_slice_transfer.hpp" - -namespace ck { - -template -struct GridwiseReduction_xy_to_x_multiblock -{ - using opReduce = typename reduce_binary_operator::opType; - using preUnaryOpType = typename reduce_unary_operator::preUnaryOp; - using posUnaryOpType = typename reduce_unary_operator::posUnaryOp; - - static constexpr auto buffer2dDesc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - using blockwise_reduce = - BlockwiseReduction_2d_block_buffer; - - static constexpr index_t BlockBufferSize = buffer2dDesc.GetElementSize(); - - static constexpr auto I0 = Number<0>{}; - - template - __device__ static void Run(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - int BlkGroupSize, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - srcDataType* const __restrict__ ws_values_global, - int* const __restrict__ ws_indices_global); - - template <> - __device__ static void Run<1>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - int BlkGroupSize, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - srcDataType* const __restrict__ ws_values_global, - int* const __restrict__ ws_indices_global) - { - (void)ws_indices_global; - - (void)alpha; // unused - (void)beta; // unused - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - // LDS - __shared__ compType p_in_block_buffer[BlockBufferSize]; - - const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); - auto workspace_global_buf = make_dynamic_buffer( - ws_values_global, dst1dDesc.GetLength(I0) * BlkGroupSize); - - auto in_block_buf = - make_dynamic_buffer(p_in_block_buffer, BlockBufferSize); - StaticBuffer accuValue_buf; - - accuValue_buf(I0) = zeroVal; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - const int divider = origReduceLen; - - const preUnaryOpType preUnaryOp(divider); - - const index_t thread_local_id = get_thread_local_1d_id(); - const index_t block_global_id = get_block_1d_id(); - const index_t blkgroup_id = block_global_id / BlkGroupSize; - const index_t block_local_id = block_global_id % BlkGroupSize; - - const index_t reduceSizePerBlock = - (((toReduceLength + BlkGroupSize - 1) / BlkGroupSize + BlockBufferSize - 1) / - BlockBufferSize) * - BlockBufferSize; - - constexpr auto in_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number<1>{}, Number{})); - - using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>; - using ThreadClusterLengths = Sequence<1, BlockSize>; - - auto blockwise_src_load = BlockwiseTensorSliceTransfer_v4, - ThreadSliceLengths, - ThreadClusterLengths, - Sequence<0, 1>, - srcDataType, - compType, - src2dDescType, - decltype(in_block_desc), - Sequence<0, 1>, - Sequence<0, 1>, - 1, - 1, - 1, - 1, - 1, - 1, - false, - true>( - src2dDesc, - make_multi_index(blkgroup_id, block_local_id * reduceSizePerBlock), - in_block_desc, - make_multi_index(0, 0)); - - constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize); - - const index_t toReduceBlocks = (reduceSizePerBlock + BlockSize - 1) / BlockSize; - - for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks; - reducedBlocks += GredAccessesPerThreadInBlock) - { - blockwise_src_load.RunRead(src2dDesc, src_global_buf); - blockwise_src_load.RunWrite(in_block_desc, in_block_buf); - __syncthreads(); - - // do element-wise pre-reduction operation - blockwise_reduce::operate_on_elements(preUnaryOp, in_block_buf); - - index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock) - ? GredAccessesPerThreadInBlock - : toReduceBlocks - reducedBlocks; - blockwise_reduce::Reduce(in_block_buf, BlocksInOneOp, accuValue_buf(I0)); - - blockwise_src_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step); - } - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - const auto workspace_desc = - make_naive_tensor_descriptor_packed(make_tuple(dst1dDesc.GetLength(I0) * BlkGroupSize)); - - // The first thread in the block stores the reduced result to the global location - // representing the block - if(thread_local_id == 0) - { - auto threadwise_workspace_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(workspace_desc, - make_multi_index(block_global_id)); - - threadwise_workspace_store.Run(ReducedDataDesc, - make_tuple(I0), - accuValue_buf, - workspace_desc, - workspace_global_buf); - } - }; - - template <> - __device__ static void Run<2>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - int BlkGroupSize, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - srcDataType* const __restrict__ ws_values_global, - int* const __restrict__ ws_indices_global) - { - (void)alpha; // unused - (void)beta; // unused - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - // LDS - __shared__ compType p_in_block_values_buffer[BlockBufferSize]; - __shared__ int p_in_block_indices_buffer[BlockBufferSize]; - - const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); - auto workspace_global_val_buf = make_dynamic_buffer( - ws_values_global, dst1dDesc.GetLength(I0) * BlkGroupSize); - auto workspace_global_idx_buf = make_dynamic_buffer( - ws_indices_global, dst1dDesc.GetLength(I0) * BlkGroupSize); - - auto in_block_val_buf = - make_dynamic_buffer(p_in_block_values_buffer, BlockBufferSize); - auto in_block_idx_buf = make_dynamic_buffer( - p_in_block_indices_buffer, BlockBufferSize); - StaticBuffer accuValue_buf; - StaticBuffer accuIndex_buf; - - accuValue_buf(I0) = zeroVal; - accuIndex_buf(I0) = 0; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - const int divider = origReduceLen; - - const preUnaryOpType preUnaryOp(divider); - - const index_t thread_local_id = get_thread_local_1d_id(); - const index_t block_global_id = get_block_1d_id(); - const index_t blkgroup_id = block_global_id / BlkGroupSize; - const index_t block_local_id = block_global_id % BlkGroupSize; - - const index_t reduceSizePerBlock = - (((toReduceLength + BlkGroupSize - 1) / BlkGroupSize + BlockBufferSize - 1) / - BlockBufferSize) * - BlockBufferSize; - - constexpr auto in_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number<1>{}, Number{})); - - using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>; - using ThreadClusterLengths = Sequence<1, BlockSize>; - - auto blockwise_src_load = BlockwiseTensorSliceTransfer_v4, - ThreadSliceLengths, - ThreadClusterLengths, - Sequence<0, 1>, - srcDataType, - compType, - src2dDescType, - decltype(in_block_desc), - Sequence<0, 1>, - Sequence<0, 1>, - 1, - 1, - 1, - 1, - 1, - 1, - false, - true>( - src2dDesc, - make_multi_index(blkgroup_id, block_local_id * reduceSizePerBlock), - in_block_desc, - make_multi_index(0, 0)); - - constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize); - - const index_t toReduceBlocks = (reduceSizePerBlock + BlockSize - 1) / BlockSize; - - int indexOffset = block_local_id * reduceSizePerBlock; - - for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks; - reducedBlocks += GredAccessesPerThreadInBlock) - { - blockwise_reduce::init_buffer_indices(in_block_idx_buf, indexOffset); - - blockwise_src_load.RunRead(src2dDesc, src_global_buf); - blockwise_src_load.RunWrite(in_block_desc, in_block_val_buf); - - __syncthreads(); - - // unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually - // done here - blockwise_reduce::operate_on_elements(preUnaryOp, in_block_val_buf); - - index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock) - ? GredAccessesPerThreadInBlock - : toReduceBlocks - reducedBlocks; - - blockwise_reduce::Reduce2(in_block_val_buf, - in_block_idx_buf, - BlocksInOneOp, - accuValue_buf(I0), - accuIndex_buf(I0)); - - indexOffset += BlockBufferSize; - - blockwise_src_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step); - } - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - const auto workspace_desc = - make_naive_tensor_descriptor_packed(make_tuple(dst1dDesc.GetLength(I0) * BlkGroupSize)); - - // The first thread in the block stores the reduced result to the global location - // representing the block - if(thread_local_id == 0) - { - auto threadwise_workspace_val_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(workspace_desc, - make_multi_index(block_global_id)); - - auto threadwise_workspace_idx_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(workspace_desc, - make_multi_index(block_global_id)); - - threadwise_workspace_val_store.Run(ReducedDataDesc, - make_tuple(I0), - accuValue_buf, - workspace_desc, - workspace_global_val_buf); - threadwise_workspace_idx_store.Run(ReducedDataDesc, - make_tuple(I0), - accuIndex_buf, - workspace_desc, - workspace_global_idx_buf); - } - }; -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/reduction_functions_blockwise.hpp b/composable_kernel/include/tensor_operation/reduction_functions_blockwise.hpp deleted file mode 100644 index 046d3311aa7..00000000000 --- a/composable_kernel/include/tensor_operation/reduction_functions_blockwise.hpp +++ /dev/null @@ -1,271 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP -#define CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP - -#include "data_type.hpp" - -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_binop.hpp" - -namespace ck { - -template -struct BlockwiseReduction_2d_block_buffer -{ - using compType = typename opReduce::dataType; - - static constexpr auto buffer2dDesc = buffer2dDescType{}; - - static constexpr index_t BlockSize = - blockIsOneRow ? buffer2dDesc.GetLength(Number<1>{}) : buffer2dDesc.GetLength(Number<0>{}); - static constexpr index_t NumBlocks = - blockIsOneRow ? buffer2dDesc.GetLength(Number<0>{}) : buffer2dDesc.GetLength(Number<1>{}); - using binop = detail::binop_with_nan_check; - - // This interface does not accumulate on indices - template - __device__ static void - Reduce(BufferType& block_buffer, index_t toReduceBlocks, compType& accuData) - { - const index_t thread_local_id = get_thread_local_1d_id(); - compType lAccuData = opReduce::GetReductionZeroVal(); - - index_t offset; - for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++) - { - offset = blockIsOneRow - ? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_local_id)) - : buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, otherDimInd)); - compType opData = type_convert{}(block_buffer[offset]); - - binop::calculate(lAccuData, opData); - } - - offset = blockIsOneRow ? buffer2dDesc.CalculateOffset(make_tuple(0, thread_local_id)) - : buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, 0)); - - block_buffer(offset) = lAccuData; - - __syncthreads(); - - for(index_t indOffset = BlockSize / 2; indOffset > 0; indOffset /= 2) - { - if(thread_local_id < indOffset) - { - index_t offset1 = - blockIsOneRow ? buffer2dDesc.CalculateOffset(make_tuple(0, thread_local_id)) - : buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, 0)); - - index_t offset2 = - blockIsOneRow - ? buffer2dDesc.CalculateOffset(make_tuple(0, thread_local_id + indOffset)) - : buffer2dDesc.CalculateOffset(make_tuple(thread_local_id + indOffset, 0)); - - compType opData1 = type_convert{}(block_buffer[offset1]); - compType opData2 = type_convert{}(block_buffer[offset2]); - binop::calculate(opData1, opData2); - block_buffer(offset1) = type_convert{}(opData1); - } - - __syncthreads(); - } - - if(thread_local_id == 0) - { - compType tmpVal = type_convert{}(block_buffer[0]); - - binop::calculate(accuData, tmpVal); - } - }; - - // This interface accumulates on both data values and indices - template - __device__ static void Reduce2(BufferType& block_buffer, - IdxBufferType& block_indices_buffer, - index_t toReduceBlocks, - compType& accuData, - int& accuIndex) - { - const index_t thread_local_id = get_thread_local_1d_id(); - compType lAccuData = opReduce::GetReductionZeroVal(); - int lAccuIndex = 0; - - if constexpr(blockIsOneRow) - { - for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++) - { - for(index_t indOffset = 1; indOffset < BlockSize; indOffset *= 2) - { - if(thread_local_id % (indOffset * 2) == 0) - { - index_t offset1 = - buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_local_id)); - index_t offset2 = buffer2dDesc.CalculateOffset( - make_tuple(otherDimInd, thread_local_id + indOffset)); - - compType currVal1 = type_convert{}(block_buffer[offset1]); - compType currVal2 = type_convert{}(block_buffer[offset2]); - int currIndex1 = block_indices_buffer[offset1]; - int currIndex2 = block_indices_buffer[offset2]; - - binop::calculate(currVal1, currVal2, currIndex1, currIndex2); - block_buffer(offset1) = type_convert{}(currVal1); - block_indices_buffer(offset1) = currIndex1; - } - __syncthreads(); - } - } - - if(thread_local_id == 0) - { - for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++) - { - index_t offset = buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, 0)); - - compType tmpVal = type_convert{}(block_buffer[offset]); - int tmpIndex = block_indices_buffer[offset]; - - binop::calculate(lAccuData, tmpVal, lAccuIndex, tmpIndex); - } - - binop::calculate(accuData, lAccuData, accuIndex, lAccuIndex); - } - } - else - { - index_t offset; - - for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++) - { - offset = buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, otherDimInd)); - compType currVal = type_convert{}(block_buffer[offset]); - int currIndex = block_indices_buffer[offset]; - - binop::calculate(lAccuData, currVal, lAccuIndex, currIndex); - } - - offset = buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, 0)); - - block_buffer(offset) = lAccuData; - block_indices_buffer(offset) = lAccuIndex; - - __syncthreads(); - - for(index_t indOffset = 1; indOffset < BlockSize; indOffset *= 2) - { - if(thread_local_id % (indOffset * 2) == 0) - { - index_t offset1 = buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, 0)); - index_t offset2 = - buffer2dDesc.CalculateOffset(make_tuple(thread_local_id + indOffset, 0)); - - compType currVal1 = type_convert{}(block_buffer[offset1]); - compType currVal2 = type_convert{}(block_buffer[offset2]); - int currIndex1 = block_indices_buffer[offset1]; - int currIndex2 = block_indices_buffer[offset2]; - - binop::calculate(currVal1, currVal2, currIndex1, currIndex2); - block_buffer(offset1) = type_convert{}(currVal1); - block_indices_buffer(offset1) = currIndex1; - } - - __syncthreads(); - } - - if(thread_local_id == 0) - { - compType tmpVal = type_convert{}(block_buffer[0]); - int tmpIndex = block_indices_buffer[0]; - - binop::calculate(accuData, tmpVal, accuIndex, tmpIndex); - } - } - }; - - template - __device__ static void set_buffer_value(BufferType& block_buffer, compType value) - { - index_t thread_id = get_thread_local_1d_id(); - - for(index_t otherDimInd = 0; otherDimInd < NumBlocks; otherDimInd++) - { - index_t offset = blockIsOneRow - ? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_id)) - : buffer2dDesc.CalculateOffset(make_tuple(thread_id, otherDimInd)); - - block_buffer(offset) = value; - - __syncthreads(); - } - }; - - // Initialize the block-wise indices buffer, the index for each element in the block-wise data - // buffer - // is calculated according to its position in the buffer and the global starting index - template - __device__ static void init_buffer_indices(IdxBufferType& block_indices_buffer, int indexStart) - { - index_t thread_id = get_thread_local_1d_id(); - - for(index_t otherDimInd = 0; otherDimInd < NumBlocks; otherDimInd++) - { - index_t offset = blockIsOneRow - ? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_id)) - : buffer2dDesc.CalculateOffset(make_tuple(thread_id, otherDimInd)); - - block_indices_buffer(offset) = offset + indexStart; - - __syncthreads(); - } - }; - - // Execute unary operation on the block buffer elements - template - __device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& block_buffer) - { - index_t thread_id = get_thread_local_1d_id(); - - for(index_t otherDimInd = 0; otherDimInd < NumBlocks; otherDimInd++) - { - index_t offset = blockIsOneRow - ? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_id)) - : buffer2dDesc.CalculateOffset(make_tuple(thread_id, otherDimInd)); - - block_buffer(offset) = unary_op(block_buffer[offset]); - - __syncthreads(); - } - }; -}; - -}; // end of namespace ck - -#endif diff --git a/composable_kernel/include/tensor_operation/reduction_functions_threadwise.hpp b/composable_kernel/include/tensor_operation/reduction_functions_threadwise.hpp deleted file mode 100644 index 2956606a6ba..00000000000 --- a/composable_kernel/include/tensor_operation/reduction_functions_threadwise.hpp +++ /dev/null @@ -1,141 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_REDUCTION_FUNCTIONS_THREADWISE_HPP -#define CK_REDUCTION_FUNCTIONS_THREADWISE_HPP - -#include "data_type.hpp" - -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_binop.hpp" - -namespace ck { - -template -struct ThreadReduce -{ - using compType = typename opReduce::dataType; - - static_assert(BufferType::IsStaticBuffer(), "Thread-wise reduction needs use StaticBuffer!"); - - static_assert( - std::is_same::value, - "Data type of StaticBuffer for Thread-wise reduction should be same as the compType!"); - - static constexpr index_t ThreadBufferLen = BufferType::Size(); - - using binop = detail::binop_with_nan_check; - - // This interface does not accumulate on indices - __device__ static void Reduce(const BufferType& thread_buffer, compType& accuData) - { - static_for<0, ThreadBufferLen, 1>{}( - [&](auto I) { binop::calculate(accuData, thread_buffer[I]); }); - }; - - // This interface accumulates on both data values and indices and - // is called by Direct_ThreadWise reduction method at first-time reduction - __device__ static void - Reduce2(const BufferType& thread_buffer, compType& accuData, int& accuIndex, int indexStart) - { - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { - int currIndex = I + indexStart; - binop::calculate(accuData, thread_buffer[I], accuIndex, currIndex); - }); - }; - - // Set the elements in the per-thread buffer to a specific value - // cppcheck-suppress constParameter - __device__ static void set_buffer_value(BufferType& thread_buffer, compType value) - { - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { thread_buffer(I) = value; }); - }; - - // Execute unary operation on the per-thread buffer elements - template - __device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& thread_buffer) - { - static_for<0, ThreadBufferLen, 1>{}( - [&](auto I) { thread_buffer(I) = unary_op(thread_buffer[I]); }); - }; -}; - -template -struct ThreadReduceWithIndicesInput -{ - using compType = typename opReduce::dataType; - - static_assert(BufferType::IsStaticBuffer(), "Thread-wise reduction needs use StaticBuffer!"); - static_assert(IdxBufferType::IsStaticBuffer(), - "Thread-wise reduction needs use StaticBuffer for indices!"); - - static_assert( - std::is_same::value, - "Data type of StaticBuffer for Thread-wise reduction should be same as the compType!"); - static_assert(std::is_same::value, - "Indices type of StaticBuffer for Thread-wise reduction should be index_t!"); - - static_assert(BufferType::Size() == IdxBufferType::Size(), - "StaticBuffers for data and indices should have the same sizes!"); - - static constexpr index_t ThreadBufferLen = BufferType::Size(); - - using binop = detail::binop_with_nan_check; - - // This interface accumulates on both data values and indices and - // is called by Direct_ThreadWise reduction method at second-time reduction - __device__ static void Reduce(const BufferType& thread_buffer, - const IdxBufferType& thread_indices_buffer, - compType& accuData, - int& accuIndex) - { - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { - binop::calculate(accuData, thread_buffer[I], accuIndex, thread_indices_buffer[I]); - }); - }; - - // Set the elements in the per-thread buffer to a specific value - // cppcheck-suppress constParameter - __device__ static void set_buffer_value(BufferType& thread_buffer, compType value) - { - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { thread_buffer(I) = value; }); - }; - - // Execute unary operation on the per-thread buffer elements - template - __device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& thread_buffer) - { - static_for<0, ThreadBufferLen, 1>{}( - [&](auto I) { thread_buffer(I) = unary_op(thread_buffer[I]); }); - }; -}; - -}; // end of namespace ck - -#endif diff --git a/composable_kernel/include/tensor_operation/reduction_functions_warpwise.hpp b/composable_kernel/include/tensor_operation/reduction_functions_warpwise.hpp deleted file mode 100644 index 9687d2d8c86..00000000000 --- a/composable_kernel/include/tensor_operation/reduction_functions_warpwise.hpp +++ /dev/null @@ -1,371 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_REDUCTION_FUNCTIONS_WARPWISE_HPP -#define CK_REDUCTION_FUNCTIONS_WARPWISE_HPP - -#include "data_type.hpp" - -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_binop.hpp" - -namespace ck { - -template -struct WarpReduce -{ - using compType = typename opReduce::dataType; - using binop = detail::binop_with_nan_check; - - static_assert(BufferType::IsStaticBuffer(), - "Per-thread buffer for WarpWise reduction should be StaticBuffer!"); - static_assert(std::is_same::value, - "Data type of per-thread StaticBuffer for WarpWise reduction should be same as " - "the compType!"); - - static constexpr index_t ThreadBufferLen = BufferType::Size(); - static constexpr bool have_builtin_shuffle = - std::is_same::value || std::is_same::value; - - // This interface does not accumulate on indices - __device__ static void Reduce(const BufferType& thread_buffer, compType& accuData) - { - if constexpr(have_builtin_shuffle) - ReduceImpl1(thread_buffer, accuData); - else - ReduceImpl2(thread_buffer, accuData); - }; - - // This interface implementation uses HIP built-in device shuffling functions - __device__ static void ReduceImpl1(const BufferType& thread_buffer, compType& accuData) - { - compType lAccuData = opReduce::GetReductionZeroVal(); - - static_for<0, ThreadBufferLen, 1>{}( - [&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); }); - - // synchronize among all threads in this warp - __all(1); - - for(index_t stride = warpSize / 2; stride > 0; stride /= 2) - { - compType tmpVal = __shfl_down(lAccuData, stride, warpSize); - binop::calculate(lAccuData, tmpVal); - __all(1); - } - - binop::calculate(accuData, lAccuData); - }; - - // This interface implementation does not use HIP built-in device shuffling functions - // since for fp16, built-in shuffling functions is not provided by HIP - __device__ static void ReduceImpl2(const BufferType& thread_buffer, compType& accuData) - { - compType lAccuData = opReduce::GetReductionZeroVal(); - - static_for<0, ThreadBufferLen, 1>{}( - [&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); }); - - __syncthreads(); - - index_t thread_id = get_thread_local_1d_id(); - index_t warpId = thread_id / warpSize; - index_t thread_inwarp_id = thread_id % warpSize; - - __shared__ compType shuffle_buffer[BlockSize]; - - compType* myBuffer = &shuffle_buffer[warpId * warpSize]; - - myBuffer[thread_inwarp_id] = lAccuData; - - __syncthreads(); - - for(index_t stride = warpSize / 2; stride > 0; stride /= 2) - { - if(thread_inwarp_id < stride) - { - compType currVal1 = myBuffer[thread_inwarp_id]; - compType currVal2 = myBuffer[thread_inwarp_id + stride]; - - binop::calculate(currVal1, currVal2); - - myBuffer[thread_inwarp_id] = currVal1; - } - - __syncthreads(); - } - if(thread_inwarp_id == 0) - binop::calculate(accuData, myBuffer[0]); - }; - - // This interface accumulates on both data values and indices and is called by Direct_WarpWise - // reduction method at first-time reduction - __device__ static void - Reduce2(const BufferType& thread_buffer, compType& accuData, int& accuIndex, int indexStart) - { - if constexpr(have_builtin_shuffle) - Reduce2Impl1(thread_buffer, accuData, accuIndex, indexStart); - else - Reduce2Impl2(thread_buffer, accuData, accuIndex, indexStart); - }; - - // This interface implementation uses HIP built-in device shuffling functions - __device__ static void Reduce2Impl1(const BufferType& thread_buffer, - compType& accuData, - int& accuIndex, - int indexStart) - { - compType lAccuData = opReduce::GetReductionZeroVal(); - int lAccuIndex = 0; - index_t thread_inwarp_id = get_thread_local_1d_id() % warpSize; - - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { - int currIndex = thread_inwarp_id * ThreadBufferLen + I + indexStart; - binop::calculate(lAccuData, thread_buffer[I], lAccuIndex, currIndex); - }); - - // synchronize among all threads in this warp - __all(1); - - for(index_t stride = 1; stride < warpSize; stride *= 2) - { - compType tmpVal = __shfl_down(lAccuData, stride, warpSize); - int tmpIndex = __shfl_down(lAccuIndex, stride, warpSize); - - binop::calculate(lAccuData, tmpVal, lAccuIndex, tmpIndex); - __all(1); - } - - if(thread_inwarp_id == 0) - binop::calculate(accuData, lAccuData, accuIndex, lAccuIndex); - }; - - // This interface implementation does not use HIP built-in device shuffling functions since for - // fp16, built-in shuffling functions is not provided by HIP - __device__ static void Reduce2Impl2(const BufferType& thread_buffer, - compType& accuData, - int& accuIndex, - int indexStart) - { - compType lAccuData = opReduce::GetReductionZeroVal(); - int lAccuIndex = 0; - index_t thread_id = get_thread_local_1d_id(); - index_t warpId = thread_id / warpSize; - index_t thread_inwarp_id = thread_id % warpSize; - - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { - int currIndex = thread_inwarp_id * ThreadBufferLen + I + indexStart; - binop::calculate(lAccuData, thread_buffer[I], lAccuIndex, currIndex); - }); - - __shared__ compType shuffle_data_buffer[BlockSize]; - __shared__ int shuffle_indices_buffer[BlockSize]; - - compType* myDataBuffer = &shuffle_data_buffer[warpId * warpSize]; - int* myIndicesBuffer = &shuffle_indices_buffer[warpId * warpSize]; - - myDataBuffer[thread_inwarp_id] = lAccuData; - myIndicesBuffer[thread_inwarp_id] = lAccuIndex; - - __syncthreads(); - - for(index_t stride = 1; stride < warpSize; stride *= 2) - { - compType currVal1 = myDataBuffer[thread_inwarp_id]; - compType currVal2 = myDataBuffer[thread_inwarp_id + stride]; - int currIndex1 = myIndicesBuffer[thread_inwarp_id]; - int currIndex2 = myIndicesBuffer[thread_inwarp_id + stride]; - - binop::calculate(currVal1, currVal2, currIndex1, currIndex2); - - myDataBuffer[thread_inwarp_id] = currVal1; - myIndicesBuffer[thread_inwarp_id] = currIndex1; - - __syncthreads(); - } - - if(thread_inwarp_id == 0) - binop::calculate(accuData, myDataBuffer[0], accuIndex, myIndicesBuffer[0]); - }; - - // cppcheck-suppress constParameter - __device__ static void set_buffer_value(BufferType& thread_buffer, compType value) - { - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { thread_buffer(I) = value; }); - - __all(1); - }; - - // Execute unary operation on the per-thread buffer elements - template - __device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& thread_buffer) - { - static_for<0, ThreadBufferLen, 1>{}( - [&](auto I) { thread_buffer(I) = unary_op(thread_buffer[I]); }); - - __all(1); - }; -}; - -template -struct WarpReduceWithIndicesInput -{ - using compType = typename opReduce::dataType; - using binop = detail::binop_with_nan_check; - - static_assert(BufferType::IsStaticBuffer(), - "Per-thread buffer for WarpWise reduction should be StaticBuffer!"); - static_assert(IdxBufferType::IsStaticBuffer(), - "Per-thread buffer for WarpWise reduction should be StaticBuffer for indices!"); - - static_assert(std::is_same::value, - "Data type of per-thread StaticBuffer for WarpWise reduction should be same as " - "the compType!"); - static_assert( - std::is_same::value, - "Indices type per-thread of StaticBuffer for WarpWise reduction should be index_t!"); - - static_assert(BufferType::Size() == IdxBufferType::Size(), - "StaticBuffers for data and indices should have the same sizes!"); - - static constexpr index_t ThreadBufferLen = BufferType::Size(); - static constexpr bool have_builtin_shuffle = - std::is_same::value || std::is_same::value; - - // This interface accumulates on both data values and indices and is called by Direct_WarpWise - // reduction method at second-time reduction - __device__ static void Reduce(const BufferType& thread_buffer, - const IdxBufferType& thread_indices_buffer, - compType& accuData, - int& accuIndex) - { - if constexpr(have_builtin_shuffle) - ReduceImpl1(thread_buffer, thread_indices_buffer, accuData, accuIndex); - else - ReduceImpl2(thread_buffer, thread_indices_buffer, accuData, accuIndex); - }; - - // This interface implementation uses HIP built-in device shuffling functions - __device__ static void ReduceImpl1(const BufferType& thread_buffer, - const IdxBufferType& thread_indices_buffer, - compType& accuData, - int& accuIndex) - { - compType lAccuData = opReduce::GetReductionZeroVal(); - int lAccuIndex = 0; - - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { - binop::calculate(lAccuData, thread_buffer[I], lAccuIndex, thread_indices_buffer[I]); - }); - - // synchronize among all threads in this warp - __all(1); - - for(index_t stride = 1; stride < warpSize; stride *= 2) - { - compType tmpVal = __shfl_down(lAccuData, stride, warpSize); - int tmpIndex = __shfl_down(lAccuIndex, stride, warpSize); - - binop::calculate(lAccuData, tmpVal, lAccuIndex, tmpIndex); - __all(1); - } - - binop::calculate(accuData, lAccuData, accuIndex, lAccuIndex); - }; - - // This interface implementation does not use HIP built-in device shuffling functions - // since for fp16, built-in shuffling functions is not provided by HIP - __device__ static void ReduceImpl2(const BufferType& thread_buffer, - const IdxBufferType& thread_indices_buffer, - compType& accuData, - int& accuIndex) - { - compType lAccuData = opReduce::GetReductionZeroVal(); - int lAccuIndex = 0; - index_t thread_id = get_thread_local_1d_id(); - index_t warpId = thread_id / warpSize; - index_t thread_inwarp_id = thread_id % warpSize; - - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { - binop::calculate(lAccuData, thread_buffer[I], lAccuIndex, thread_indices_buffer[I]); - }); - - __shared__ compType shuffle_data_buffer[BlockSize]; - __shared__ int shuffle_indices_buffer[BlockSize]; - - compType* myDataBuffer = &shuffle_data_buffer[warpId * warpSize]; - int* myIndicesBuffer = &shuffle_indices_buffer[warpId * warpSize]; - - myDataBuffer[thread_inwarp_id] = lAccuData; - myIndicesBuffer[thread_inwarp_id] = lAccuIndex; - - __syncthreads(); - - for(index_t stride = 1; stride < warpSize; stride *= 2) - { - compType currVal1 = myDataBuffer[thread_inwarp_id]; - compType currVal2 = myDataBuffer[thread_inwarp_id + stride]; - int currIndex1 = myIndicesBuffer[thread_inwarp_id]; - int currIndex2 = myIndicesBuffer[thread_inwarp_id + stride]; - - binop::calculate(currVal1, currVal2, currIndex1, currIndex2); - - myDataBuffer[thread_inwarp_id] = currVal1; - myIndicesBuffer[thread_inwarp_id] = currIndex1; - - __syncthreads(); - } - - if(thread_inwarp_id == 0) - binop::calculate(accuData, myDataBuffer[0], accuIndex, myIndicesBuffer[0]); - }; - - // cppcheck-suppress constParameter - __device__ static void set_buffer_value(BufferType& thread_buffer, compType value) - { - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { thread_buffer(I) = value; }); - - __all(1); - }; - - // Execute unary operation on the per-thread buffer elements - template - __device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& thread_buffer) - { - static_for<0, ThreadBufferLen, 1>{}( - [&](auto I) { thread_buffer(I) = unary_op(thread_buffer[I]); }); - - __all(1); - }; -}; - -}; // end of namespace ck - -#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp deleted file mode 100644 index f6c15fd85ac..00000000000 --- a/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp +++ /dev/null @@ -1,157 +0,0 @@ -#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP -#define CK_THREADWISE_GEMM_DLOPS_V3_HPP - -#include "common_header.hpp" -#include "math.hpp" - -namespace ck { - -// C[M, N] += transpose(A[K, M]) * B[K, N] -// Element of matrix can be vectorized data -// Assume: -// 1. ADesc, BDesc, CDesc are known at compile-time -// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time -template ::type = false> -struct ThreadwiseGemmDlops_km_kn_mn_v3 -{ - template - __device__ static void Run(const ABuffer& a_buf, - AOriginIdx, - const BBuffer& b_buf, - BOriginIdx, - CBuffer& c_buf, - COriginIdx) - { - static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && - CDesc::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(is_known_at_compile_time>::value && - is_known_at_compile_time>::value && - is_known_at_compile_time>::value, - "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); - - static_assert( - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value && - "wrong! inconsistent type"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - constexpr auto E = ADesc{}.GetLength(I0); - constexpr auto K = ADesc{}.GetLength(I1); - - constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); - constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); - constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); - - static_for<0, E, 1>{}([&](auto e) { - static_for<0, K, 1>{}([&](auto k) { - constexpr index_t a_offset = - ADesc{}.CalculateOffset(a_origin_idx + make_tuple(e, k)); - - if constexpr(H == 2 && W == 2) - { - constexpr index_t b_offset_0 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0)); - constexpr index_t b_offset_1 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 1)); - constexpr index_t b_offset_2 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0)); - constexpr index_t b_offset_3 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 1)); - - constexpr index_t c_offset_0 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0)); - constexpr index_t c_offset_1 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 1)); - constexpr index_t c_offset_2 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0)); - constexpr index_t c_offset_3 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 1)); - - amd_assembly_outer_product_1x4(a_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - c_buf(Number{}), - c_buf(Number{}), - c_buf(Number{}), - c_buf(Number{})); - } - else if constexpr(H == 4 && W == 1) - { - constexpr index_t b_offset_0 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0)); - constexpr index_t b_offset_1 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0)); - constexpr index_t b_offset_2 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 2, 0)); - constexpr index_t b_offset_3 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 3, 0)); - - constexpr index_t c_offset_0 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0)); - constexpr index_t c_offset_1 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0)); - constexpr index_t c_offset_2 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 2, 0)); - constexpr index_t c_offset_3 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 3, 0)); - - amd_assembly_outer_product_1x4(a_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - c_buf(Number{}), - c_buf(Number{}), - c_buf(Number{}), - c_buf(Number{})); - } - else - { - static_for<0, H, 1>{}([&](auto h) { - static_for<0, W, 1>{}([&](auto w) { - constexpr index_t b_offset = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, h, w)); - - constexpr index_t c_offset = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, h, w)); - -#if 0 - c_buf(Number{}) += inner_product_with_conversion{}( - a_buf[Number{}], b_buf[Number{}]); -#else - amd_assembly_inner_product(a_buf[Number{}], - b_buf[Number{}], - c_buf(Number{})); -#endif - }); - }); - } - }); - }); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/amd_xdlops.hpp b/composable_kernel/include/utility/amd_xdlops.hpp deleted file mode 100644 index 083e47fbf1e..00000000000 --- a/composable_kernel/include/utility/amd_xdlops.hpp +++ /dev/null @@ -1,390 +0,0 @@ -#ifndef CK_AMD_XDLOPS_HPP -#define CK_AMD_XDLOPS_HPP - -#include "data_type.hpp" - -namespace ck { - -// A, B, C, cbsz, abid, blgp -extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32( - float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32"); - -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x2f32( - float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2f32"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x4f32( - float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f32"); - -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32( - float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x1f32"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32( - float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32"); - -extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16( - half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16"); - -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8f16( - half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8f16"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16f16( - half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16f16"); - -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16( - half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f16"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16( - half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16"); - -extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16( - ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16"); - -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x4bf16( - ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4bf16"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x8bf16( - ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x8bf16"); - -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16( - ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x2bf16"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16( - ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16"); - -template -struct intrin_mfma_f32_32x32x1f32; - -template <> -struct intrin_mfma_f32_32x32x1f32<64, 64> -{ - template - __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); - reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( - reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); - } -}; - -template <> -struct intrin_mfma_f32_32x32x1f32<32, 64> -{ - template - __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); - } -}; - -template -struct intrin_mfma_f32_32x32x2f32; - -template <> -struct intrin_mfma_f32_32x32x2f32<32, 32> -{ - template - __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); - } -}; - -template -struct intrin_mfma_f32_16x16x4f32; - -template <> -struct intrin_mfma_f32_16x16x4f32<16, 16> -{ - template - __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); - } -}; - -template -struct intrin_mfma_f32_16x16x1f32; - -template <> -struct intrin_mfma_f32_16x16x1f32<16, 64> -{ - template - __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) - { - - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x1f32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); - } -}; - -template -struct intrin_mfma_f32_4x4x1f32; - -template <> -struct intrin_mfma_f32_4x4x1f32<4, 64> -{ - template - __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); - } -}; - -template <> -struct intrin_mfma_f32_4x4x1f32<8, 64> -{ - template - __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); - reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( - reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); - } -}; - -template -struct intrin_mfma_f32_32x32x4f16; - -template <> -struct intrin_mfma_f32_32x32x4f16<64, 64> -{ - template - __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); - reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( - reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); - } -}; - -template <> -struct intrin_mfma_f32_32x32x4f16<32, 64> -{ - template - __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); - } -}; - -template -struct intrin_mfma_f32_32x32x8f16; - -template <> -struct intrin_mfma_f32_32x32x8f16<32, 32> -{ - template - __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x8f16( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); - } -}; - -template -struct intrin_mfma_f32_16x16x16f16; - -template <> -struct intrin_mfma_f32_16x16x16f16<16, 16> -{ - template - __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x16f16( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); - } -}; - -template -struct intrin_mfma_f32_16x16x4f16; - -template <> -struct intrin_mfma_f32_16x16x4f16<16, 64> -{ - template - __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f16( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); - } -}; - -template -struct intrin_mfma_f32_4x4x4f16; - -template <> -struct intrin_mfma_f32_4x4x4f16<4, 64> -{ - template - __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); - } -}; - -template <> -struct intrin_mfma_f32_4x4x4f16<8, 64> -{ - template - __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); - reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( - reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); - } -}; - -#if 0 -template -struct intrin_mfma_f32_32x32x2bf16; - -template -struct intrin_mfma_f32_32x32x2bf16<128, 64, AStride, BStride> -{ - __device__ static c_vec32_4_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c) - { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); - reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); - - reg_c.s.z = - llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0); - reg_c.s.w = - llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0); - - return reg_c; - } -}; - -template -struct intrin_mfma_f32_32x32x2bf16<64, 128, AStride, BStride> -{ - __device__ static c_vec32_4_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c) - { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); - reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); - - reg_c.s.z = - llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0); - reg_c.s.w = - llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0); - - return reg_c; - } -}; - -template -struct intrin_mfma_f32_32x32x2bf16<64, 64, AStride, BStride> -{ - __device__ static c_vec32_2_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_2_t::VecType reg_c) - { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); - reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); - - return reg_c; - } -}; - -template -struct intrin_mfma_f32_32x32x2bf16<64, 32, AStride, BStride> -{ - __device__ static c_vec32_1_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c) - { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1); - - return reg_c; - } -}; - -template -struct intrin_mfma_f32_32x32x2bf16<32, 64, AStride, BStride> -{ - __device__ static c_vec32_1_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c) - { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); - return reg_c; - } -}; - -__device__ c_vec16_1_t::VecType intrin_mfma_f32_32x32x4bf16(const ushort2_t* reg_a, - const ushort2_t* reg_b, - c_vec16_1_t::VecType reg_c) -{ - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0); - return reg_c; -} - -__device__ c_vec4_1_t::VecType intrin_mfma_f32_16x16x8bf16(const ushort2_t* reg_a, - const ushort2_t* reg_b, - c_vec4_1_t::VecType reg_c) -{ - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0); - return reg_c; -} - -template -__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a, - const ushort2_t* reg_b, - c_vec16_1_t::VecType reg_c); -template <> -__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a, - const ushort2_t* reg_b, - c_vec16_1_t::VecType reg_c) -{ - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0); - return reg_c; -} - -template <> -__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t* reg_a, - const ushort2_t* reg_b, - c_vec16_1_t::VecType reg_c) -{ - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4); - return reg_c; -} - -template -struct intrin_mfma_f32_4x4x2bf16; - -template <> -struct intrin_mfma_f32_4x4x2bf16<4, 64> -{ - __device__ static c_vec4_1_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_1_t::VecType reg_c) - { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); - return reg_c; - } -}; - -template <> -struct intrin_mfma_f32_4x4x2bf16<8, 64> -{ - __device__ static c_vec4_2_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_2_t::VecType reg_c) - { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); - reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0); - return reg_c; - } -}; - -#endif - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/config.hpp b/composable_kernel/include/utility/config.hpp deleted file mode 100644 index 5ee4bb9c642..00000000000 --- a/composable_kernel/include/utility/config.hpp +++ /dev/null @@ -1,134 +0,0 @@ -#ifndef CK_CONFIG_AMD_HPP -#define CK_CONFIG_AMD_HPP - -#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS -#include "hip/hip_runtime.h" -#include "hip/hip_fp16.h" -#endif -#include "bfloat16_dev.hpp" - -// "Constant" address space for kernel parameter -#define CONSTANT __attribute__((address_space(4))) - -// GPU target -// should enable one and only one GPU target -#if !(defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \ - defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) || defined(CK_AMD_GPU_GFX1030)) -#error Need to define (only) one GPU target -#endif - -// launch bounds -#define CK_USE_LAUNCH_BOUNDS 1 - -#ifdef CK_USE_LAUNCH_BOUNDS -#define CK_MAX_THREAD_PER_BLOCK 256 -#define CK_MIN_BLOCK_PER_CU 2 -#endif - -// buffer resourse -#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \ - defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) -#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 -#elif defined(CK_AMD_GPU_GFX1030) -#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 -#endif - -// FMA instruction -#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) -#define CK_USE_AMD_V_MAC_F32 -#elif defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90a) || \ - defined(CK_AMD_GPU_GFX1030) -#define CK_USE_AMD_V_FMAC_F32 -#define CK_USE_AMD_V_DOT2_F32_F16 -#define CK_USE_AMD_V_DOT4_I32_I8 -#endif - -// multi index -#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0 - -// AMD inline asm -#ifndef CK_USE_AMD_INLINE_ASM -#define CK_USE_AMD_INLINE_ASM 1 -#endif - -// AMD inner product (DLOP) -#ifndef CK_USE_AMD_INNER_PRODUCT_INLINE_ASM -#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1 -#endif - -// AMD buffer addressing -#ifndef CK_USE_AMD_BUFFER_ADDRESSING -#define CK_USE_AMD_BUFFER_ADDRESSING 1 -#endif - -// only gfx908 support native floating point atomic add -#ifndef CK_USE_AMD_BUFFER_ATOMIC_FADD -#define CK_USE_AMD_BUFFER_ATOMIC_FADD 0 -#endif - -// AMD XDLOPS -#ifndef CK_USE_AMD_XDLOPS -#define CK_USE_AMD_XDLOPS 0 -#endif - -// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) -#ifndef CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM -#define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 -#endif - -// experimental implementation -#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 -#endif - -#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK -#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 -#endif - -#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK -#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 -#endif - -// pass tensor descriptor by value or void* -#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1 -#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0 - -// merge transformation use magic number division -#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0 - -// hack: have underlying assumption that need to be satsified, otherwise it's a bug -// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be -// thread-invariant, otherwise it's a bug -// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" -#ifndef CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE -#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 -#endif - -// workaround for compiler crash when compiling recursive lambda -#ifndef CK_WORKAROUND_SWDEV_275126 -#define CK_WORKAROUND_SWDEV_275126 1 -#endif - -// workaround for compiler crash when using buffer load/store for i8 -#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE -#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1 -#endif - -// workaround for compiler crash when using buffer load/store for i8 -#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE -#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 -#endif - -namespace ck { - -enum InMemoryDataOperationEnum_t -{ - Set, - AtomicAdd -}; - -// index type -using index_t = int32_t; - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/dynamic_buffer.hpp b/composable_kernel/include/utility/dynamic_buffer.hpp deleted file mode 100644 index 886737efacd..00000000000 --- a/composable_kernel/include/utility/dynamic_buffer.hpp +++ /dev/null @@ -1,272 +0,0 @@ -#ifndef CK_BUFFER_HPP -#define CK_BUFFER_HPP - -#include "amd_buffer_addressing.hpp" -#include "c_style_pointer_cast.hpp" -#include "enable_if.hpp" - -namespace ck { - -template -struct DynamicBuffer -{ - using type = T; - - T* p_data_; - ElementSpaceSize element_space_size_; - T invalid_element_value_ = T{0}; - - __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) - : p_data_{p_data}, element_space_size_{element_space_size} - { - } - - __host__ __device__ constexpr DynamicBuffer(T* p_data, - ElementSpaceSize element_space_size, - T invalid_element_value) - : p_data_{p_data}, - element_space_size_{element_space_size}, - invalid_element_value_{invalid_element_value} - { - } - - __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() - { - return BufferAddressSpace; - } - - __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; } - - __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } - - template >::type, - typename scalar_type>::type>::value, - bool>::type = false> - __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const - { - // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; - - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; - - static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, - "wrong! X need to be multiple T"); - -#if CK_USE_AMD_BUFFER_ADDRESSING - bool constexpr use_amd_buffer_addressing = true; -#else - bool constexpr use_amd_buffer_addressing = false; -#endif - - if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global && use_amd_buffer_addressing) - { - constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - - if constexpr(InvalidElementUseNumericalZeroValue) - { - return amd_buffer_load_invalid_element_return_return_zero, - t_per_x>( - p_data_, i, is_valid_element, element_space_size_); - } - else - { - return amd_buffer_load_invalid_element_return_customized_value, - t_per_x>( - p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); - } - } - else - { - if constexpr(InvalidElementUseNumericalZeroValue) - { - return is_valid_element ? *c_style_pointer_cast(&p_data_[i]) : X{0}; - } - else - { - return is_valid_element ? *c_style_pointer_cast(&p_data_[i]) - : X{invalid_element_value_}; - } - } - } - - template >::type, - typename scalar_type>::type>::value, - bool>::type = false> - __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) - { - // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; - - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; - - static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, - "wrong! X need to be multiple T"); - - if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global) - { -#if CK_USE_AMD_BUFFER_ADDRESSING - constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - - amd_buffer_store, t_per_x>( - x, p_data_, i, is_valid_element, element_space_size_); -#else - if(is_valid_element) - { - *c_style_pointer_cast(&p_data_[i]) = x; - } -#endif - } - else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds) - { - if(is_valid_element) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE - *c_style_pointer_cast(&p_data_[i]) = x; -#else - // HACK: compiler would lower IR "store address_space(3)" into - // inefficient - // ISA, so I try to let compiler emit IR "store" which would be lower to - // ds_write_b128 - // TODO: remove this after compiler fix - if constexpr(is_same>::type, int8_t>::value) - { - static_assert((is_same, int8_t>::value && - is_same, int8_t>::value) || - (is_same, int8_t>::value && - is_same, int8x2_t>::value) || - (is_same, int8_t>::value && - is_same, int8x4_t>::value) || - (is_same, int8x4_t>::value && - is_same, int8x4_t>::value) || - (is_same, int8x8_t>::value && - is_same, int8x8_t>::value) || - (is_same, int8x16_t>::value && - is_same, int8x16_t>::value), - "wrong! not implemented for this combination, please add " - "implementation"); - - if constexpr(is_same, int8_t>::value && - is_same, int8_t>::value) - { - // HACK: cast pointer of x is bad - // TODO: remove this after compiler fix - *c_style_pointer_cast(&p_data_[i]) = - *c_style_pointer_cast(&x); - } - else if constexpr(is_same, int8_t>::value && - is_same, int8x2_t>::value) - { - // HACK: cast pointer of x is bad - // TODO: remove this after compiler fix - *c_style_pointer_cast(&p_data_[i]) = - *c_style_pointer_cast(&x); - } - else if constexpr(is_same, int8_t>::value && - is_same, int8x4_t>::value) - { - // HACK: cast pointer of x is bad - // TODO: remove this after compiler fix - *c_style_pointer_cast(&p_data_[i]) = - *c_style_pointer_cast(&x); - } - else if constexpr(is_same, int8x4_t>::value && - is_same, int8x4_t>::value) - { - // HACK: cast pointer of x is bad - // TODO: remove this after compiler fix - *c_style_pointer_cast(&p_data_[i]) = - *c_style_pointer_cast(&x); - } - else if constexpr(is_same, int8x8_t>::value && - is_same, int8x8_t>::value) - { - // HACK: cast pointer of x is bad - // TODO: remove this after compiler fix - *c_style_pointer_cast(&p_data_[i]) = - *c_style_pointer_cast(&x); - } - else if constexpr(is_same, int8x16_t>::value && - is_same, int8x16_t>::value) - { - // HACK: cast pointer of x is bad - // TODO: remove this after compiler fix - *c_style_pointer_cast(&p_data_[i]) = - *c_style_pointer_cast(&x); - } - } - else - { - *c_style_pointer_cast(&p_data_[i]) = x; - } -#endif - } - } - else - { - if(is_valid_element) - { - *c_style_pointer_cast(&p_data_[i]) = x; - } - } - } - - template >::type, - typename scalar_type>::type>::value, - bool>::type = false> - __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x) - { - // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; - - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; - - static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, - "wrong! X need to be multiple T"); - - static_assert(GetAddressSpace() == AddressSpaceEnum_t::Global, "only support global mem"); - -#if CK_USE_AMD_BUFFER_ADDRESSING - constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - - amd_buffer_atomic_add, t_per_x>( - x, p_data_, i, is_valid_element, element_space_size_); -#else - if(is_valid_element) - { - atomicAdd(&p_data_[i], x); - } -#endif - } - - __host__ __device__ static constexpr bool IsStaticBuffer() { return false; } - - __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } -}; - -template -__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) -{ - return DynamicBuffer{p, element_space_size}; -} - -template < - AddressSpaceEnum_t BufferAddressSpace, - typename T, - typename ElementSpaceSize, - typename X, - typename enable_if, remove_cvref_t>::value, bool>::type = false> -__host__ __device__ constexpr auto -make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value) -{ - return DynamicBuffer{ - p, element_space_size, invalid_element_value}; -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/integral_constant.hpp b/composable_kernel/include/utility/integral_constant.hpp deleted file mode 100644 index 14f3df894be..00000000000 --- a/composable_kernel/include/utility/integral_constant.hpp +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef CK_INTEGRAL_CONSTANT_HPP -#define CK_INTEGRAL_CONSTANT_HPP - -namespace ck { - -template -struct integral_constant -{ - static constexpr T value = v; - typedef T value_type; - typedef integral_constant type; - __host__ __device__ constexpr operator value_type() const noexcept { return value; } - __host__ __device__ constexpr value_type operator()() const noexcept { return value; } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/number.hpp b/composable_kernel/include/utility/number.hpp deleted file mode 100644 index f8c56436940..00000000000 --- a/composable_kernel/include/utility/number.hpp +++ /dev/null @@ -1,44 +0,0 @@ -#ifndef CK_NUMBER_HPP -#define CK_NUMBER_HPP - -#include "integral_constant.hpp" - -namespace ck { - -template -using Number = integral_constant; - -template -__host__ __device__ constexpr auto operator+(Number, Number) -{ - return Number{}; -} - -template -__host__ __device__ constexpr auto operator-(Number, Number) -{ - static_assert(Y <= X, "wrong!"); - return Number{}; -} - -template -__host__ __device__ constexpr auto operator*(Number, Number) -{ - return Number{}; -} - -template -__host__ __device__ constexpr auto operator/(Number, Number) -{ - static_assert(Y > 0, "wrong!"); - return Number{}; -} - -template -__host__ __device__ constexpr auto operator%(Number, Number) -{ - static_assert(Y > 0, "wrong!"); - return Number{}; -} -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/reduction_operator.hpp b/composable_kernel/include/utility/reduction_operator.hpp deleted file mode 100644 index c0afbec8695..00000000000 --- a/composable_kernel/include/utility/reduction_operator.hpp +++ /dev/null @@ -1,419 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_REDUCTION_OPERATOR_HPP -#define CK_REDUCTION_OPERATOR_HPP - -#include "reduction_common.hpp" - -namespace ck { - -namespace reduce { - -// Every binary operator used in reduction is represented by a templated functor class. Each functor -// class must provide at least -// three members: -// 1) GetReductionZeroVal() -- the interface to return the "identity element" for the binary -// operator, "identity element" is the unique -// element in the algebraic space that doesn't affect the value of other elements -// when operated against them, and the concept is similar to zero vector in -// vector space -// (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf). -// 2) indexable -- boolean value indicating whether indices of the operated elements could be -// recorded. Usually, Min/Max operator could -// need to record the indices of elements. For operator like Add/Mul, no need to -// record the indices. -// 3) operator() -- the first argument of the operator must be both an input & output, and the -// corresponding variable usually stores -// the accumulated result of many operator() calls; the second argument is only an -// input. For indexable binary -// operator, the second version of operator() has third argument (which is an -// output) to indicate whether the -// accumulated value (the first argument) has changed, in which case the recorded -// accumulated index also need be -// changed. - -template -struct Add -{ - using dataType = T; - - __device__ static constexpr T GetReductionZeroVal() { return static_cast(0.0f); }; - - __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; } - - static constexpr bool indexable = false; -}; - -template -struct Mul -{ - using dataType = T; - - __device__ static constexpr T GetReductionZeroVal() { return static_cast(1.0f); }; - - __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; } - - static constexpr bool indexable = false; -}; - -template -struct Max -{ - using dataType = T; - - __device__ static constexpr T GetReductionZeroVal() { return NumericLimits::Lowest(); }; - - __device__ inline constexpr void operator()(T& a, T b) const - { - if(a < b) - a = b; - } - - __device__ inline constexpr void operator()(T& a, T b, bool& changed) const - { - if(a < b) - { - a = b; - changed = true; - } - } - - static constexpr bool indexable = true; -}; - -template -struct Min -{ - using dataType = T; - - __device__ static constexpr T GetReductionZeroVal() { return NumericLimits::Max(); }; - - __device__ inline constexpr void operator()(T& a, T b) const - { - if(a > b) - a = b; - } - - __device__ inline constexpr void operator()(T& a, T b, bool& changed) const - { - if(a > b) - { - a = b; - changed = true; - } - } - - static constexpr bool indexable = true; -}; - -template -struct AMax -{ - using dataType = T; - - __device__ static constexpr T GetReductionZeroVal() { return static_cast(0.0f); }; - - __device__ inline constexpr void operator()(T& a, T b) const - { - if(a < b) - a = b; - } - - __device__ inline constexpr void operator()(T& a, T b, bool& changed) const - { - if(a < b) - { - a = b; - changed = true; - } - } - - static constexpr bool indexable = true; -}; - -// Unary operators are usually called element-wisely before the reduction is executed on the -// elements. -// They are needed for easy implementation of reduction types of AVG, NRM1, NRM2 -template -struct unary_identic -{ - __device__ unary_identic(const int divider = 1) - { - scaler = 1.0f / static_cast(divider); - }; - - __device__ inline constexpr T operator()(T a) const { return a * type_convert{}(scaler); }; - - float scaler = 1.0f; -}; - -template -struct unary_identic -{ - __device__ unary_identic(const int divider = 1) { (void)divider; }; - - __device__ inline constexpr T operator()(T a) const { return a; }; -}; - -template -struct unary_square -{ - __device__ unary_square(const int divider = 1) { scaler = 1.0f / static_cast(divider); }; - - __device__ inline constexpr T operator()(T a) const - { - a = a * a; - - return a * type_convert{}(scaler); - }; - - float scaler = 1.0f; -}; - -template -struct unary_square -{ - __device__ unary_square(const int divider = 1) { (void)divider; }; - - __device__ inline constexpr T operator()(T a) const { return a * a; }; -}; - -template -struct unary_abs -{ - __device__ unary_abs(const int divider = 1) { scaler = 1.0f / static_cast(divider); }; - - __device__ inline constexpr T operator()(T a) const - { - a = abs(a); - - return a * type_convert{}(scaler); - }; - - float scaler = 1.0f; -}; - -template -struct unary_abs -{ - __device__ unary_abs(const int divider = 1) { (void)divider; }; - - __device__ inline constexpr T operator()(T a) const { return abs(a); }; -}; - -// We know for sure that 4.0 has __habs(), but 3.0 does not have it. -// Let's assume that __habs() exists since 3.5. -#if HIP_PACKAGE_VERSION_FLAT < 3005000000 -inline __device__ __half __habs(__half x) -{ - union - { - __half half; - unsigned short u16; - } val; - val.half = x; - val.u16 = val.u16 & 0x7fff; - return val.half; -} -#endif - -template -struct unary_abs -{ - __device__ unary_abs(const int divider = 1) { scaler = 1.0f / static_cast(divider); }; - - __device__ inline half_t operator()(half_t a) const - { - a = static_cast(__habs(a)); - - return a * type_convert{}(scaler); - }; - - float scaler = 1.0f; -}; - -template <> -struct unary_abs -{ - __device__ unary_abs(const int divider = 1) { (void)divider; }; - - __device__ inline half_t operator()(half_t a) const { return static_cast(__habs(a)); }; -}; - -template -struct unary_sqrt -{ - __device__ unary_sqrt(const int divider = 1) { (void)divider; }; - - __device__ inline T operator()(T a) const { return sqrtf(a); }; -}; - -template <> -struct unary_sqrt -{ - __device__ unary_sqrt(const int divider = 1) { (void)divider; }; - - __device__ inline half_t operator()(half_t a) const { return static_cast(hsqrt(a)); }; -}; - -}; // end of namespace reduce - -// The templated struct reduce_binary_operator maps the enum Ids of binary operators to their -// respective functor classes. -// The "GetReductionZeroVal()" interface and boolean member "indexable" are also provided in -// reduce_binary_operactor for -// easier checking by the upper-layer codes in the kernels. - -template -struct reduce_binary_operator; - -template -struct reduce_binary_operator -{ - using opType = reduce::Add; - using dataType = T; - - static constexpr bool indexable = reduce::Add::indexable; -}; - -template -struct reduce_binary_operator -{ - using opType = reduce::Mul; - using dataType = T; - - static constexpr bool indexable = reduce::Mul::indexable; -}; - -template -struct reduce_binary_operator -{ - using opType = reduce::Min; - using dataType = T; - - static constexpr bool indexable = reduce::Min::indexable; -}; - -template -struct reduce_binary_operator -{ - using opType = reduce::Max; - using dataType = T; - - static constexpr bool indexable = reduce::Max::indexable; -}; - -template -struct reduce_binary_operator -{ - using opType = reduce::AMax; - using dataType = T; - - static constexpr bool indexable = reduce::Max::indexable; -}; - -template -struct reduce_binary_operator -{ - using opType = reduce::Add; - using dataType = T; - - static constexpr bool indexable = reduce::Add::indexable; -}; - -template -struct reduce_binary_operator -{ - using opType = reduce::Add; - using dataType = T; - - static constexpr bool indexable = reduce::Add::indexable; -}; - -template -struct reduce_binary_operator -{ - using opType = reduce::Add; - using dataType = T; - - static constexpr bool indexable = reduce::Add::indexable; -}; - -// The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary -// functor classes. -// The two unary functors are called before and afer the Reduction is executed respectively -template -struct reduce_unary_operator -{ - using preUnaryOp = reduce::unary_identic; - using posUnaryOp = reduce::unary_identic; -}; - -template -struct reduce_unary_operator -{ - using preUnaryOp = reduce::unary_identic; - using posUnaryOp = reduce::unary_identic; -}; - -template -struct reduce_unary_operator -{ - using preUnaryOp = reduce::unary_abs; - using posUnaryOp = reduce::unary_identic; -}; - -template -struct reduce_unary_operator -{ - using preUnaryOp = reduce::unary_abs; - using posUnaryOp = reduce::unary_identic; -}; - -template -struct reduce_unary_operator -{ - using preUnaryOp = reduce::unary_square; - using posUnaryOp = reduce::unary_identic; -}; - -template -struct reduce_unary_operator -{ - using preUnaryOp = reduce::unary_square; - using posUnaryOp = reduce::unary_sqrt; -}; - -template -struct reduce_unary_operator -{ - using preUnaryOp = reduce::unary_identic; - using posUnaryOp = reduce::unary_sqrt; -}; - -} // end of namespace ck - -#endif diff --git a/composable_kernel/include/utility/static_buffer.hpp b/composable_kernel/include/utility/static_buffer.hpp deleted file mode 100644 index 9615d10c597..00000000000 --- a/composable_kernel/include/utility/static_buffer.hpp +++ /dev/null @@ -1,163 +0,0 @@ -#ifndef CK_STATIC_BUFFER_HPP -#define CK_STATIC_BUFFER_HPP - -#include "statically_indexed_array.hpp" - -namespace ck { - -template -struct StaticBuffer : public StaticallyIndexedArray -{ - using type = T; - using base = StaticallyIndexedArray; - - T invalid_element_value_ = T{0}; - - __host__ __device__ constexpr StaticBuffer() : base{} {} - - __host__ __device__ constexpr StaticBuffer(T invalid_element_value) - : base{}, invalid_element_value_{invalid_element_value} - { - } - - __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() - { - return BufferAddressSpace; - } - - template - __host__ __device__ constexpr auto Get(Number i, bool is_valid_element) const - { - if constexpr(InvalidElementUseNumericalZeroValue) - { - return is_valid_element ? At(i) : T{0}; - } - else - { - return is_valid_element ? At(i) : invalid_element_value_; - } - } - - template - __host__ __device__ void Set(Number i, bool is_valid_element, const T& x) - { - if(is_valid_element) - { - At(i) = x; - } - } - - __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } - - __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } -}; - -template -struct StaticBufferV2 : public StaticallyIndexedArray -{ - using type = T; - using base = StaticallyIndexedArray; - - using VecBaseType = typename T::d1_t; - - __host__ __device__ static constexpr index_t GetVectorSize() - { - return sizeof(typename T::type) / sizeof(VecBaseType); - } - - static constexpr index_t vector_size = GetVectorSize(); - - VecBaseType invalid_element_value_ = VecBaseType{0}; - - T invalid_vec_value_ = T{0}; - - __host__ __device__ constexpr StaticBufferV2() : base{} {} - - __host__ __device__ constexpr StaticBufferV2(VecBaseType invalid_element_value) - : base{}, - invalid_vec_value_{invalid_element_value}, - invalid_element_value_{invalid_element_value} - { - } - - __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() - { - return BufferAddressSpace; - } - - template - __host__ __device__ constexpr auto& GetVector(Number vec_id) - { - return this->At(vec_id); - } - - template - __host__ __device__ constexpr const auto& GetVector(Number vec_id) const - { - return this->At(vec_id); - } - - template - __host__ __device__ constexpr auto& GetElement(Number i, bool) - { - constexpr auto vec_id = Number{}; - constexpr auto vec_off = Number{}; - - return this->At(vec_id).template AsType()(vec_off); - } - - template - __host__ __device__ constexpr auto GetElement(Number i, bool is_valid_element) const - { - constexpr auto vec_id = Number{}; - constexpr auto vec_off = Number{}; - - if constexpr(InvalidElementUseNumericalZeroValue) - { - return is_valid_element ? this->At(vec_id).template AsType()[vec_off] - : VecBaseType{0}; - } - else - { - return is_valid_element ? this->At(vec_id).template AsType()[vec_off] - : invalid_element_value_; - } - } - - template - __host__ __device__ constexpr auto operator[](Number i) const - { - return GetElement(i, true); - } - - template - __host__ __device__ constexpr auto& operator()(Number i) - { - return GetElement(i, true); - } - - __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } - - __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } -}; - -template -__host__ __device__ constexpr auto make_static_buffer(Number) -{ - return StaticBuffer{}; -} - -template -__host__ __device__ constexpr auto make_static_buffer(Number, T invalid_element_value) -{ - return StaticBuffer{invalid_element_value}; -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/statically_indexed_array.hpp b/composable_kernel/include/utility/statically_indexed_array.hpp deleted file mode 100644 index f30a3a9ee63..00000000000 --- a/composable_kernel/include/utility/statically_indexed_array.hpp +++ /dev/null @@ -1,40 +0,0 @@ -#ifndef CK_STATICALLY_INDEXED_ARRAY_HPP -#define CK_STATICALLY_INDEXED_ARRAY_HPP - -#include "functional2.hpp" -#include "sequence.hpp" -#include "tuple.hpp" - -namespace ck { - -namespace detail { - -template -__host__ __device__ constexpr auto generate_same_type_tuple() -{ - return generate_tuple([](auto) -> T { return T{}; }, Number{}); -} - -template -using same_type_tuple = decltype(generate_same_type_tuple()); - -} // namespace detail - -template -using StaticallyIndexedArray = detail::same_type_tuple; - -template -__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs) -{ - return StaticallyIndexedArray(x, static_cast(xs)...); -} - -// make empty StaticallyIndexedArray -template -__host__ __device__ constexpr auto make_statically_indexed_array() -{ - return StaticallyIndexedArray(); -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/utility.hpp b/composable_kernel/include/utility/utility.hpp deleted file mode 100644 index 9f34e044b71..00000000000 --- a/composable_kernel/include/utility/utility.hpp +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef CK_UTILITY_HPP -#define CK_UTILITY_HPP - -#include "config.hpp" - -namespace ck { - -__device__ index_t get_thread_local_1d_id() { return threadIdx.x; } - -__device__ index_t get_block_1d_id() { return blockIdx.x; } - -} // namespace ck - -#endif diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp deleted file mode 100644 index 09a7fffa3ed..00000000000 --- a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp +++ /dev/null @@ -1,370 +0,0 @@ -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_dlops_v1r2.hpp" -#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" - -using namespace ck; - -constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); -constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); -constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); - -using FloatAB = typename get_datatype_from_enum::type; -using FloatAcc = typename get_datatype_from_enum::type; -using FloatC = typename get_datatype_from_enum::type; - -constexpr index_t BlockSize = CK_PARAM_BlockSize; - -constexpr index_t MPerBlock = CK_PARAM_MPerBlock; -constexpr index_t NPerBlock = CK_PARAM_NPerBlock; -constexpr index_t KPerBlock = CK_PARAM_KPerBlock; -constexpr index_t M1PerThread = CK_PARAM_M1PerThread; -constexpr index_t N1PerThread = CK_PARAM_N1PerThread; -constexpr index_t KPerThread = CK_PARAM_KPerThread; -constexpr index_t M1N1ThreadClusterM10 = CK_PARAM_M1N1ThreadClusterM10; -constexpr index_t M1N1ThreadClusterN10 = CK_PARAM_M1N1ThreadClusterN10; -constexpr index_t M1N1ThreadClusterM11 = CK_PARAM_M1N1ThreadClusterM11; -constexpr index_t M1N1ThreadClusterN11 = CK_PARAM_M1N1ThreadClusterN11; - -using ABlockTransferThreadSliceLengths_K_M0_M1 = - Sequence; -using ABlockTransferThreadClusterLengths_K_M0_M1 = - Sequence; -using ABlockTransferThreadClusterArrangeOrder = - Sequence; -using ABlockTransferSrcAccessOrder = Sequence; - -constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; -constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; -constexpr index_t ABlockTransferDstScalarPerVector_M1 = - CK_PARAM_ABlockTransferDstScalarPerVector_M1; -constexpr bool AThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); - -using BBlockTransferThreadSliceLengths_K_N0_N1 = - Sequence; -using BBlockTransferThreadClusterLengths_K_N0_N1 = - Sequence; -using BBlockTransferThreadClusterArrangeOrder = - Sequence; -using BBlockTransferSrcAccessOrder = Sequence; - -constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; -constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; -constexpr index_t BBlockTransferDstScalarPerVector_N1 = - CK_PARAM_BBlockTransferDstScalarPerVector_N1; -constexpr bool BThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); - -using CThreadTransferSrcDstAccessOrder = Sequence; -constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; -constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; - -constexpr bool HasMainKBlockLoop = static_cast(CK_PARAM_HAS_MAIN_KBLOCK_LOOP); -constexpr bool HasDoubleTailKBlockLoop = static_cast(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP); - -extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare( - int n, - int c, - int hi, - int wi, - int k, - int y, - int x, - int convStrideH, - int convStrideW, - int convDilationY, - int convDilationX, - int leftPadH, - int leftPadW, - int rightPadH, - int rightPadW, - void* p_a_k_m0_m1_grid_desc, - void* p_b_k_n0_n1_grid_desc, - void* p_c_m0_m10_m11_n0_n10_n11_grid_desc, - void* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; - const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; - - const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi)); - const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x)); - const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo)); - - const auto descs = transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - make_tuple(convStrideH, convStrideW), - make_tuple(convDilationY, convDilationX), - make_tuple(leftPadH, leftPadW), - make_tuple(rightPadH, rightPadW)); - - const auto a_k_m_grid_desc = descs[I0]; - const auto b_k_n_grid_desc = descs[I1]; - const auto c_m_n_grid_desc = descs[I2]; - - using AKMGridDesc = decltype(a_k_m_grid_desc); - using BKNGridDesc = decltype(b_k_n_grid_desc); - using CMNGridDesc = decltype(c_m_n_grid_desc); - - using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}))); - - using BGridStepHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}))); - - using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); - - using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; - using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - - using GridwiseGemm = - GridwiseGemmDlops_km_kn_mn_v1r2; - - auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); - auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); - auto c_m0_m10_m11_n0_n10_n11_grid_desc = - GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); - auto c_blockid_to_m0_n0_block_cluster_adaptor = - GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); - - if(hipThreadIdx_x == 0) - { - *static_cast(p_a_k_m0_m1_grid_desc) = a_k_m0_m1_grid_desc; - *static_cast(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc; - *static_cast( - p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc; - *static_cast( - p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; - }; -}; - -extern "C" __global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k_m0_m1_grid_desc, - const void CONSTANT* p_b_k_n0_n1_grid_desc, - const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, - const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - constexpr auto in_n_c_hi_wi_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); - constexpr auto wei_k_c_y_x_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3)); - constexpr auto out_n_k_ho_wo_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); - - constexpr auto descs = - transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1)); - - constexpr auto a_k_m_grid_desc = descs[I0]; - constexpr auto b_k_n_grid_desc = descs[I1]; - constexpr auto c_m_n_grid_desc = descs[I2]; - - using AKMGridDesc = decltype(a_k_m_grid_desc); - using BKNGridDesc = decltype(b_k_n_grid_desc); - using CMNGridDesc = decltype(c_m_n_grid_desc); - - using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}))); - - using BGridStepHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}))); - - using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); - - using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; - using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - - using GridwiseGemm = - GridwiseGemmDlops_km_kn_mn_v1r2; - - constexpr auto a_k_m0_m1_grid_desc_tmp = - GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); - constexpr auto b_k_n0_n1_grid_desc_tmp = - GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); - constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp = - GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); - constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = - GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); - - using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp); - using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp); - using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp); - using CBlockIdToM0N0BlockClusterAdaptor = - decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); - - const auto a_k_m0_m1_grid_desc = - *reinterpret_cast((const void*)p_a_k_m0_m1_grid_desc); - const auto b_k_n0_n1_grid_desc = - *reinterpret_cast((const void*)p_b_k_n0_n1_grid_desc); - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = - *reinterpret_cast( - (const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc); - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - *reinterpret_cast( - (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); - - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor, - integral_constant{}, - integral_constant{}); -}; diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp deleted file mode 100644 index 51d852617f8..00000000000 --- a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp +++ /dev/null @@ -1,358 +0,0 @@ -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r3.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" - -using namespace ck; - -constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); -constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); -constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); - -using FloatAB = typename get_datatype_from_enum::type; -using FloatAcc = typename get_datatype_from_enum::type; -using FloatC = typename get_datatype_from_enum::type; - -constexpr index_t BlockSize = CK_PARAM_BlockSize; - -constexpr index_t MPerBlock = CK_PARAM_MPerBlock; -constexpr index_t NPerBlock = CK_PARAM_NPerBlock; -constexpr index_t KPerBlock = CK_PARAM_KPerBlock; - -constexpr index_t MPerWave = CK_PARAM_MPerWave; -constexpr index_t NPerWave = CK_PARAM_NPerWave; -constexpr index_t MRepeat = CK_PARAM_MRepeat; -constexpr index_t NRepeat = CK_PARAM_NRepeat; -constexpr index_t K1 = CK_PARAM_K1; - -using ABlockTransferThreadSliceLengths_K0_M_K1 = - Sequence; -using ABlockTransferThreadClusterLengths_K0_M_K1 = - Sequence; -using ABlockTransferThreadClusterArrangeOrder = - Sequence; -using ABlockTransferSrcAccessOrder = Sequence; - -constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; -constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; -constexpr index_t ABlockTransferDstScalarPerVector_K1 = - CK_PARAM_ABlockTransferDstScalarPerVector_K1; -constexpr bool AThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); - -using BBlockTransferThreadSliceLengths_K0_N_K1 = - Sequence; -using BBlockTransferThreadClusterLengths_K0_N_K1 = - Sequence; -using BBlockTransferThreadClusterArrangeOrder = - Sequence; -using BBlockTransferSrcAccessOrder = Sequence; - -constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; -constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; -constexpr index_t BBlockTransferDstScalarPerVector_K1 = - CK_PARAM_BBlockTransferDstScalarPerVector_K1; -constexpr bool BThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); - -using CThreadTransferSrcDstAccessOrder = Sequence; -constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; -constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; - -extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare( - int n, - int c, - int hi, - int wi, - int k, - int y, - int x, - int convStrideH, - int convStrideW, - int convDilationY, - int convDilationX, - int leftPadH, - int leftPadW, - int rightPadH, - int rightPadW, - void* p_a_k0_m_k1_grid_desc, - void* p_b_k0_n_k1_grid_desc, - void* p_c_m0_m1_m2_n_grid_desc, - void* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; - const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; - - const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi)); - const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x)); - const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo)); - - const auto descs = transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - make_tuple(convStrideH, convStrideW), - make_tuple(convDilationY, convDilationX), - make_tuple(leftPadH, leftPadW), - make_tuple(rightPadH, rightPadW), - Number{}); - - const auto a_k0_m_k1_grid_desc = descs[I0]; - const auto b_k0_n_k1_grid_desc = descs[I1]; - const auto c_m_n_grid_desc = descs[I2]; - - using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc); - using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc); - using CMNGridDesc = decltype(c_m_n_grid_desc); - - using AGridStepHacks = decltype(make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); - - using BGridStepHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); - - using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); - - using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; - using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - - using GridwiseGemm = - GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; - - auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); - - auto c_blockid_to_m0_n0_block_cluster_adaptor = - GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); - - if(hipThreadIdx_x == 0) - { - *static_cast*>(p_a_k0_m_k1_grid_desc) = - a_k0_m_k1_grid_desc; - *static_cast*>(p_b_k0_n_k1_grid_desc) = - b_k0_n_k1_grid_desc; - *static_cast(p_c_m0_m1_m2_n_grid_desc) = - c_m0_m1_m2_n_grid_desc; - *static_cast( - p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; - } -}; - -extern "C" __global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k0_m_k1_grid_desc, - const void CONSTANT* p_b_k0_n_k1_grid_desc, - const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, - const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - constexpr auto in_n_c_hi_wi_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); - constexpr auto wei_k_c_y_x_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3)); - constexpr auto out_n_k_ho_wo_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); - - constexpr auto descs = - transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - Number{}); - - constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0]; - constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1]; - constexpr auto c_m_n_grid_desc = descs[I2]; - - using AGridStepHacks = decltype(make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); - - using BGridStepHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); - - using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); - - using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; - using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - - using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp); - using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp); - using CMNGridDesc = decltype(c_m_n_grid_desc); - - using GridwiseGemm = - GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; - - constexpr auto c_m0_m1_m2_n_grid_desc_tmp = - GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); - constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = - GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); - - using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp); - using CBlockIdToM0N0BlockClusterAdaptor = - decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); - - const auto a_k0_m_k1_grid_desc = - *reinterpret_cast((const void*)p_a_k0_m_k1_grid_desc); - const auto b_k0_n_k1_grid_desc = - *reinterpret_cast((const void*)p_b_k0_n_k1_grid_desc); - const auto c_m0_m1_m2_n_grid_desc = - *reinterpret_cast((const void*)p_c_m0_m1_m2_n_grid_desc); - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - *reinterpret_cast( - (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); - - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); -}; diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp deleted file mode 100644 index 30e4c518ced..00000000000 --- a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp +++ /dev/null @@ -1,357 +0,0 @@ -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r3.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" - -using namespace ck; - -constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); -constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); -constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); - -using FloatAB = typename get_datatype_from_enum::type; -using FloatAcc = typename get_datatype_from_enum::type; -using FloatC = typename get_datatype_from_enum::type; - -constexpr index_t BlockSize = CK_PARAM_BlockSize; - -constexpr index_t MPerBlock = CK_PARAM_MPerBlock; -constexpr index_t NPerBlock = CK_PARAM_NPerBlock; -constexpr index_t KPerBlock = CK_PARAM_KPerBlock; - -constexpr index_t MPerWave = CK_PARAM_MPerWave; -constexpr index_t NPerWave = CK_PARAM_NPerWave; -constexpr index_t MRepeat = CK_PARAM_MRepeat; -constexpr index_t NRepeat = CK_PARAM_NRepeat; -constexpr index_t K1 = CK_PARAM_K1; - -using ABlockTransferThreadSliceLengths_K0_M_K1 = - Sequence; -using ABlockTransferThreadClusterLengths_K0_M_K1 = - Sequence; -using ABlockTransferThreadClusterArrangeOrder = - Sequence; -using ABlockTransferSrcAccessOrder = Sequence; - -constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; -constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; -constexpr index_t ABlockTransferDstScalarPerVector_K1 = - CK_PARAM_ABlockTransferDstScalarPerVector_K1; -constexpr bool AThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); - -using BBlockTransferThreadSliceLengths_K0_N_K1 = - Sequence; -using BBlockTransferThreadClusterLengths_K0_N_K1 = - Sequence; -using BBlockTransferThreadClusterArrangeOrder = - Sequence; -using BBlockTransferSrcAccessOrder = Sequence; - -constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; -constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; -constexpr index_t BBlockTransferDstScalarPerVector_K1 = - CK_PARAM_BBlockTransferDstScalarPerVector_K1; -constexpr bool BThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); - -using CThreadTransferSrcDstAccessOrder = Sequence; -constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; -constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; - -extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare( - int n, - int hi, - int wi, - int c, - int k, - int y, - int x, - int convStrideH, - int convStrideW, - int convDilationY, - int convDilationX, - int leftPadH, - int leftPadW, - int rightPadH, - int rightPadW, - void* p_a_k0_m_k1_grid_desc, - void* p_b_k0_n_k1_grid_desc, - void* p_c_m0_m1_m2_n_grid_desc, - void* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; - const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(n, hi, wi, c)); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(make_tuple(k, y, x, c)); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(make_tuple(n, ho, wo, k)); - - const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( - in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - make_tuple(convStrideH, convStrideW), - make_tuple(convDilationY, convDilationX), - make_tuple(leftPadH, leftPadW), - make_tuple(rightPadH, rightPadW), - Number{}); - - const auto a_k0_m_k1_grid_desc = descs[I0]; - const auto b_k0_n_k1_grid_desc = descs[I1]; - const auto c_m_n_grid_desc = descs[I2]; - - using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc); - using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc); - using CMNGridDesc = decltype(c_m_n_grid_desc); - - using BGridStepHacks = decltype(make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); - - using AGridStepHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); - - using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); - - using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; - - using GridwiseGemm = - GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; - - auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); - - auto c_blockid_to_m0_n0_block_cluster_adaptor = - GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); - - if(hipThreadIdx_x == 0) - { - *static_cast*>(p_a_k0_m_k1_grid_desc) = - a_k0_m_k1_grid_desc; - *static_cast*>(p_b_k0_n_k1_grid_desc) = - b_k0_n_k1_grid_desc; - *static_cast(p_c_m0_m1_m2_n_grid_desc) = - c_m0_m1_m2_n_grid_desc; - *static_cast( - p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; - } -}; - -extern "C" __global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k0_m_k1_grid_desc, - const void CONSTANT* p_b_k0_n_k1_grid_desc, - const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, - const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - constexpr auto in_n_hi_wi_c_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256)); - constexpr auto wei_k_y_x_c_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 3, 3, 256)); - constexpr auto out_n_ho_wo_k_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256)); - - constexpr auto descs = - transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - Number{}); - - constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0]; - constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1]; - constexpr auto c_m_n_grid_desc = descs[I2]; - - using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp); - using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp); - using CMNGridDesc = decltype(c_m_n_grid_desc); - - using BGridStepHacks = decltype(make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); - - using AGridStepHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); - - using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); - - using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; - - using GridwiseGemm = - GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; - constexpr auto c_m0_m1_m2_n_grid_desc_tmp = - GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); - constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = - GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); - - using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp); - using CBlockIdToM0N0BlockClusterAdaptor = - decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); - - const auto a_k0_m_k1_grid_desc = - *reinterpret_cast((const void*)p_a_k0_m_k1_grid_desc); - const auto b_k0_n_k1_grid_desc = - *reinterpret_cast((const void*)p_b_k0_n_k1_grid_desc); - const auto c_m0_m1_m2_n_grid_desc = - *reinterpret_cast((const void*)p_c_m0_m1_m2_n_grid_desc); - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - *reinterpret_cast( - (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); - - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); -}; diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp deleted file mode 100644 index 71239e0ecc9..00000000000 --- a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp +++ /dev/null @@ -1,400 +0,0 @@ -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_contraction_dlops_v1r2.hpp" -#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" - -using namespace ck; - -constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); -constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); -constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); - -using FloatAB = typename get_datatype_from_enum::type; -using FloatAcc = typename get_datatype_from_enum::type; -using FloatC = typename get_datatype_from_enum::type; - -constexpr index_t BlockSize = CK_PARAM_BlockSize; - -constexpr auto GN0 = Number{}; -constexpr auto GK1 = Number{}; - -constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11; -constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11; -constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock; - -constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11; -constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11; -constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread; - -using BM10BN10ThreadClusterBM10Xs = Sequence; -using BM10BN10ThreadClusterBN10Xs = Sequence; - -using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = - Sequence; -using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = - Sequence; -using ABlockTransferThreadClusterArrangeOrder = Sequence<1, 2, 3, 0, 4>; -using ABlockTransferSrcAccessOrder = Sequence<3, 2, 1, 0, 4>; -using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = - Sequence; -using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = - Sequence; -using ABlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>; - -using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = - Sequence; -using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = - Sequence; -using BBlockTransferThreadClusterArrangeOrder = Sequence<0, 4, 1, 2, 3>; -using BBlockTransferSrcAccessOrder = Sequence<4, 3, 2, 0, 1>; -using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = - Sequence; -using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = - Sequence; -using BBlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>; - -using CThreadTransferSrcDstAccessOrder = Sequence<3, 4, 5, 0, 1, 2>; -constexpr index_t CThreadTransferSrcDstVectorDim = 5; -constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; - -constexpr bool HasMainKBlockLoop = static_cast(CK_PARAM_HasMainKBlockLoop); -constexpr bool HasDoubleTailKBlockLoop = static_cast(CK_PARAM_HasDoubleTailKBlockLoop); - -extern "C" __global__ void -convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(int N_, - int C_, - int Hi_, - int Wi_, - int K_, - int Y_, - int X_, - int ConvStrideH_, - int ConvStrideW_, - int ConvDilationH_, - int ConvDilationW_, - int InLeftPadH_, - int InLeftPadW_, - int InRightPadH_, - int InRightPadW_, - void* p_desc_tuple) -{ - index_t N = static_cast(N_); - index_t C = static_cast(C_); - index_t Hi = static_cast(Hi_); - index_t Wi = static_cast(Wi_); - index_t K = static_cast(K_); - index_t Y = static_cast(Y_); - index_t X = static_cast(X_); - index_t ConvStrideH = static_cast(ConvStrideH_); - index_t ConvStrideW = static_cast(ConvStrideW_); - index_t ConvDilationH = static_cast(ConvDilationH_); - index_t ConvDilationW = static_cast(ConvDilationW_); - index_t InLeftPadH = static_cast(InLeftPadH_); - index_t InLeftPadW = static_cast(InLeftPadW_); - index_t InRightPadH = static_cast(InRightPadH_); - index_t InRightPadW = static_cast(InRightPadW_); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - const index_t Ho = - (Hi + InLeftPadH + InRightPadH - ConvDilationH * (Y - 1) - 1) / ConvStrideH + 1; - const index_t Wo = - (Wi + InLeftPadW + InRightPadW - ConvDilationW * (X - 1) - 1) / ConvStrideW + 1; - - const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C, Hi, Wi)); - const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C, Y, X)); - const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo)); - - const auto descs = transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - make_tuple(ConvStrideH, ConvStrideW), - make_tuple(ConvDilationH, ConvDilationW), - make_tuple(InLeftPadH, InLeftPadW), - make_tuple(InRightPadH, InRightPadW), - GN0, - GK1); - - const auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0]; - const auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1]; - const auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; - - using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1); - using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1); - using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1); - - using AGridStepHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11 - Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 - Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 - - using BGridStepHacks = decltype(make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 - - using CGridStepHacks = decltype(make_tuple( - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1 - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1 - - using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; - - using BGridMoveSliceWindowStepHacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; - - using GridwiseContraction = - GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - InMemoryDataOperationEnum_t::Set, - AGridDesc_GK0_GM0_GM1_GK1, - BGridDesc_GK0_GN0_GN1_GK1, - CGridDesc_GM0_GM1_GN0_GN1, - GM1PerBlockGM11, - GN1PerBlockGN11, - GK0PerBlock, - BM1PerThreadBM11, - BN1PerThreadBN11, - BK0PerThread, - BM10BN10ThreadClusterBM10Xs, - BM10BN10ThreadClusterBN10Xs, - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferSrcVectorTensorContiguousDimOrder, - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferSrcVectorTensorContiguousDimOrder, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - AGridStepHacks, - BGridStepHacks, - CGridStepHacks, - AGridMoveSliceWindowStepHacks, - BGridMoveSliceWindowStepHacks>; - - if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0) - { - auto desc_tuple = - make_tuple(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( - a_grid_desc_gk0_gm0_gm1_gk1), - GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1( - b_grid_desc_gk0_gn0_gn1_gk1), - GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( - c_grid_desc_gm0_gm1_gn0_gn1), - GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( - c_grid_desc_gm0_gm1_gn0_gn1)); - - *static_cast(p_desc_tuple) = desc_tuple; - } -}; - -extern "C" __global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_desc_tuple) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_n_c_hi_wi_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); - constexpr auto wei_k_c_y_x_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3)); - constexpr auto out_n_k_ho_wo_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); - - constexpr auto descs = - transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - GN0, - GK1); - - constexpr auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0]; - constexpr auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1]; - constexpr auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; - - using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1); - using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1); - using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1); - - using AGridStepHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11 - Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 - Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 - - using BGridStepHacks = decltype(make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 - - using CGridStepHacks = decltype(make_tuple( - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1 - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1 - - using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; - - using BGridMoveSliceWindowStepHacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; - - using GridwiseContraction = - GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - InMemoryDataOperationEnum_t::Set, - AGridDesc_GK0_GM0_GM1_GK1, - BGridDesc_GK0_GN0_GN1_GK1, - CGridDesc_GM0_GM1_GN0_GN1, - GM1PerBlockGM11, - GN1PerBlockGN11, - GK0PerBlock, - BM1PerThreadBM11, - BN1PerThreadBN11, - BK0PerThread, - BM10BN10ThreadClusterBM10Xs, - BM10BN10ThreadClusterBN10Xs, - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferSrcVectorTensorContiguousDimOrder, - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferSrcVectorTensorContiguousDimOrder, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - AGridStepHacks, - BGridStepHacks, - CGridStepHacks, - AGridMoveSliceWindowStepHacks, - BGridMoveSliceWindowStepHacks>; - - using AGridDesc_GK0_GM0_GM10_GM11_GK1 = - decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( - a_grid_desc_gk0_gm0_gm1_gk1)); - using BGridDesc_GK0_GN0_GN10_GN11_GK1 = - decltype(GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1( - b_grid_desc_gk0_gn0_gn1_gk1)); - using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = - decltype(GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( - c_grid_desc_gm0_gm1_gn0_gn1)); - using CGridBlockCluster_BlockId_To_GM10_GN10 = - decltype(GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( - c_grid_desc_gm0_gm1_gn0_gn1)); - - using DescTuple = decltype(make_tuple(AGridDesc_GK0_GM0_GM10_GM11_GK1{}, - BGridDesc_GK0_GN0_GN10_GN11_GK1{}, - CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{}, - CGridBlockCluster_BlockId_To_GM10_GN10{})); - - const auto desc_tuple = - *reinterpret_cast(cast_pointer_to_generic_address_space(p_desc_tuple)); - - const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0]; - const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1]; - const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = desc_tuple[I2]; - const auto c_grid_block_cluster_blockid_to_gm10_gn10 = desc_tuple[I3]; - - constexpr index_t shared_block_size = - GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseContraction::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_grid_desc_gk0_gm0_gm10_gm11_gk1, - b_grid_desc_gk0_gn0_gn10_gn11_gk1, - c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - c_grid_block_cluster_blockid_to_gm10_gn10, - integral_constant{}, - integral_constant{}); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_all_dims.cpp deleted file mode 100644 index ca6b415910e..00000000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_all_dims.cpp +++ /dev/null @@ -1,271 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_blockwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t srcDims = CK_PARAM_IN_DIMS; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, - int BlkGroupSize, - int inLength0, - int inLength1, - int inLength2, - int inLength3, - int inLength4, - int inLength5, - int inStride0, - int inStride1, - int inStride2, - int inStride3, - int inStride4, - int inStride5, - void* __restrict__ ws_global) -{ - (void)GridSize; - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; - const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - - const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple(1); - const auto tupleDstStrides = make_tuple(1); - - const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const auto one_dim_srcDesc = transform_tensor_descriptor( - srcDesc, - make_tuple(make_merge_transform(tupleSrcLengths)), - make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - auto src2dDesc = transform_tensor_descriptor( - one_dim_srcDesc, - make_tuple(make_unmerge_transform(make_tuple(1, one_dim_srcDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - - constexpr int invariantLen = 1; - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); - - constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; - - if constexpr(src2d_need_padding) - { - const auto srcPad = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pass_through_transform(invariantLen), - make_pad_transform(toReduceLen, 0, srcPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dstDesc; -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - - // don't have to use accurate strides to get an expected referrence type - static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1)); - - static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( - ref_srcDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_srcLengths))), - make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr auto ref_src2dDesc = - transform_tensor_descriptor(ref_one_dim_srcDesc, - make_tuple(make_unmerge_transform( - make_tuple(1, ref_one_dim_srcDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); - static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); - - // used by the BlockWise and MultiBlock method - using refType_src2dDesc_padded_34 = decltype( - transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pass_through_transform(ref_invariantLen), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dstDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dstDesc); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_34 = - typename get_ref_desc_types::refType_src2dDesc_padded_34; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, - int BlkGroupSize, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)BlkGroupSize; - (void)ws_buf2_bytes_offset; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_blockwise; - - constexpr int RunId = need_indices ? 2 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(p_src_global), - beta, - static_cast(p_dst_global), - static_cast(nullptr), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_partial_dims.cpp deleted file mode 100644 index a3daeaf1639..00000000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_partial_dims.cpp +++ /dev/null @@ -1,305 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_blockwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t srcDims = CK_PARAM_IN_DIMS; -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS; -constexpr index_t num_invariantDims = srcDims - num_toReduceDims; - -using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type; -using toReduceDims = typename arithmetic_sequence_gen::type; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!"); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, - int BlkGroupSize, - int inLength0, - int inLength1, - int inLength2, - int inLength3, - int inLength4, - int inLength5, - int inStride0, - int inStride1, - int inStride2, - int inStride3, - int inStride4, - int inStride5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, - void* __restrict__ ws_global) -{ - (void)GridSize; - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; - const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; - - const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); - - const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const auto toReduceDimLengths = make_tuple_from_array_and_index_seq(srcLengths, toReduceDims{}); - const auto invariantDimLengths = - make_tuple_from_array_and_index_seq(srcLengths, invariantDims{}); - - auto src2dDesc = - transform_tensor_descriptor(srcDesc, - make_tuple(make_merge_transform(invariantDimLengths), - make_merge_transform(toReduceDimLengths)), - make_tuple(invariantDims{}, toReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto invariantLen = src2dDesc.GetLength(Number<0>{}); - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); - - constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; - - if constexpr(src2d_need_padding) - { - const auto srcPad = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pass_through_transform(invariantLen), - make_pad_transform(toReduceLen, 0, srcPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_toReduceDimLengths = - typename uniform_sequence_gen::type{}; - static constexpr auto ref_invariantDimLengths = - typename uniform_sequence_gen::type{}; - - static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - static constexpr auto ref_dstLengths = typename uniform_sequence_gen::type{}; - - // don't have to use accurate strides to get an expected referrence type - static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths)); - - static constexpr auto ref_src2dDesc = transform_tensor_descriptor( - ref_srcDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_invariantDimLengths)), - make_merge_transform(make_tuple_from_seq(ref_toReduceDimLengths))), - make_tuple(invariantDims{}, toReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); - static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); - - // used by the BlockWise and MultiBlock method - using refType_src2dDesc_padded_34 = decltype( - transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pass_through_transform(ref_invariantLen), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); -}; - -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_34 = - typename get_ref_desc_types:: - refType_src2dDesc_padded_34; -using refType_dst1dDesc_padded = - typename get_ref_desc_types:: - refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, - int BlkGroupSize, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)BlkGroupSize; - (void)ws_buf2_bytes_offset; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_blockwise; - - constexpr int RunId = need_indices ? 2 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(p_src_global), - beta, - static_cast(p_dst_global), - static_cast(nullptr), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_all_dims.cpp deleted file mode 100644 index 81899dfb021..00000000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_all_dims.cpp +++ /dev/null @@ -1,276 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_multiblock.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t srcDims = CK_PARAM_IN_DIMS; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, - int BlkGroupSize, - int inLength0, - int inLength1, - int inLength2, - int inLength3, - int inLength4, - int inLength5, - int inStride0, - int inStride1, - int inStride2, - int inStride3, - int inStride4, - int inStride5, - void* __restrict__ ws_global) -{ - (void)GridSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; - const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - - const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple(1); - const auto tupleDstStrides = make_tuple(1); - - const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const auto one_dim_srcDesc = transform_tensor_descriptor( - srcDesc, - make_tuple(make_merge_transform(tupleSrcLengths)), - make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - auto src2dDesc = transform_tensor_descriptor( - one_dim_srcDesc, - make_tuple(make_unmerge_transform(make_tuple(1, one_dim_srcDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - - constexpr int invariantLen = 1; - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); - - constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; - const index_t reduceSizePerBlock = - (((toReduceLen + BlkGroupSize - 1) / BlkGroupSize + copySliceLen - 1) / copySliceLen) * - copySliceLen; - - if constexpr(src2d_need_padding) - { - const auto srcPad = reduceSizePerBlock * BlkGroupSize - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pass_through_transform(invariantLen), - make_pad_transform(toReduceLen, 0, srcPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dstDesc; -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - - // don't have to use accurate strides to get an expected referrence type - static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1)); - - static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( - ref_srcDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_srcLengths))), - make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr auto ref_src2dDesc = - transform_tensor_descriptor(ref_one_dim_srcDesc, - make_tuple(make_unmerge_transform( - make_tuple(1, ref_one_dim_srcDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); - static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); - - // used by the BlockWise and MultiBlock method - using refType_src2dDesc_padded_34 = decltype( - transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pass_through_transform(ref_invariantLen), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dstDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dstDesc); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_34 = - typename get_ref_desc_types::refType_src2dDesc_padded_34; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, - int BlkGroupSize, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)p_dst_global; - (void)indices_global; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_multiblock; - - void* const ws_buf2_global = - ws_buf2_bytes_offset > 0 - ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) - : nullptr; - - constexpr int RunId = need_indices ? 2 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - BlkGroupSize, - alpha, - static_cast(p_src_global), - beta, - static_cast(ws_buf1_global), - static_cast(ws_buf2_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_partial_dims.cpp deleted file mode 100644 index 0e578f4d1d8..00000000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_partial_dims.cpp +++ /dev/null @@ -1,310 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_multiblock.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t srcDims = CK_PARAM_IN_DIMS; -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS; -constexpr index_t num_invariantDims = srcDims - num_toReduceDims; - -using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type; -using toReduceDims = typename arithmetic_sequence_gen::type; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!"); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, - int BlkGroupSize, - int inLength0, - int inLength1, - int inLength2, - int inLength3, - int inLength4, - int inLength5, - int inStride0, - int inStride1, - int inStride2, - int inStride3, - int inStride4, - int inStride5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, - void* __restrict__ ws_global) -{ - (void)GridSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; - const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; - - const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); - - const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const auto toReduceDimLengths = make_tuple_from_array_and_index_seq(srcLengths, toReduceDims{}); - const auto invariantDimLengths = - make_tuple_from_array_and_index_seq(srcLengths, invariantDims{}); - - auto src2dDesc = - transform_tensor_descriptor(srcDesc, - make_tuple(make_merge_transform(invariantDimLengths), - make_merge_transform(toReduceDimLengths)), - make_tuple(invariantDims{}, toReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto invariantLen = src2dDesc.GetLength(Number<0>{}); - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); - - constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; - const index_t reduceSizePerBlock = - (((toReduceLen + BlkGroupSize - 1) / BlkGroupSize + copySliceLen - 1) / copySliceLen) * - copySliceLen; - - if constexpr(src2d_need_padding) - { - const auto srcPad = reduceSizePerBlock * BlkGroupSize - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pass_through_transform(invariantLen), - make_pad_transform(toReduceLen, 0, srcPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_toReduceDimLengths = - typename uniform_sequence_gen::type{}; - static constexpr auto ref_invariantDimLengths = - typename uniform_sequence_gen::type{}; - - static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - static constexpr auto ref_dstLengths = typename uniform_sequence_gen::type{}; - - // don't have to use accurate strides to get an expected referrence type - static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths)); - - static constexpr auto ref_src2dDesc = transform_tensor_descriptor( - ref_srcDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_invariantDimLengths)), - make_merge_transform(make_tuple_from_seq(ref_toReduceDimLengths))), - make_tuple(invariantDims{}, toReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); - static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); - - // used by the BlockWise and MultiBlock method - using refType_src2dDesc_padded_34 = decltype( - transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pass_through_transform(ref_invariantLen), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); -}; - -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_34 = - typename get_ref_desc_types:: - refType_src2dDesc_padded_34; -using refType_dst1dDesc_padded = - typename get_ref_desc_types:: - refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, - int BlkGroupSize, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)p_dst_global; - (void)indices_global; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_multiblock; - - void* const ws_buf2_global = - ws_buf2_bytes_offset > 0 - ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) - : nullptr; - - constexpr int RunId = need_indices ? 2 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - BlkGroupSize, - alpha, - static_cast(p_src_global), - beta, - static_cast(ws_buf1_global), - static_cast(ws_buf2_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_all_dims.cpp deleted file mode 100644 index e63a1254e4d..00000000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_all_dims.cpp +++ /dev/null @@ -1,284 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_direct_threadwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t srcDims = CK_PARAM_IN_DIMS; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, - int BlkGroupSize, - int inLength0, - int inLength1, - int inLength2, - int inLength3, - int inLength4, - int inLength5, - int inStride0, - int inStride1, - int inStride2, - int inStride3, - int inStride4, - int inStride5, - void* __restrict__ ws_global) -{ - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; - const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - - const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple(1); - const auto tupleDstStrides = make_tuple(1); - - const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const auto one_dim_srcDesc = transform_tensor_descriptor( - srcDesc, - make_tuple(make_merge_transform(tupleSrcLengths)), - make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - auto src2dDesc = transform_tensor_descriptor( - one_dim_srcDesc, - make_tuple(make_unmerge_transform(make_tuple(1, one_dim_srcDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - - constexpr int invariantLen = 1; - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); - - constexpr auto copySliceLen = GredThreadBufferLength; - - if constexpr(src2d_need_padding) - { - const auto srcPad1 = GridSize * BlockSize - invariantLen; - const auto srcPad2 = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pad_transform(invariantLen, 0, srcPad1), - make_pad_transform(toReduceLen, 0, srcPad2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if constexpr(dst1d_need_padding) - { - const auto dstPad = GridSize * BlockSize - invariantLen; - auto dst1dDesc_2 = - transform_tensor_descriptor(dstdDesc, - make_tuple(make_pad_transform(invariantLen, 0, dstPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dstDesc; - } -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - - // don't have to use accurate strides to get an expected referrence type - static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1)); - - static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( - ref_srcDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_srcLengths))), - make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr auto ref_src2dDesc = - transform_tensor_descriptor(ref_one_dim_srcDesc, - make_tuple(make_unmerge_transform( - make_tuple(1, ref_one_dim_srcDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); - static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); - - // used by the DirectThreadWise and DirectWarpWise method - using refType_src2dDesc_padded_12 = - decltype(transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dstDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dstDesc); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_12 = - typename get_ref_desc_types::refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, - int BlkGroupSize, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)BlkGroupSize; - (void)ws_buf2_bytes_offset; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise; - - constexpr int RunId = need_indices ? 2 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(p_src_global), - beta, - static_cast(p_dst_global), - static_cast(nullptr), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_partial_dims.cpp deleted file mode 100644 index 698f740058f..00000000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_partial_dims.cpp +++ /dev/null @@ -1,318 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_direct_threadwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t srcDims = CK_PARAM_IN_DIMS; -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS; -constexpr index_t num_invariantDims = srcDims - num_toReduceDims; - -using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type; -using toReduceDims = typename arithmetic_sequence_gen::type; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!"); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, - int BlkGroupSize, - int inLength0, - int inLength1, - int inLength2, - int inLength3, - int inLength4, - int inLength5, - int inStride0, - int inStride1, - int inStride2, - int inStride3, - int inStride4, - int inStride5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, - void* __restrict__ ws_global) -{ - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; - const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; - - const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); - - const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const auto toReduceDimLengths = make_tuple_from_array_and_index_seq(srcLengths, toReduceDims{}); - const auto invariantDimLengths = - make_tuple_from_array_and_index_seq(srcLengths, invariantDims{}); - - auto src2dDesc = - transform_tensor_descriptor(srcDesc, - make_tuple(make_merge_transform(invariantDimLengths), - make_merge_transform(toReduceDimLengths)), - make_tuple(invariantDims{}, toReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto invariantLen = src2dDesc.GetLength(Number<0>{}); - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); - - constexpr auto copySliceLen = GredThreadBufferLength; - - if constexpr(src2d_need_padding) - { - const auto srcPad1 = GridSize * BlockSize - invariantLen; - const auto srcPad2 = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pad_transform(invariantLen, 0, srcPad1), - make_pad_transform(toReduceLen, 0, srcPad2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if constexpr(dst1d_need_padding) - { - const auto dstPad = GridSize * BlockSize - invariantLen; - auto dst1dDesc_2 = - transform_tensor_descriptor(dst1dDesc, - make_tuple(make_pad_transform(invariantLen, 0, dstPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; - } -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_toReduceDimLengths = - typename uniform_sequence_gen::type{}; - static constexpr auto ref_invariantDimLengths = - typename uniform_sequence_gen::type{}; - - static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - static constexpr auto ref_dstLengths = typename uniform_sequence_gen::type{}; - - // don't have to use accurate strides to get an expected referrence type - static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths)); - - static constexpr auto ref_src2dDesc = transform_tensor_descriptor( - ref_srcDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_invariantDimLengths)), - make_merge_transform(make_tuple_from_seq(ref_toReduceDimLengths))), - make_tuple(invariantDims{}, toReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); - static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); - - // used by the DirectThreadWise and DirectWarpWise method - using refType_src2dDesc_padded_12 = - decltype(transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); -}; - -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_12 = - typename get_ref_desc_types:: - refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = - typename get_ref_desc_types:: - refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, - int BlkGroupSize, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)BlkGroupSize; - (void)ws_buf2_bytes_offset; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise; - - constexpr int RunId = need_indices ? 2 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(p_src_global), - beta, - static_cast(p_dst_global), - static_cast(nullptr), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_all_dims.cpp deleted file mode 100644 index 4a607372e95..00000000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_all_dims.cpp +++ /dev/null @@ -1,285 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_direct_warpwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t srcDims = CK_PARAM_IN_DIMS; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, - int BlkGroupSize, - int inLength0, - int inLength1, - int inLength2, - int inLength3, - int inLength4, - int inLength5, - int inStride0, - int inStride1, - int inStride2, - int inStride3, - int inStride4, - int inStride5, - void* __restrict__ ws_global) -{ - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; - const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - - const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple(1); - const auto tupleDstStrides = make_tuple(1); - - const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const auto one_dim_srcDesc = transform_tensor_descriptor( - srcDesc, - make_tuple(make_merge_transform(tupleSrcLengths)), - make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - auto src2dDesc = transform_tensor_descriptor( - one_dim_srcDesc, - make_tuple(make_unmerge_transform(make_tuple(1, one_dim_srcDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - - constexpr int invariantLen = 1; - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); - - constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp; - - if constexpr(src2d_need_padding) - { - const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen; - const auto srcPad2 = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pad_transform(invariantLen, 0, srcPad1), - make_pad_transform(toReduceLen, 0, srcPad2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if constexpr(dst1d_need_padding) - { - const auto dstPad = GridSize * BlockSize / warpSize - invariantLen; - auto dst1dDesc_2 = - transform_tensor_descriptor(dstDesc, - make_tuple(make_pad_transform(invariantLen, 0, dstPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dstDesc; - } -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - - // don't have to use accurate strides to get an expected referrence type - static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1)); - - static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( - ref_srcDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_srcLengths))), - make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr auto ref_src2dDesc = - transform_tensor_descriptor(ref_one_dim_srcDesc, - make_tuple(make_unmerge_transform( - make_tuple(1, ref_one_dim_srcDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); - static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); - - // used by the DirectThreadWise and DirectWarpWise method - using refType_src2dDesc_padded_12 = - decltype(transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dstDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dstDesc); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_12 typename get_ref_desc_types::refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, - int BlkGroupSize, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)BlkGroupSize; - (void)ws_buf2_bytes_offset; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = - GridwiseReduction_xy_to_x_direct_warpwise; - - constexpr int RunId = need_indices ? 2 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(p_src_global), - beta, - static_cast(p_dst_global), - static_cast(nullptr), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_partial_dims.cpp deleted file mode 100644 index a6415279006..00000000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_partial_dims.cpp +++ /dev/null @@ -1,320 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_direct_warpwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t srcDims = CK_PARAM_IN_DIMS; -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS; -constexpr index_t num_invariantDims = srcDims - num_toReduceDims; - -using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type; -using toReduceDims = typename arithmetic_sequence_gen::type; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!"); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, - int BlkGroupSize, - int inLength0, - int inLength1, - int inLength2, - int inLength3, - int inLength4, - int inLength5, - int inStride0, - int inStride1, - int inStride2, - int inStride3, - int inStride4, - int inStride5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, - void* __restrict__ ws_global) -{ - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; - const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; - - const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); - - const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const auto toReduceDimLengths = make_tuple_from_array_and_index_seq(srcLengths, toReduceDims{}); - const auto invariantDimLengths = - make_tuple_from_array_and_index_seq(srcLengths, invariantDims{}); - - auto src2dDesc = - transform_tensor_descriptor(srcDesc, - make_tuple(make_merge_transform(invariantDimLengths), - make_merge_transform(toReduceDimLengths)), - make_tuple(invariantDims{}, toReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto invariantLen = src2dDesc.GetLength(Number<0>{}); - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); - - constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp; - - if constexpr(src2d_need_padding) - { - const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen; - const auto srcPad2 = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pad_transform(invariantLen, 0, srcPad1), - make_pad_transform(toReduceLen, 0, srcPad2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if constexpr(dst1d_need_padding) - { - const auto dstPad = GridSize * BlockSize / warpSize - invariantLen; - auto dst1dDesc_2 = - transform_tensor_descriptor(dst1dDesc, - make_tuple(make_pad_transform(invariantLen, 0, dstPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; - } -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_toReduceDimLengths = - typename uniform_sequence_gen::type{}; - static constexpr auto ref_invariantDimLengths = - typename uniform_sequence_gen::type{}; - - static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - static constexpr auto ref_dstLengths = typename uniform_sequence_gen::type{}; - - // don't have to use accurate strides to get an expected referrence type - static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths)); - - static constexpr auto ref_src2dDesc = transform_tensor_descriptor( - ref_srcDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_invariantDimLengths)), - make_merge_transform(make_tuple_from_seq(ref_toReduceDimLengths))), - make_tuple(invariantDims{}, toReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); - static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); - - // used by the DirectThreadWise and DirectWarpWise method - using refType_src2dDesc_padded_12 = - decltype(transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); -}; - -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_12 = - typename get_ref_desc_types:: - refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = - typename get_ref_desc_types:: - refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, - int BlkGroupSize, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)BlkGroupSize; - (void)ws_buf2_bytes_offset; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = - GridwiseReduction_xy_to_x_direct_warpwise; - - constexpr int RunId = need_indices ? 2 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(p_src_global), - beta, - static_cast(p_dst_global), - static_cast(nullptr), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_all_dims.cpp deleted file mode 100644 index 7e9d46612ef..00000000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_all_dims.cpp +++ /dev/null @@ -1,205 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_blockwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable - -extern "C" __global__ void -gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global) -{ - (void)GridSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const auto tupleDstLengths = make_tuple(1); - const auto tupleDstStrides = make_tuple(1); - - auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const index_t invariantLen = dstDesc.GetLength(Number<0>{}); - const index_t toReduceLen = BlkGroupSize; - - auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen)); - - constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; - - if constexpr(src2d_need_padding) - { - const auto srcPad = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pass_through_transform(invariantLen), - make_pad_transform(toReduceLen, 0, srcPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dstDesc; -}; - -struct get_ref_desc_types -{ - static constexpr auto ref_tupleDstLengths = make_tuple(8); - static constexpr auto ref_dstDesc = - make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths); - - static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{}); - static constexpr index_t ref_toReduceLen = 8; - - static constexpr auto ref_src2dDesc = - make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen)); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dstDesc); - - // used by the BlockWise and MultiBlock method - using refType_src2dDesc_padded_34 = decltype( - transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pass_through_transform(ref_invariantLen), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dstDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_34 = typename get_ref_desc_types::refType_src2dDesc_padded_34; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)p_src_global; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_blockwise; - - void* const ws_buf2_global = - ws_buf2_bytes_offset > 0 - ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) - : nullptr; - - constexpr int RunId = need_indices ? 3 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(ws_buf1_global), - beta, - static_cast(p_dst_global), - static_cast(ws_buf2_global), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_partial_dims.cpp deleted file mode 100644 index 3f37d01e21e..00000000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_partial_dims.cpp +++ /dev/null @@ -1,263 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_blockwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, - int BlkGroupSize, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, - void* __restrict__ ws_global) -{ - (void)GridSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; - - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); - - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const index_t invariantLen = dst1dDesc.GetLength(Number<0>{}); - const index_t toReduceLen = BlkGroupSize; - - auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen)); - - constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; - - if constexpr(src2d_need_padding) - { - const auto srcPad = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pass_through_transform(invariantLen), - make_pad_transform(toReduceLen, 0, srcPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_tupleDstLengths = - make_tuple_from_seq(typename uniform_sequence_gen::type{}); - static constexpr auto ref_dstDesc = - make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths); - - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(ref_tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr index_t ref_invariantLen = ref_dst1dDesc.GetLength(Number<0>{}); - static constexpr index_t ref_toReduceLen = 8; - - static constexpr auto ref_src2dDesc = - make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen)); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); - - // used by the BlockWise and MultiBlock method - using refType_src2dDesc_padded_34 = decltype( - transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pass_through_transform(ref_invariantLen), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_34 = - typename get_ref_desc_types::refType_src2dDesc_padded_34; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)p_src_global; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_blockwise; - - void* const ws_buf2_global = - ws_buf2_bytes_offset > 0 - ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) - : nullptr; - - constexpr int RunId = need_indices ? 3 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(ws_buf1_global), - beta, - static_cast(p_dst_global), - static_cast(ws_buf2_global), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_all_dims.cpp deleted file mode 100644 index 77841d1312b..00000000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_all_dims.cpp +++ /dev/null @@ -1,222 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_direct_threadwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -using toReduceDims = Sequence; -using invariantDims = Sequence; // this could be empty - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable - -extern "C" __global__ void -gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global) -{ - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const auto tupleDstLengths = make_tuple(1); - const auto tupleDstStrides = make_tuple(1); - - auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const index_t invariantLen = dstDesc.GetLength(Number<0>{}); - const index_t toReduceLen = BlkGroupSize; - - auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen)); - - constexpr auto copySliceLen = GredThreadBufferLength; - - if constexpr(src2d_need_padding) - { - const auto srcPad1 = GridSize * BlockSize - invariantLen; - const auto srcPad2 = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pad_transform(invariantLen, 0, srcPad1), - make_pad_transform(toReduceLen, 0, srcPad2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if constexpr(dst1d_need_padding) - { - const auto dstPad = GridSize * BlockSize - invariantLen; - auto dst1dDesc_2 = - transform_tensor_descriptor(dstDesc, - make_tuple(make_pad_transform(invariantLen, 0, dstPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dstDesc; - } -}; - -struct get_ref_desc_types -{ - static constexpr auto ref_tupleDstLengths = make_tuple(8); - static constexpr auto ref_dstDesc = - make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths); - - static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{}); - static constexpr index_t ref_toReduceLen = 8; - - static constexpr auto ref_src2dDesc = - make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen)); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dstDesc); - - // used by the DirectThreadWise and DirectWarpWise method - using refType_src2dDesc_padded_12 = - decltype(transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dstDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_12 = typename get_ref_desc_types::refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)p_src_global; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise; - - void* const ws_buf2_global = - ws_buf2_bytes_offset > 0 - ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) - : nullptr; - - constexpr int RunId = need_indices ? 3 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(ws_buf1_global), - beta, - static_cast(p_dst_global), - static_cast(ws_buf2_global), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_partial_dims.cpp deleted file mode 100644 index 2de461ad0fa..00000000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_partial_dims.cpp +++ /dev/null @@ -1,277 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_direct_threadwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, - int BlkGroupSize, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, - void* __restrict__ ws_global) -{ - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; - - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); - - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const index_t invariantLen = dst1dDesc.GetLength(Number<0>{}); - const index_t toReduceLen = BlkGroupSize; - - auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen)); - - constexpr auto copySliceLen = GredThreadBufferLength; - - if constexpr(src2d_need_padding) - { - const auto srcPad1 = GridSize * BlockSize - invariantLen; - const auto srcPad2 = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pad_transform(invariantLen, 0, srcPad1), - make_pad_transform(toReduceLen, 0, srcPad2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if constexpr(dst1d_need_padding) - { - const auto dstPad = GridSize * BlockSize - invariantLen; - auto dst1dDesc_2 = - transform_tensor_descriptor(dst1dDesc, - make_tuple(make_pad_transform(invariantLen, 0, dstPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; - } -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_tupleDstLengths = - make_tuple_from_seq(typename uniform_sequence_gen::type{}); - static constexpr auto ref_dstDesc = - make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths); - - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(ref_tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr index_t ref_invariantLen = ref_dst1dDesc.GetLength(Number<0>{}); - static constexpr index_t ref_toReduceLen = 8; - - static constexpr auto ref_src2dDesc = - make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen)); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); - - // used by the DirectThreadWise and DirectWarpWise method - using refType_src2dDesc_padded_12 = - decltype(transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_12 = - typename get_ref_desc_types::refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)p_src_global; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise; - - void* const ws_buf2_global = - ws_buf2_bytes_offset > 0 - ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) - : nullptr; - - constexpr int RunId = need_indices ? 3 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(ws_buf1_global), - beta, - static_cast(p_dst_global), - static_cast(ws_buf2_global), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_all_dims.cpp deleted file mode 100644 index 1ba5e496579..00000000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_all_dims.cpp +++ /dev/null @@ -1,221 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_direct_warpwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable - -extern "C" __global__ void -gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global) -{ - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const auto tupleDstLengths = make_tuple(1); - const auto tupleDstStrides = make_tuple(1); - - auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const index_t invariantLen = dstDesc.GetLength(Number<0>{}); - const index_t toReduceLen = BlkGroupSize; - - auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen)); - - constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp; - - if constexpr(src2d_need_padding) - { - const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen; - const auto srcPad2 = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pad_transform(invariantLen, 0, srcPad1), - make_pad_transform(toReduceLen, 0, srcPad2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if constexpr(dst1d_need_padding) - { - const auto dstPad = GridSize * BlockSize / warpSize - invariantLen; - auto dst1dDesc_2 = - transform_tensor_descriptor(dstDesc, - make_tuple(make_pad_transform(invariantLen, 0, dstPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dstDesc; - } -}; - -struct get_ref_desc_types -{ - static constexpr auto ref_tupleDstLengths = make_tuple(8); - static constexpr auto ref_dstDesc = - make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths); - - static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{}); - static constexpr index_t ref_toReduceLen = 8; - - static constexpr auto ref_src2dDesc = - make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen)); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dstDesc); - - // used by the DirectThreadWise and DirectWarpWise method - using refType_src2dDesc_padded_12 = - decltype(transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dstDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_12 = typename get_ref_desc_types::refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)p_src_global; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = - GridwiseReduction_xy_to_x_direct_warpwise; - - void* const ws_buf2_global = - ws_buf2_bytes_offset > 0 - ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) - : nullptr; - - constexpr int RunId = need_indices ? 3 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(ws_buf1_global), - beta, - static_cast(p_dst_global), - static_cast(ws_buf2_global), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_partial_dims.cpp deleted file mode 100644 index aef1545f118..00000000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_partial_dims.cpp +++ /dev/null @@ -1,279 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_direct_warpwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, - int BlkGroupSize, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, - void* __restrict__ ws_global) -{ - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; - - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); - - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const index_t invariantLen = dst1dDesc.GetLength(Number<0>{}); - const index_t toReduceLen = BlkGroupSize; - - auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen)); - - constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp; - - if constexpr(src2d_need_padding) - { - const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen; - const auto srcPad2 = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pad_transform(invariantLen, 0, srcPad1), - make_pad_transform(toReduceLen, 0, srcPad2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if constexpr(dst1d_need_padding) - { - const auto dstPad = GridSize * BlockSize / warpSize - invariantLen; - auto dst1dDesc_2 = - transform_tensor_descriptor(dst1dDesc, - make_tuple(make_pad_transform(invariantLen, 0, dstPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; - } -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_tupleDstLengths = - make_tuple_from_seq(typename uniform_sequence_gen::type{}); - static constexpr auto ref_dstDesc = - make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths); - - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(ref_tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr index_t ref_invariantLen = ref_dst1dDesc.GetLength(Number<0>{}); - static constexpr index_t ref_toReduceLen = 8; - - static constexpr auto ref_src2dDesc = - make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen)); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); - - // used by the DirectThreadWise and DirectWarpWise method - using refType_src2dDesc_padded_12 = - decltype(transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_12 = - typename get_ref_desc_types::refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)p_src_global; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = - GridwiseReduction_xy_to_x_direct_warpwise; - - void* const ws_buf2_global = - ws_buf2_bytes_offset > 0 - ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) - : nullptr; - - constexpr int RunId = need_indices ? 3 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(ws_buf1_global), - beta, - static_cast(p_dst_global), - static_cast(ws_buf2_global), - static_cast(indices_global)); -}; diff --git a/dev-requirements.txt b/dev-requirements.txt new file mode 100644 index 00000000000..5d123edb856 --- /dev/null +++ b/dev-requirements.txt @@ -0,0 +1,3 @@ +ROCmSoftwarePlatform/rocm-recipes +# 1.90+ +danmar/cppcheck@dd05839a7e63ef04afd34711cb3e1e0ef742882f \ No newline at end of file diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt new file mode 100644 index 00000000000..a0fe1fe2fa2 --- /dev/null +++ b/example/01_gemm/CMakeLists.txt @@ -0,0 +1,6 @@ +add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) +add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) +add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) +add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) +add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) +add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) diff --git a/example/01_gemm/README.md b/example/01_gemm/README.md new file mode 100644 index 00000000000..226783b03b0 --- /dev/null +++ b/example/01_gemm/README.md @@ -0,0 +1,23 @@ +# Instructions for ```example_gemm_xdl``` + +## Run ```example_gemm_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +./bin/example_gemm_xdl 0 1 5 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s +``` diff --git a/example/01_gemm/gemm_dl_fp16.cpp b/example/01_gemm/gemm_dl_fp16.cpp new file mode 100644 index 00000000000..6e8e04f9e51 --- /dev/null +++ b/example/01_gemm/gemm_dl_fp16.cpp @@ -0,0 +1,211 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device:: + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(1); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem" + << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + bool pass = true; + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + } + + return pass ? 0 : 1; +} diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp new file mode 100644 index 00000000000..65c806bf07e --- /dev/null +++ b/example/01_gemm/gemm_dl_fp32.cpp @@ -0,0 +1,210 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = float; +using BDataType = float; +using CDataType = float; +using AccDataType = float; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device:: + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(1); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem" + << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + bool pass = true; + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + } + + return pass ? 0 : 1; +} diff --git a/example/01_gemm/gemm_dl_int8.cpp b/example/01_gemm/gemm_dl_int8.cpp new file mode 100644 index 00000000000..a9590030c7f --- /dev/null +++ b/example/01_gemm/gemm_dl_int8.cpp @@ -0,0 +1,208 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using BDataType = int8_t; +using CDataType = int8_t; +using AccDataType = int32_t; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device:: + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(1); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem" + << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + bool pass = true; + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + } + + return pass ? 0 : 1; +} diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp new file mode 100644 index 00000000000..060750e6768 --- /dev/null +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -0,0 +1,239 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using CDataType = BF16; +using AccDataType = F32; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle + , // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder + 2, // index_t ABlockTransferSrcVectorDim + 8, // index_t ABlockTransferSrcScalarPerVector + 8, // index_t ABlockTransferDstScalarPerVector_AK1 + 1, // index_t ABlockLdsExtraM + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder + 2, // index_t BBlockTransferSrcVectorDim + 8, // index_t BBlockTransferSrcScalarPerVector + 8, // index_t BBlockTransferDstScalarPerVector_BK1 + 1, // index_t BBlockLdsExtraN + 1, // index_t CShuffleMXdlPerWavePerShuffle + 1, // index_t CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto c_element_op = PassThrough{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor a_f32_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_f32_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_f32_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + bf16_to_f32_(a_m_k, a_f32_m_k); + bf16_to_f32_(b_k_n, b_f32_k_n); + bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_f32_m_k, b_f32_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp new file mode 100644 index 00000000000..06523037f96 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -0,0 +1,203 @@ +#include +#include +#include +#include +#include +#include +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp new file mode 100644 index 00000000000..a22c21e40e2 --- /dev/null +++ b/example/01_gemm/gemm_xdl_int8.cpp @@ -0,0 +1,226 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using BDataType = int8_t; +using CDataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = int8_t; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle< + ALayout, // typename ALayout + BLayout, // typename BLayout + CLayout, // typename CLayout + ADataType, // typename ADataType + BDataType, // typename BDataType + CDataType, // typename CDataType + AccDataType, // typename GemmAccDataType + CShuffleDataType, // typename CShuffleDataType + PassThrough, // typename AElementwiseOperation + PassThrough, // typename BElementwiseOperation + PassThrough, // typename CElementwiseOperation + GemmDefault, // GemmSpecialization GemmSpec + 1, // index_t NumGemmKPrefetchStage + 256, // index_t BlockSize + 256, // index_t MPerBlock + 128, // index_t NPerBlock + 64, // index_t KPerBlock + 16, // index_t AK1 + 16, // index_t BK1 + 32, // index_t MPerXDL + 32, // index_t NPerXDL + 4, // index_t MXdlPerWave + 2, // index_t NXdlPerWave + S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder + 2, // index_t ABlockTransferSrcVectorDim + 16, // index_t ABlockTransferSrcScalarPerVector + 16, // index_t ABlockTransferDstScalarPerVector_AK1 + 1, // index_t ABlockLdsExtraM + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder + 2, // index_t BBlockTransferSrcVectorDim + 8, // index_t BBlockTransferSrcScalarPerVector + 8, // index_t BBlockTransferDstScalarPerVector_BK1 + 1, // index_t BBlockLdsExtraN + 1, // index_t CShuffleMXdlPerWavePerShuffle + 1, // index_t CShuffleNXdlPerWavePerShuffle + S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto c_element_op = PassThrough{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/02_gemm_alpha_beta/CMakeLists.txt b/example/02_gemm_alpha_beta/CMakeLists.txt new file mode 100644 index 00000000000..1b81cf21622 --- /dev/null +++ b/example/02_gemm_alpha_beta/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_xdl_alpha_beta gemm_xdl_alpha_beta.cpp) diff --git a/example/02_gemm_alpha_beta/README.md b/example/02_gemm_alpha_beta/README.md new file mode 100644 index 00000000000..ba2a3068f3e --- /dev/null +++ b/example/02_gemm_alpha_beta/README.md @@ -0,0 +1,26 @@ +# Instructions for ```example_gemm_xdl_alpha_beta``` + +## Run ```example_gemm_xdl_alpha_beta``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +./bin/example_gemm_xdl_alpha_beta 1 1 1 0.5 0.5 +``` +Result (MI100 @ 1502Mhz, 184.6TFlops peak FP16) +``` +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c0_grid_desc_m_n_{ 3840, 4096} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Perf: 0.936965 ms, 137.517 TFlops, 102.959 GB/s +error: 0 +max_diff: 0, 558.5, 558.5 +``` diff --git a/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp b/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp new file mode 100644 index 00000000000..1a6e1de4dcf --- /dev/null +++ b/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp @@ -0,0 +1,253 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_base.hpp" +#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm_bias_2d.hpp" + +template +using S = ck::Sequence; + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::AlphaBetaAdd; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Bias_2d< + ADataType, // ADataType + BDataType, // BDataType + CDataType, // CDataType + AccDataType, // AccDataType + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBias2D; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + float alpha = 1.0f; + float beta = 1.0f; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 12) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + + alpha = std::stof(argv[10]); + beta = std::stof(argv[11]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, alpha, beta\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c0_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c0_m_n: " << c0_m_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + c0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + c0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c0_m_n_device_buf(sizeof(CDataType) * c0_m_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c0_m_n_device_buf.ToDevice(c0_m_n.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c0_m_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + AElementOp{}, + BElementOp{}, + CElementOp{alpha, beta}); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + b_k_n, + c0_m_n, + c_m_n_host_result, + AElementOp{}, + BElementOp{}, + CElementOp{alpha, beta}); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/03_gemm_bias_relu/CMakeLists.txt b/example/03_gemm_bias_relu/CMakeLists.txt new file mode 100644 index 00000000000..d07ad6e36c3 --- /dev/null +++ b/example/03_gemm_bias_relu/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_xdl_bias_relu gemm_xdl_bias_relu.cpp) diff --git a/example/03_gemm_bias_relu/README.md b/example/03_gemm_bias_relu/README.md new file mode 100644 index 00000000000..f8d9bd61529 --- /dev/null +++ b/example/03_gemm_bias_relu/README.md @@ -0,0 +1,28 @@ +# Instructions for ```example_gemm_xdl_bias_relu_add``` + +## Run ```example_gemm_xdl_bias_relu_add``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC +./bin/example_gemm_xdl_bias_relu_add 0 1 5 3840 4096 4096 4096 4096 4096 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +c1_m_n: dim 2, lengths {3840, 4096}, strides {1, 0} +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +arg.c0_grid_desc_m_n_{ 3840, 4096} +arg.c1_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.27583 ms, 100.992 TFlops, 73.9688 GB/s +``` diff --git a/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp b/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp new file mode 100644 index 00000000000..3bf3003c147 --- /dev/null +++ b/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp @@ -0,0 +1,239 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation.hpp" +#include "reference_gemm_bias_activation.hpp" + +template +using S = ck::Sequence; + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::AddRelu; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Bias_Activation< + ADataType, // ADataType + BDataType, // BDataType + CDataType, // CDataType + AccDataType, // AccDataType + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBiasActivation; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + // c0_n[n] + Tensor c0_n(HostTensorDescriptor( + std::vector({static_cast(N)}), std::vector({1}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "c0_n: " << c0_n.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + c0_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + c0_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem c0_n_device_buf(sizeof(CDataType) * c0_n.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); + c0_n_device_buf.ToDevice(c0_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + static_cast(c0_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + + sizeof(CDataType) * M * N + sizeof(CDataType) * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, c0_n, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/04_gemm_bias_relu_add/CMakeLists.txt b/example/04_gemm_bias_relu_add/CMakeLists.txt new file mode 100644 index 00000000000..4f48db94a88 --- /dev/null +++ b/example/04_gemm_bias_relu_add/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_xdl_bias_relu_add gemm_xdl_bias_relu_add.cpp) diff --git a/example/04_gemm_bias_relu_add/README.md b/example/04_gemm_bias_relu_add/README.md new file mode 100644 index 00000000000..f8d9bd61529 --- /dev/null +++ b/example/04_gemm_bias_relu_add/README.md @@ -0,0 +1,28 @@ +# Instructions for ```example_gemm_xdl_bias_relu_add``` + +## Run ```example_gemm_xdl_bias_relu_add``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC +./bin/example_gemm_xdl_bias_relu_add 0 1 5 3840 4096 4096 4096 4096 4096 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +c1_m_n: dim 2, lengths {3840, 4096}, strides {1, 0} +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +arg.c0_grid_desc_m_n_{ 3840, 4096} +arg.c1_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.27583 ms, 100.992 TFlops, 73.9688 GB/s +``` diff --git a/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp b/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp new file mode 100644 index 00000000000..73e92f9d116 --- /dev/null +++ b/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp @@ -0,0 +1,257 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp" +#include "reference_gemm_bias_activation_add.hpp" + +template +using S = ck::Sequence; + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::AddReluAdd; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< + ADataType, // ADataType + BDataType, // BDataType + CDataType, // CDataType + AccDataType, // AccDataType + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmBiasActivationAdd; +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + ck::index_t StrideC1 = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + StrideC1 = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, StrideC1\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + // c0_n[n] + Tensor c0_n(HostTensorDescriptor( + std::vector({static_cast(N)}), std::vector({1}))); + + // c1_m_n[m ,n] + Tensor c1_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "c0_n: " << c0_n.mDesc << std::endl; + std::cout << "c1_m_n: " << c1_m_n.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + c0_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + c1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + c0_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + c1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem c0_n_device_buf(sizeof(CDataType) * c0_n.mDesc.GetElementSpace()); + DeviceMem c1_m_n_device_buf(sizeof(CDataType) * c1_m_n.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); + c0_n_device_buf.ToDevice(c0_n.mData.data()); + c1_m_n_device_buf.ToDevice(c1_m_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + static_cast(c0_n_device_buf.GetDeviceBuffer()), + static_cast(c1_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideC1, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + + sizeof(CDataType) * M * N + sizeof(CDataType) * N + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + b_k_n, + c_m_n_host_result, + c0_n, + c1_m_n, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/06_conv2d_fwd_bias_relu/CMakeLists.txt b/example/06_conv2d_fwd_bias_relu/CMakeLists.txt new file mode 100644 index 00000000000..4e1dd1f3e6e --- /dev/null +++ b/example/06_conv2d_fwd_bias_relu/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_conv2d_fwd_xdl_bias_relu conv2d_fwd_xdl_bias_relu.cpp) +target_link_libraries(example_conv2d_fwd_xdl_bias_relu PRIVATE conv_util) diff --git a/example/06_conv2d_fwd_bias_relu/README.md b/example/06_conv2d_fwd_bias_relu/README.md new file mode 100644 index 00000000000..4c30563ef01 --- /dev/null +++ b/example/06_conv2d_fwd_bias_relu/README.md @@ -0,0 +1,22 @@ +# Instructions for ```example_conv_xdl_bias_relu``` + +## Run ```example_conv_xdl_bias_relu``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx +./bin/example_conv_xdl_bias_relu 0 1 5 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} +out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +bias_k: dim 1, lengths {256}, strides {1} +launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.39009 ms, 105.581 TFlops, 239.981 GB/s +``` diff --git a/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp b/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp new file mode 100644 index 00000000000..d50afb6854c --- /dev/null +++ b/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp @@ -0,0 +1,312 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "conv_util.hpp" +#include "device.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "reference_conv_fwd_bias_activation.hpp" +#include "tensor_layout.hpp" + +namespace { + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using WeiLayout = ck::tensor_layout::convolution::KYXC; +using OutLayout = ck::tensor_layout::convolution::NHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddRelu; + +static constexpr auto MemorySet = ck::InMemoryDataOperationEnum::Set; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +// clang-format off +using DeviceConvFwdInstance = ck::tensor_operation::device:: + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + MemorySet, // OutGlobalMemoryDataOperation + ConvFwdDefault, // ConvForwardSpecialization + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +using ReferenceConvFwdInstance = + ck::tensor_operation::host::ReferenceConvFwd_Bias_Activation; + +void PrintUseMsg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "Following arguments:\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams ParseConvParams(int argc, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + int num_dim_spatial = 2; + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 4; + if(cmdline_nargs != argc) + { + PrintUseMsg(); + exit(0); + } + + ck::utils::conv::ConvParams params; + int arg_idx = 4; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +} // anonymous namespace + +int main(int argc, char* argv[]) +{ + using namespace ck::utils::conv; + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + const int num_dim_spatial = 2; + + ck::utils::conv::ConvParams params; + + if(argc >= 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + + if(argc >= 5) + { + params = ParseConvParams(argc, argv); + } + + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor weights(get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor host_output( + get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + Tensor device_output( + get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + // bias: assume contiguous 1d vector + Tensor bias( + HostTensorDescriptor(std::vector({static_cast(params.K_)}))); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weights: " << weights.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + std::cout << "bias: " << bias.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weights.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); + + auto conv = DeviceConvFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = + conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(bias_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device operator with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = get_flops( + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); + std::size_t num_btype = + get_btype(params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths) + + sizeof(OutDataType) * (params.K_); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto ref_conv = ReferenceConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + bias, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + ref_invoker.Run(ref_argument); + out_device_buf.FromDevice(device_output.mData.data()); + return ck::utils::check_err(device_output.mData, host_output.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/07_conv2d_fwd_bias_relu_add/CMakeLists.txt b/example/07_conv2d_fwd_bias_relu_add/CMakeLists.txt new file mode 100644 index 00000000000..b4dd39d83a7 --- /dev/null +++ b/example/07_conv2d_fwd_bias_relu_add/CMakeLists.txt @@ -0,0 +1,3 @@ +# FIXME: should fix validation failure +add_example_executable_no_testing(example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp) +target_link_libraries(example_conv2d_fwd_xdl_bias_relu_add PRIVATE conv_util) diff --git a/example/07_conv2d_fwd_bias_relu_add/README.md b/example/07_conv2d_fwd_bias_relu_add/README.md new file mode 100644 index 00000000000..99afcae9c86 --- /dev/null +++ b/example/07_conv2d_fwd_bias_relu_add/README.md @@ -0,0 +1,24 @@ +# Instructions for ```example_conv_xdl_bias_relu_add``` + + +## Run ```example_conv_xdl_bias_relu_add``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx +./bin/example_conv_xdl_bias_relu_add 0 1 5 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} +out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +bias_k: dim 1, lengths {256}, strides {1} +resi_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.44711 ms, 101.421 TFlops, 289.218 GB/s +``` diff --git a/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp b/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp new file mode 100644 index 00000000000..53d882778a2 --- /dev/null +++ b/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp @@ -0,0 +1,327 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "conv_util.hpp" +#include "device.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "reference_conv_fwd_bias_activation_add.hpp" +#include "tensor_layout.hpp" + +namespace { + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using WeiLayout = ck::tensor_layout::convolution::KYXC; +using OutLayout = ck::tensor_layout::convolution::NHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +// clang-format off +using DeviceConvFwdInstance = ck::tensor_operation::device:: + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvFwdDefault, // ConvForwardSpecialization + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +using ReferenceConvFwdInstance = + ck::tensor_operation::host::ReferenceConvFwd_Bias_Activation_Add; + +void PrintUseMsg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "Following arguments:\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams ParseConvParams(int argc, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + int num_dim_spatial = 2; + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 4; + if(cmdline_nargs != argc) + { + PrintUseMsg(); + exit(0); + } + + ck::utils::conv::ConvParams params; + int arg_idx = 4; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +} // anonymous namespace + +int main(int argc, char* argv[]) +{ + using namespace ck::utils::conv; + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + const int num_dim_spatial = 2; + + ck::utils::conv::ConvParams params; + + if(argc >= 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + + if(argc >= 5) + { + params = ParseConvParams(argc, argv); + } + + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor weights(get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor host_output( + get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + Tensor device_output( + get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + + // bias: assume contiguous 1d vector + Tensor bias( + HostTensorDescriptor(std::vector({static_cast(params.K_)}))); + + // residual: assume same layout as output tensor + Tensor residual(get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weights: " << weights.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + std::cout << "bias: " << bias.mDesc << std::endl; + std::cout << "residual: " << residual.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + residual.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weights.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + residual.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpace()); + DeviceMem resi_device_buf(sizeof(OutDataType) * residual.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); + resi_device_buf.ToDevice(residual.mData.data()); + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + auto conv = DeviceConvFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = + conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(bias_device_buf.GetDeviceBuffer()), + static_cast(resi_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device operator with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = get_flops( + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); + std::size_t num_btype = + get_btype(params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths) + + sizeof(OutDataType) * (params.K_) + + sizeof(OutDataType) * + (params.N_ * params.K_ * output_spatial_lengths[0] * output_spatial_lengths[1]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto ref_conv = ReferenceConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + bias, + residual, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + out_device_buf.FromDevice(device_output.mData.data()); + return ck::utils::check_err(device_output.mData, host_output.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt new file mode 100644 index 00000000000..ceceb4aedc9 --- /dev/null +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -0,0 +1,6 @@ +add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) +add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) +add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) +target_link_libraries(example_convnd_fwd_xdl_fp32 PRIVATE conv_util) +target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_util) +target_link_libraries(example_convnd_fwd_xdl_fp16 PRIVATE conv_util) diff --git a/example/09_convnd_fwd/README.md b/example/09_convnd_fwd/README.md new file mode 100644 index 00000000000..9ab5fee549d --- /dev/null +++ b/example/09_convnd_fwd/README.md @@ -0,0 +1,32 @@ +# Instructions for ```example_convnd_fwd_xdl``` + +## Run ```example_convnd_fwd_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4: N spatial dimensions (default 2) +#Following arguments (depending on number of spatial dims): +# N, K, C, +# , (ie Y, X for 2D) +# , (ie Hi, Wi for 2D) +# , (ie Sy, Sx for 2D) +# , (ie Dy, Dx for 2D) +# , (ie LeftPy, LeftPx for 2D) +# , (ie RightPy, RightPx for 2D) +./bin/example_convnd_fwd_xdl 0 1 100 +``` + +Result (MI100 @ 1087Mhz, 33.4TFlops peak FP32) +``` +input: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +weights: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} +output: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +arg.a_grid_desc_k0_m_k1_{432, 165888, 4} +arg.b_grid_desc_k0_n_k1_{432, 256, 4} +arg.c_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 100 times... +Perf: 4.43736 ms, 33.0753 TFlops, 150.357 GB/s +``` diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp new file mode 100644 index 00000000000..7ad83d5ad63 --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -0,0 +1,342 @@ +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "conv_util.hpp" +#include "device.hpp" +#include "device_tensor.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "reference_conv_fwd.hpp" +#include "tensor_layout.hpp" + +namespace { + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using WeiLayout = ck::tensor_layout::convolution::KYXC; +using OutLayout = ck::tensor_layout::convolution::NHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +using DeviceConvFwdBasePtr = + ck::tensor_operation::device::DeviceConvFwdPtr; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device:: + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + // clang-format off + InDataType, // + WeiDataType, // + OutDataType, // + AccDataType, // + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector + +template +using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; + +DeviceConvFwdBasePtr get_conv_instance(int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 3: { + return std::make_unique>(); + } + case 2: { + return std::make_unique>(); + } + case 1: { + return std::make_unique>(); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +void print_use_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 5; + if(cmdline_nargs != argc) + { + print_use_msg(); + exit(0); + } + + ck::utils::conv::ConvParams params; + int arg_idx = 5; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +} // anonymous namespace + +int main(int argc, char* argv[]) +{ + using namespace ck::utils::conv; + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int num_dim_spatial = 2; + + ck::utils::conv::ConvParams params; + + if(argc >= 5) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + num_dim_spatial = std::stoi(argv[4]); + } + + if(argc >= 6) + { + params = parse_conv_params(num_dim_spatial, argc, argv); + } + + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor weights(get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor host_output(get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + Tensor device_output(get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weights: " << weights.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weights.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + + // do GEMM + auto conv = get_conv_instance(num_dim_spatial); + auto invoker = conv->MakeInvokerPointer(); + auto argument = + conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv->IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = get_flops( + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); + std::size_t num_btype = get_btype( + params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto verify_f = [&input, &weights, &host_output, ¶ms, &out_device_buf, &device_output]( + const auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + out_device_buf.FromDevice(device_output.mData.data()); + return ck::utils::check_err( + host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f) ? 0 : 1; + }; + + switch(num_dim_spatial) + { + case 3: { + auto ref_conv = ReferenceConvNDFwdInstance<3>(); + verify_f(ref_conv); + break; + } + case 2: { + auto ref_conv = ReferenceConvNDFwdInstance<2>(); + verify_f(ref_conv); + break; + } + case 1: { + auto ref_conv = ReferenceConvNDFwdInstance<1>(); + verify_f(ref_conv); + break; + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } + } + return 0; +} diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp new file mode 100644 index 00000000000..8a9633d84a9 --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp @@ -0,0 +1,346 @@ +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "conv_util.hpp" +#include "device.hpp" +#include "device_tensor.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "reference_conv_fwd.hpp" +#include "tensor_layout.hpp" + +namespace { + +using InDataType = float; +using WeiDataType = float; +using OutDataType = float; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +using DeviceConvFwdBasePtr = + ck::tensor_operation::device::DeviceConvFwdPtr; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device:: + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + // clang-format off + InDataType, // + WeiDataType, // + OutDataType, // + AccDataType, // + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 4, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector + 4, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockTransferAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector +// clang-format on + +template +using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; + +DeviceConvFwdBasePtr get_conv_instance(int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 3: { + return std::make_unique>(); + } + case 2: { + return std::make_unique>(); + } + case 1: { + return std::make_unique>(); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +void print_use_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 5; + if(cmdline_nargs != argc) + { + print_use_msg(); + exit(0); + } + + ck::utils::conv::ConvParams params; + int arg_idx = 5; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +} // anonymous namespace + +int main(int argc, char* argv[]) +{ + using namespace ck::utils::conv; + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int num_dim_spatial = 2; + + ck::utils::conv::ConvParams params; + + if(argc >= 5) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + num_dim_spatial = std::stoi(argv[4]); + } + + if(argc >= 6) + { + params = parse_conv_params(num_dim_spatial, argc, argv); + } + + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor weights(get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor host_output( + get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + Tensor device_output( + get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weights: " << weights.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weights.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + + // do GEMM + auto conv = get_conv_instance(num_dim_spatial); + auto invoker = conv->MakeInvokerPointer(); + auto argument = + conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv->IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = get_flops( + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); + std::size_t num_btype = + get_btype(params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto verify_f = [&input, &weights, &host_output, ¶ms, &out_device_buf, &device_output]( + const auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + out_device_buf.FromDevice(device_output.mData.data()); + return ck::utils::check_err(device_output.mData, + host_output.mData, + "Error: incorrect results!", + 1e-5f, + 1e-4f) + ? 0 + : 1; + }; + + switch(num_dim_spatial) + { + case 3: { + auto ref_conv = ReferenceConvNDFwdInstance<3>(); + verify_f(ref_conv); + break; + } + case 2: { + auto ref_conv = ReferenceConvNDFwdInstance<2>(); + verify_f(ref_conv); + break; + } + case 1: { + auto ref_conv = ReferenceConvNDFwdInstance<1>(); + verify_f(ref_conv); + break; + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } + } + return 0; +} diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp new file mode 100644 index 00000000000..f196d271828 --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp @@ -0,0 +1,344 @@ +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "conv_util.hpp" +#include "device.hpp" +#include "device_tensor.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "reference_conv_fwd.hpp" +#include "tensor_layout.hpp" + +namespace { + +using InDataType = int8_t; +using WeiDataType = int8_t; +using OutDataType = int8_t; +using AccDataType = int32_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using WeiLayout = ck::tensor_layout::convolution::KYXC; +using OutLayout = ck::tensor_layout::convolution::NHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +using DeviceConvFwdBasePtr = + ck::tensor_operation::device::DeviceConvFwdPtr; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device:: + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + // clang-format off + InDataType, // + WeiDataType, // + OutDataType, // + AccDataType, // + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 4, // K0PerBlock + 16, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector + +template +using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; + +DeviceConvFwdBasePtr get_conv_instance(int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 3: { + return std::make_unique>(); + } + case 2: { + return std::make_unique>(); + } + case 1: { + return std::make_unique>(); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +void print_use_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 5; + if(cmdline_nargs != argc) + { + print_use_msg(); + exit(0); + } + + ck::utils::conv::ConvParams params; + int arg_idx = 5; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +} // anonymous namespace + +int main(int argc, char* argv[]) +{ + using namespace ck::utils::conv; + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int num_dim_spatial = 2; + + ck::utils::conv::ConvParams params; + + if(argc >= 5) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + num_dim_spatial = std::stoi(argv[4]); + } + + if(argc >= 6) + { + params = parse_conv_params(num_dim_spatial, argc, argv); + } + + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor weights(get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor host_output(get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + Tensor device_output(get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weights: " << weights.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weights.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + + // do GEMM + auto conv = get_conv_instance(num_dim_spatial); + auto invoker = conv->MakeInvokerPointer(); + auto argument = + conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv->IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = get_flops( + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); + std::size_t num_btype = get_btype( + params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto verify_f = [&input, &weights, &host_output, ¶ms, &out_device_buf, &device_output]( + const auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + out_device_buf.FromDevice(device_output.mData.data()); + return ck::utils::check_err( + host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f) ? 0 : 1; + }; + + switch(num_dim_spatial) + { + case 3: { + auto ref_conv = ReferenceConvNDFwdInstance<3>(); + verify_f(ref_conv); + break; + } + case 2: { + auto ref_conv = ReferenceConvNDFwdInstance<2>(); + verify_f(ref_conv); + break; + } + case 1: { + auto ref_conv = ReferenceConvNDFwdInstance<1>(); + verify_f(ref_conv); + break; + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } + } + return 0; +} diff --git a/example/10_conv2d_bwd_data/CMakeLists.txt b/example/10_conv2d_bwd_data/CMakeLists.txt new file mode 100644 index 00000000000..17aca1481bf --- /dev/null +++ b/example/10_conv2d_bwd_data/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_conv2d_bwd_data_xdl conv2d_bwd_data_xdl.cpp) +target_link_libraries(example_conv2d_bwd_data_xdl PRIVATE conv_util) diff --git a/example/10_conv2d_bwd_data/README.md b/example/10_conv2d_bwd_data/README.md new file mode 100644 index 00000000000..7503ff6d1e0 --- /dev/null +++ b/example/10_conv2d_bwd_data/README.md @@ -0,0 +1,47 @@ +# Instructions for ```example_conv2d_bwd_data_xdl``` Example + + +## Run ```example_conv2d_bwd_data_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx +./bin/example_conv2d_bwd_data_xdl 0 1 5 +``` + +Result +``` +in_n_c_hi_wi: dim 4, lengths {128, 256, 71, 71}, strides {1290496, 1, 18176, 256} +wei_k_c_y_x: dim 4, lengths {256, 256, 3, 3}, strides {2304, 1, 768, 256} +out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +arg.a_grid_desc_k0_m_k1_container_{128, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{128, 256, 8} +arg.c_grid_desc_m_n_container_{ 175232, 256} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 4, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {2738, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{64, 256, 8} +arg.c_grid_desc_m_n_container_{ 175232, 256} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 4, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {2738, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{64, 256, 8} +arg.c_grid_desc_m_n_container_{ 175232, 256} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 4, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {2738, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +arg.a_grid_desc_k0_m_k1_container_{32, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{32, 256, 8} +arg.c_grid_desc_m_n_container_{ 175232, 256} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 4, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {2738, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Perf: 2.45966 ms, 79.5597 TFlops, 169.325 GB/s +``` diff --git a/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp b/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp new file mode 100644 index 00000000000..2d25f5ac2f1 --- /dev/null +++ b/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp @@ -0,0 +1,258 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "tensor_layout.hpp" +#include "element_wise_operation.hpp" +#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "reference_conv_bwd_data.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +using DeviceConvBwdDataInstance = ck::tensor_operation::device:: + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdDefault, // ConvolutionBackwardDataSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 2, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, + 1>; // GemmCThreadTransferDstScalarPerVector + +using ReferenceConvBwdInstance = ck::tensor_operation::host::ReferenceConvBwdData; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // Conv shape + ck::index_t N = 128; + ck::index_t K = 256; + ck::index_t C = 256; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t conv_stride_h = 2; + ck::index_t conv_stride_w = 2; + ck::index_t conv_dilation_h = 1; + ck::index_t conv_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 19) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + C = std::stoi(argv[6]); + Y = std::stoi(argv[7]); + X = std::stoi(argv[8]); + Hi = std::stoi(argv[9]); + Wi = std::stoi(argv[10]); + conv_stride_h = std::stoi(argv[11]); + conv_stride_w = std::stoi(argv[12]); + conv_dilation_h = std::stoi(argv[13]); + conv_dilation_w = std::stoi(argv[14]); + in_left_pad_h = std::stoi(argv[15]); + in_left_pad_w = std::stoi(argv[16]); + in_right_pad_h = std::stoi(argv[17]); + in_right_pad_w = std::stoi(argv[18]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(0); + } + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + const std::vector conv_filter_strides{{conv_stride_h, conv_stride_w}}; + const std::vector conv_filter_dilations{{conv_dilation_h, conv_dilation_w}}; + const std::vector input_left_pads{{in_left_pad_h, in_left_pad_w}}; + const std::vector input_right_pads{{in_right_pad_h, in_right_pad_w}}; + + // tensor layout + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W) { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + }; + + Tensor out_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo)); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X)); + Tensor in_n_c_hi_wi_host_result(f_host_tensor_descriptor(N, C, Hi, Wi)); + Tensor in_n_c_hi_wi_device_result(f_host_tensor_descriptor(N, C, Hi, Wi)); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * + in_n_c_hi_wi_device_result.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + + // reset input to zero + in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1{0}); + in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data()); + + // do GEMM + auto conv = DeviceConvBwdDataInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + std::vector{{Hi, Wi}}, + std::vector{{Y, X}}, + std::vector{{Ho, Wo}}, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto ref_conv = ReferenceConvBwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result, + wei_k_c_y_x, + out_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + + in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); + + return ck::utils::check_err(in_n_c_hi_wi_device_result.mData, + in_n_c_hi_wi_host_result.mData) + ? 0 + : 1; + } + return 0; +} diff --git a/example/11_conv2d_bwd_weight/CMakeLists.txt b/example/11_conv2d_bwd_weight/CMakeLists.txt new file mode 100644 index 00000000000..3d771b55697 --- /dev/null +++ b/example/11_conv2d_bwd_weight/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_conv2d_bwd_weight_xdl conv2d_bwd_weight_xdl.cpp) +target_link_libraries(example_conv2d_bwd_weight_xdl PRIVATE conv_util) diff --git a/example/11_conv2d_bwd_weight/README.md b/example/11_conv2d_bwd_weight/README.md new file mode 100644 index 00000000000..c7627427849 --- /dev/null +++ b/example/11_conv2d_bwd_weight/README.md @@ -0,0 +1,25 @@ +# Instructions for ```example_conv2d_bwd_weight_xdl``` Example + +## Run ```example_conv2d_bwd_weight_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4: is show log (0=no, 1=yes) +#arg5 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx, split-k +./bin/example_conv2d_bwd_weight_xdl 0 1 5 0 4 +``` + +Result +``` +in_n_c_hi_wi: dim 4, lengths {128, 1024, 14, 14}, strides {200704, 1, 14336, 1024} +wei_k_c_y_x: dim 4, lengths {256, 1024, 3, 3}, strides {9216, 1, 3072, 1024} +out_n_k_ho_wo: dim 4, lengths {128, 256, 6, 6}, strides {9216, 1, 1536, 256} +arg.a_grid_desc_kbatch_k0_m_k1_{4, 144, 256, 8} +arg.b_grid_desc_kbatch_k0_n_k1_{4, 144, 9216, 8} +arg.c_grid_desc_m_n_{ 256, 9216} +launch_and_time_kernel: grid_dim {576, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 0.401084 ms, 54.2112 TFlops, 145.75 GB/s +``` diff --git a/example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp b/example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp new file mode 100644 index 00000000000..1578161116c --- /dev/null +++ b/example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp @@ -0,0 +1,299 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "tensor_layout.hpp" +#include "element_wise_operation.hpp" +#include "device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "reference_conv_backward_weight.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using WeiLayout = ck::tensor_layout::convolution::KYXC; +using OutLayout = ck::tensor_layout::convolution::NHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +// clang-format off +using DeviceConvBwdWeightInstance = ck::tensor_operation::device:: + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +using ReferenceConvBwdWeightInstance = + ck::tensor_operation::host::ReferenceConvBwdWeight; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int do_log = 0; + int split_k = 4; + + // Conv shape + ck::index_t N = 128; + ck::index_t K = 256; + ck::index_t C = 1024; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 14; + ck::index_t Wi = 14; + ck::index_t conv_stride_h = 2; + ck::index_t conv_stride_w = 2; + ck::index_t conv_dilation_h = 1; + ck::index_t conv_dilation_w = 1; + ck::index_t in_left_pad_h = 0; + ck::index_t in_left_pad_w = 0; + ck::index_t in_right_pad_h = 0; + ck::index_t in_right_pad_w = 0; + + if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + } + else if(argc == 21) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + + N = std::stoi(argv[6]); + K = std::stoi(argv[7]); + C = std::stoi(argv[8]); + Y = std::stoi(argv[9]); + X = std::stoi(argv[10]); + Hi = std::stoi(argv[11]); + Wi = std::stoi(argv[12]); + conv_stride_h = std::stoi(argv[13]); + conv_stride_w = std::stoi(argv[14]); + conv_dilation_h = std::stoi(argv[15]); + conv_dilation_w = std::stoi(argv[16]); + in_left_pad_h = std::stoi(argv[17]); + in_left_pad_w = std::stoi(argv[18]); + in_right_pad_h = std::stoi(argv[19]); + in_right_pad_w = std::stoi(argv[20]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: is show log (0=no, 1=yes)\n"); + printf("arg5: split-k \n"); + printf("arg6 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(0); + } + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + const std::vector conv_filter_strides{{conv_stride_h, conv_stride_w}}; + const std::vector conv_filter_dilations{{conv_dilation_h, conv_dilation_w}}; + const std::vector input_left_pads{{in_left_pad_h, in_left_pad_w}}; + const std::vector input_right_pads{{in_right_pad_h, in_right_pad_w}}; + + // tensor layout + auto f_host_tensor_descriptor = [](std::size_t N_, + std::size_t C_, + std::size_t H, + std::size_t W, + auto layout) { + if constexpr(ck::is_same::value || + ck::is_same::value || + ck::is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(ck::is_same::value || + ck::is_same::value || + ck::is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x_host_result(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor wei_k_c_y_x_device_result( + f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + } + wei_k_c_y_x_device_result.GenerateTensorValue(GeneratorTensor_1{0}); + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * + wei_k_c_y_x_device_result.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x_device_result.mData.data()); + + // do GEMM + auto conv = DeviceConvBwdWeightInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + std::vector{{Hi, Wi}}, + std::vector{{Y, X}}, + std::vector{{Ho, Wo}}, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}, + split_k); + + if(!conv.IsSupportedArgument(argument)) + { + std::cout << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + return 1; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto ref_conv = ReferenceConvBwdWeightInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x_host_result, + out_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + + wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data()); + + if(do_log) + { + LogRangeAsType(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",") + << std::endl; + } + return ck::utils::check_err(wei_k_c_y_x_device_result.mData, wei_k_c_y_x_host_result.mData) + ? 0 + : 1; + } + return 0; +} diff --git a/example/12_reduce/CMakeLists.txt b/example/12_reduce/CMakeLists.txt new file mode 100644 index 00000000000..9045a78a85b --- /dev/null +++ b/example/12_reduce/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_reduce_blockwise reduce_blockwise.cpp) +add_example_executable(example_reduce_blockwise_two_call reduce_blockwise_two_call.cpp) diff --git a/example/12_reduce/README.md b/example/12_reduce/README.md new file mode 100644 index 00000000000..a6442984e7c --- /dev/null +++ b/example/12_reduce/README.md @@ -0,0 +1,42 @@ +# Instructions for ```example_reduce_blockwise``` + +## Run ```example_reduce_blockwise``` +```bash +# -D : input 4-d tensor lengths +# -v : verification (0=no, 1=yes) +#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg2: time kernel (0=no, 1=yes) +./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1 +``` + +Result +``` +./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1 +launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 0.282592 ms, 222.641 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1> +``` + +# Instructions for ```example_reduce_blockwise_two_call``` + +## Run ```example_reduce_blockwise_two_call``` +```bash +#arg1: verification (0=no, 1=yes( +#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg3: time kernel (0=no, 1=yes) +./bin/example_reduce_blockwise_two_call 1 2 1 + + +Result +``` +./bin/example_reduce_blockwise_two_call 1 2 1 +launch_and_time_kernel: grid_dim {204800, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +launch_and_time_kernel: grid_dim {6400, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 2.1791 ms, 771.42 GB/s, DeviceReduceBlockWise<256,M_C32_S1,K_C8_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1> => DeviceReduceBlockWise<256,M_C256_S1,K_C1_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1> +``` + diff --git a/example/12_reduce/reduce_blockwise.cpp b/example/12_reduce/reduce_blockwise.cpp new file mode 100644 index 00000000000..e1e3afc58a6 --- /dev/null +++ b/example/12_reduce/reduce_blockwise.cpp @@ -0,0 +1,331 @@ +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_base.hpp" +#include "device_reduce_multiblock.hpp" +#include "host_common_util.hpp" +#include "host_reduction.hpp" + +#include "reduction_enums.hpp" +#include "reduction_operator_mapping.hpp" + +using namespace ck; +using namespace ck::tensor_operation::device; + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +constexpr int Rank = 4; +constexpr int NumReduceDim = 3; + +constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2; +constexpr bool PropagateNan = true; +constexpr bool OutputIndex = false; + +using ReduceOperation = typename reduce_binary_operator::opType; +using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; +using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + +using DeviceReduceInstance = DeviceReduceMultiBlock; + +static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class SimpleAppArgs +{ + private: + int option_index = 0; + + public: + std::vector inLengths = {16, 64, 32, 960}; + std::vector scales = {1.0f, 0.0f}; + + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + public: + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths" + << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by " + "comparing with the host-based reduction" + << std::endl; + std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer " + "value, 3=decimal value)" + << std::endl; + std::cout << "Arg2 -- time kernel (0=no, 1=yes)" << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:v:l:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 2 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind])); + + if(scales.empty()) + { + scales.push_back(1.0f); + scales.push_back(0.0f); + }; + + return (0); + }; +}; + +int main(int argc, char* argv[]) +{ + using namespace ck::host_reduce; + + const std::vector reduceDims{0, 1, 2}; + const std::vector invariantDims{3}; + + SimpleAppArgs args; + + if(argc > 1) + { + if(args.processArgs(argc, argv) < 0) + return (-1); + }; + + constexpr bool op_support_indices = + (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || + ReduceOpId == ReduceTensorOp::AMAX); + + // if input is half type, no reason to use float for indiced reduction operation and must use + // float for non-indiced reduction operation for accuracy + constexpr bool invalid_reduce_1 = + std::is_same::value && + ((!op_support_indices && !std::is_same::value) || + (op_support_indices && !std::is_same::value)); + + // if input is float type, no reason to use double for indiced reduction operation + constexpr bool invalid_reduce_2 = + std::is_same::value && + (op_support_indices && !std::is_same::value); + + // indices option can only be used when it is really needed + constexpr bool invalid_reduce_3 = (!op_support_indices && OutputIndex); + + constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3); + + if constexpr(invalid_reduce) + std::cout << "Reduction setting is not supported, exiting!" << std::endl; + + Tensor in(args.inLengths); + + std::vector outLengths; + + if(invariantDims.empty()) + outLengths.push_back(1); + else + for(auto dim : invariantDims) + outLengths.push_back(args.inLengths[dim]); + + Tensor out_ref(outLengths); + Tensor out(outLengths); + Tensor out_indices_ref(outLengths); + Tensor out_indices(outLengths); + + auto inStrides = in.mDesc.GetStrides(); + auto outStrides = out.mDesc.GetStrides(); + + size_t invariant_total_length = out.mDesc.GetElementSize(); + size_t reduce_total_length = in.mDesc.GetElementSize() / invariant_total_length; + + float alpha = args.scales[0]; + float beta = args.scales[1]; + + std::size_t num_thread = 1; + + if(args.do_verification) + { + switch(args.init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + } + + if(beta != 0.0f) + for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++) + out.mData[i] = out_ref.mData[i]; + }; + + // these buffers are usually provided by the user application + DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace()); + DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace()); + + in_dev.ToDevice(in.mData.data()); + + if(beta != 0.0f) + out_dev.ToDevice(out.mData.data()); + + size_t indicesSizeInBytes = OutputIndex ? out.mDesc.GetElementSize() * sizeof(int32_t) : 0; + + DeviceMem out_index_dev(indicesSizeInBytes); + + if(args.do_verification) + { + ReductionHost + hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims); + + hostReduce.Run( + alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data()); + }; + + std::vector i_inLengths; + std::vector i_inStrides; + std::vector i_outLengths; + std::vector i_outStrides; + + i_inLengths.assign(args.inLengths.begin(), args.inLengths.end()); + i_inStrides.assign(inStrides.begin(), inStrides.end()); + i_outLengths.assign(outLengths.begin(), outLengths.end()); + i_outStrides.assign(outStrides.begin(), outStrides.end()); + + auto reduce = DeviceReduceInstance{}; + + auto argument_ptr = reduce.MakeArgumentPointer( + i_inLengths, + i_inStrides, + i_outLengths, + i_outStrides, + reduceDims, + alpha, + beta, + in_dev.GetDeviceBuffer(), + nullptr, + out_dev.GetDeviceBuffer(), + out_index_dev.GetDeviceBuffer(), + InElementwiseOperation{static_cast(reduce_total_length)}, + AccElementwiseOperation{static_cast(reduce_total_length)}); + + if(!reduce.IsSupportedArgument(argument_ptr.get())) + { + std::cout + << "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" + << std::endl; + }; + + std::string reduce_name = reduce.GetTypeString(); + + auto invoker_ptr = reduce.MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, args.time_kernel}); + + std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InDataType) + + invariant_total_length * sizeof(OutDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << reduce_name + << std::endl; + + bool pass = true; + + if(args.do_verification) + { + out_dev.FromDevice(out.mData.data()); + pass = pass && ck::utils::check_err(out.mData, out_ref.mData); + + if(OutputIndex) + { + out_index_dev.FromDevice(out_indices.mData.data()); + pass = pass && ck::utils::check_err(out_indices.mData, out_indices_ref.mData); + }; + }; + + return (pass ? 0 : 1); +} diff --git a/example/12_reduce/reduce_blockwise_two_call.cpp b/example/12_reduce/reduce_blockwise_two_call.cpp new file mode 100644 index 00000000000..cd166c40fe6 --- /dev/null +++ b/example/12_reduce/reduce_blockwise_two_call.cpp @@ -0,0 +1,290 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_base.hpp" +#include "device_reduce_multiblock.hpp" +#include "host_common_util.hpp" +#include "host_reduction.hpp" + +#include "reduction_enums.hpp" +#include "reduction_operator_mapping.hpp" + +using namespace ck; +using namespace ck::tensor_operation::device; + +using InOutDataType = ck::half_t; +using InOutDataType = ck::half_t; +using AccDataType = float; + +constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2; +constexpr bool PropagateNan = true; +constexpr bool OutputIndex = false; + +using ReduceOperation = typename reduce_binary_operator::opType; +using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; +using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + +using PassThroughOp = tensor_operation::element_wise::UnaryIdentic; + +using DeviceReduceInstance_1 = DeviceReduceMultiBlock; + +using DeviceReduceInstance_2 = DeviceReduceMultiBlock; + +static bool do_verify; +static int init_method; +static float alpha; +static float beta; +static bool time_kernel; + +int main(int argc, char* argv[]) +{ + // used by the device reduction + const std::vector reduceDims_1 = {4}; + const std::vector invariantDims_1 = {0, 1, 2, 3}; + + const std::vector reduceDims_2 = {3}; + const std::vector invariantDims_2 = {0, 1, 2}; + + // used by the host reduction + const std::vector reduceDims = {3, 4}; + const std::vector invariantDims = {0, 1, 2}; + + const std::vector inLengths_1 = {64, 320, 80, 4, 128}; + + // input lengths of the second reduction, which is also the output lengths of the first + // reduction + const std::vector inLengths_2 = {64, 320, 80, 4}; + + const std::vector outLengths = {64, 320, 80}; + + using namespace ck::host_reduce; + + if(argc == 1) + { + do_verify = true; + init_method = 2; + time_kernel = true; + } + else if(argc == 4) + { + do_verify = static_cast(argv[1]); + init_method = atoi(argv[2]); + time_kernel = static_cast(atoi(argv[3])); + } + else + { + std::ostringstream ostr; + + ostr << "Wrong parameter! " << std::endl + << "Usage: " << argv[0] << "[verify 0/1] init_method time_kernel" << std::endl; + + throw std::runtime_error(ostr.str()); + }; + + alpha = 1.0f; + beta = 0.0f; + + Tensor in_1(inLengths_1); + + Tensor out_ref(outLengths); + Tensor in_2(inLengths_2); // also the output tensor of the first reduction + Tensor out(outLengths); + + auto inStrides_1 = in_1.mDesc.GetStrides(); + auto inStrides_2 = in_2.mDesc.GetStrides(); + auto outStrides = out.mDesc.GetStrides(); + + size_t invariant_total_length = out.mDesc.GetElementSize(); + size_t reduce_total_length = in_1.mDesc.GetElementSize() / invariant_total_length; + + std::size_t num_thread = 1; + + if(do_verify) + { + switch(init_method) + { + case 0: break; + case 1: + in_1.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + in_1.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + in_1.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, + num_thread); + } + + if(beta != 0.0f) + for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++) + out.mData[i] = out_ref.mData[i]; + }; + + DeviceMem in_1_dev(sizeof(InOutDataType) * in_1.mDesc.GetElementSpace()); + DeviceMem in_2_dev(sizeof(InOutDataType) * in_2.mDesc.GetElementSpace()); + DeviceMem out_dev(sizeof(InOutDataType) * out.mDesc.GetElementSpace()); + + in_1_dev.ToDevice(in_1.mData.data()); + + if(beta != 0.0f) + out_dev.ToDevice(out.mData.data()); + + if(do_verify) + { + ReductionHost + hostReduce(in_1.mDesc, out_ref.mDesc, invariantDims, reduceDims); + + hostReduce.Run(alpha, in_1.mData.data(), beta, out_ref.mData.data(), nullptr); + }; + + std::vector i_inLengths_1; + std::vector i_inStrides_1; + std::vector i_inLengths_2; + std::vector i_inStrides_2; + std::vector i_outLengths; + std::vector i_outStrides; + + i_inLengths_1.assign(inLengths_1.begin(), inLengths_1.end()); + i_inStrides_1.assign(inStrides_1.begin(), inStrides_1.end()); + i_inLengths_2.assign(inLengths_2.begin(), inLengths_2.end()); + i_inStrides_2.assign(inStrides_2.begin(), inStrides_2.end()); + i_outLengths.assign(outLengths.begin(), outLengths.end()); + i_outStrides.assign(outStrides.begin(), outStrides.end()); + + auto reduce_1 = DeviceReduceInstance_1{}; + + auto argument_ptr_1 = reduce_1.MakeArgumentPointer( + i_inLengths_1, + i_inStrides_1, + i_inLengths_2, + i_inStrides_2, + reduceDims_1, + 1.0f, + 0.0f, + in_1_dev.GetDeviceBuffer(), + nullptr, + in_2_dev.GetDeviceBuffer(), + nullptr, + InElementwiseOperation{static_cast(reduce_total_length)}, + PassThroughOp{}); + + if(!reduce_1.IsSupportedArgument(argument_ptr_1.get())) + { + std::cout + << "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" + << std::endl; + }; + + auto invoker_ptr_1 = reduce_1.MakeInvokerPointer(); + + auto reduce_2 = DeviceReduceInstance_2{}; + + auto argument_ptr_2 = reduce_2.MakeArgumentPointer( + i_inLengths_2, + i_inStrides_2, + i_outLengths, + i_outStrides, + reduceDims_2, + alpha, + beta, + in_2_dev.GetDeviceBuffer(), + nullptr, + out_dev.GetDeviceBuffer(), + nullptr, + PassThroughOp{}, + AccElementwiseOperation{static_cast(reduce_total_length)}); + + if(!reduce_2.IsSupportedArgument(argument_ptr_2.get())) + { + std::cout + << "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" + << std::endl; + }; + + auto invoker_ptr_2 = reduce_2.MakeInvokerPointer(); + + float avg_time_1 = invoker_ptr_1->Run(argument_ptr_1.get(), StreamConfig{nullptr, time_kernel}); + float avg_time_2 = invoker_ptr_2->Run(argument_ptr_2.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InOutDataType) + + invariant_total_length * sizeof(InOutDataType); + + float gb_per_sec = num_bytes / 1.E6 / (avg_time_1 + avg_time_2); + + std::cout << "Perf: " << avg_time_1 + avg_time_2 << " ms, " << gb_per_sec << " GB/s, " + << reduce_1.GetTypeString() << " => " << reduce_2.GetTypeString() << std::endl; + + bool pass = true; + + if(do_verify) + { + out_dev.FromDevice(out.mData.data()); + pass = pass && ck::utils::check_err(out.mData, out_ref.mData); + }; + + return (pass ? 0 : 1); +} diff --git a/example/13_pool2d_fwd/CMakeLists.txt b/example/13_pool2d_fwd/CMakeLists.txt new file mode 100644 index 00000000000..1fdeb4c5858 --- /dev/null +++ b/example/13_pool2d_fwd/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_pool2d_fwd pool2d_fwd.cpp) diff --git a/example/13_pool2d_fwd/README.md b/example/13_pool2d_fwd/README.md new file mode 100644 index 00000000000..2314cfd6701 --- /dev/null +++ b/example/13_pool2d_fwd/README.md @@ -0,0 +1,20 @@ +# Instructions for ```example_pool2d_fwd``` Example + +## Run ```example_pool2d_fwd``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg3: time kernel (0=no, 1=yes) +#arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, RightPx +./bin/example_pool2d_fwd 1 1 1 +``` + +Result +``` +in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +out_n_c_ho_wo: dim 4, lengths {128, 192, 36, 36}, strides {248832, 1, 6912, 192} +launch_and_time_kernel: grid_dim {124416, 1, 1}, block_dim {64, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 0.397436 ms, 1.44252 TFlops, 783.713 GB/s +``` diff --git a/example/13_pool2d_fwd/pool2d_fwd.cpp b/example/13_pool2d_fwd/pool2d_fwd.cpp new file mode 100644 index 00000000000..662a48500f5 --- /dev/null +++ b/example/13_pool2d_fwd/pool2d_fwd.cpp @@ -0,0 +1,329 @@ +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_reduce_util.hpp" +#include "device_tensor.hpp" +#include "tensor_layout.hpp" +#include "reduction_operator.hpp" +#include "device_pool2d_fwd_nhwc_nhwc.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +using IndexDataType = int32_t; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using OutLayout = ck::tensor_layout::convolution::NHWC; + +#if 1 +static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; +#else +static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; +#endif + +static constexpr bool OutputIndex = false; +static constexpr bool PropagateNan = false; + +using DevicePoolFwdInstance = + ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C< + InDataType, // InDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + ReduceOpId, + OutputIndex, + 64, // BlockSize + 64, // ReduceMThreadClusterSize + 1, // ReduceKThreadClusterSize + 4, // ReduceMThreadSliceSize + 1, // ReduceKThreadSliceSize + 4>; // InSrcOutDstVectorSize + +template +static void pool_host_verify(const Tensor& in, + Tensor& out, + Tensor& out_indices, + const std::array& window_spatial_lengths, + const std::array& window_strides, + const std::array& in_left_pads, + const std::array& /*in_right_pads*/) +{ + using namespace ck::host_reduce; + + const int32_t divider = window_spatial_lengths[0] * window_spatial_lengths[1]; + + const auto PreUnaryOp = PreUnaryOpFn(divider); + const auto PosUnaryOp = PosUnaryOpFn(divider); + + if constexpr(!OutputIndex) + { + auto opReduce = ReduceOpFn(); + + auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { + auto accuVal = ReduceOpZeroVal(); + + for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) + { + ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0]; + for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x) + { + ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1]; + if(hi >= 0 && hi < static_cast(in.mDesc.GetLengths()[2]) && + wi >= 0 && wi < static_cast(in.mDesc.GetLengths()[3])) + { + AccDataType currVal = static_cast(in(n, c, hi, wi)); + + PreUnaryOp(currVal); + + binop_with_nan_check(opReduce, accuVal, currVal); + } + } + } + + PosUnaryOp(accuVal); + + out(n, c, ho, wo) = accuVal; + }; + + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else + { + auto opReduce = ReduceOpFn2(); + + auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { + auto accuVal = ReduceOpZeroVal(); + IndexDataType accuIndex = 0; + + for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) + { + ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0]; + for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x) + { + ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + AccDataType currVal = static_cast(in(n, c, hi, wi)); + IndexDataType currIndex = y * window_spatial_lengths[1] + x; + + PreUnaryOp(currVal); + + binop_with_index_and_nan_check( + opReduce, accuVal, currVal, accuIndex, currIndex); + } + } + } + + PosUnaryOp(accuVal); + + out(n, c, ho, wo) = accuVal; + out_indices(n, c, ho, wo) = accuIndex; + }; + + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + }; +} + +int main(int argc, char* argv[]) +{ + using namespace ck::host_reduce; + + bool do_verification; + int init_method; + bool time_kernel; + + // Pool shape + ck::index_t N = 128; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + if(argc == 1) + { + do_verification = true; + init_method = 1; + time_kernel = true; + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = static_cast(std::stoi(argv[3])); + } + else if(argc == 16) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = static_cast(std::stoi(argv[3])); + + N = std::stoi(argv[4]); + C = std::stoi(argv[5]); + Y = std::stoi(argv[6]); + X = std::stoi(argv[7]); + Hi = std::stoi(argv[8]); + Wi = std::stoi(argv[9]); + window_stride_h = std::stoi(argv[10]); + window_stride_w = std::stoi(argv[11]); + in_left_pad_h = std::stoi(argv[12]); + in_left_pad_w = std::stoi(argv[13]); + in_right_pad_h = std::stoi(argv[14]); + in_right_pad_w = std::stoi(argv[15]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(0); + } + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; + + const std::array window_spatial_lengths{{Y, X}}; + const std::array window_strides{{window_stride_h, window_stride_w}}; + const std::array input_left_pads{{in_left_pad_h, in_left_pad_w}}; + const std::array input_right_pads{{in_right_pad_h, in_right_pad_w}}; + + // tensor layout + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { + if constexpr(ck::is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(ck::is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor out_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); + Tensor out_indices_n_c_ho_wo_host( + f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); + Tensor out_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); + Tensor out_indices_n_c_ho_wo_device( + f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "out_n_c_ho_wo: " << out_n_c_ho_wo_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); break; + case 2: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; + default: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_c_ho_wo_device.mDesc.GetElementSpace()); + DeviceMem out_indices_device_buf(sizeof(IndexDataType) * + out_indices_n_c_ho_wo_device.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + + auto pool = DevicePoolFwdInstance{}; + auto invoker_ptr = pool.MakeInvokerPointer(); + auto argument_ptr = pool.MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(out_indices_device_buf.GetDeviceBuffer()), + N, + C, + std::array{{Hi, Wi}}, + std::array{{Y, X}}, + std::array{{Ho, Wo}}, + window_strides, + input_left_pads, + input_right_pads); + + if(!pool.IsSupportedArgument(argument_ptr.get())) + { + throw std::runtime_error("wrong! device_op with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * N * C * Ho * Wo * Y * X; + + std::size_t num_btype = + sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(OutDataType) * (N * C * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + pool_host_verify(in_n_c_hi_wi, + out_n_c_ho_wo_host, + out_indices_n_c_ho_wo_host, + window_spatial_lengths, + window_strides, + input_left_pads, + input_right_pads); + + out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); + + pass = pass && ck::utils::check_err(out_n_c_ho_wo_device.mData, out_n_c_ho_wo_host.mData); + + if constexpr(OutputIndex) + { + out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data()); + + pass = pass && ck::utils::check_err(out_indices_n_c_ho_wo_device.mData, + out_indices_n_c_ho_wo_host.mData); + }; + } + + return (pass ? 0 : 1); +} diff --git a/example/14_gemm_xdl_requant_relu_requant/CMakeLists.txt b/example/14_gemm_xdl_requant_relu_requant/CMakeLists.txt new file mode 100644 index 00000000000..0f5b8e1bc72 --- /dev/null +++ b/example/14_gemm_xdl_requant_relu_requant/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_xdl_requant_relu_requant_int8 gemm_xdl_requant_relu_requant_int8.cpp) \ No newline at end of file diff --git a/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp b/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp new file mode 100644 index 00000000000..9f6408a84ae --- /dev/null +++ b/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp @@ -0,0 +1,251 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +struct RequantReluRequant +{ + // FIXME: We just need one scale for Relu / Leaky Relu / PRelu + RequantReluRequant(float scaleGemm, float scaleRelu) + : scaleGemm_(scaleGemm), scaleRelu_(scaleRelu) + { + } + + __host__ __device__ constexpr void operator()(float& y, const float& x) const + { + float gemm_requant = scaleGemm_ * x; + float relu = gemm_requant > 0 ? gemm_requant : 0; + float relu_requant = scaleRelu_ * relu; + y = relu_requant > 127 ? 127 : relu_requant < -128 ? -128 : relu_requant; + } + + float scaleGemm_; + float scaleRelu_; +}; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using BDataType = int8_t; +using CDataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle< + ALayout, // typename ALayout, + BLayout, // typename BLayout, + CLayout, // typename CLayout, + ADataType, // typename ADataType, + BDataType, // typename BDataType, + CDataType, // typename CDataType, + AccDataType, // typename GemmAccDataType, + CShuffleDataType, // typename CShuffleDataType, + PassThrough, // typename AElementwiseOperation, + PassThrough, // typename BElementwiseOperation, + RequantReluRequant, // typename CElementwiseOperation, + GemmDefault, // GemmSpecialization GemmSpec, + 1, // index_t NumGemmKPrefetchStage, + 256, // index_t BlockSize, + 256, // index_t MPerBlock, + 128, // index_t NPerBlock, + 64, // index_t KPerBlock, + 16, // index_t AK1, + 16, // index_t BK1, + 32, // index_t MPerXDL, + 32, // index_t NPerXDL, + 4, // index_t MXdlPerWave, + 2, // index_t NXdlPerWave, + S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1, + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder, + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder, + 2, // index_t ABlockTransferSrcVectorDim, + 16, // index_t ABlockTransferSrcScalarPerVector, + 16, // index_t ABlockTransferDstScalarPerVector_AK1, + 1, // bool ABlockLdsExtraM, + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder, + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder, + 2, // index_t BBlockTransferSrcVectorDim, + 8, // index_t BBlockTransferSrcScalarPerVector, + 8, // index_t BBlockTransferDstScalarPerVector_BK1, + 1, // bool BBlockLdsExtraN, + 1, // index_t CShuffleMXdlPerWavePerShuffle, + 1, // index_t CShuffleNXdlPerWavePerShuffle, + S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock> +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + float scale_gemm = 0.03; + float scale_relu = 1; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto c_element_op = RequantReluRequant{scale_gemm, scale_relu}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt new file mode 100644 index 00000000000..a8cac069306 --- /dev/null +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) diff --git a/example/15_grouped_gemm/README.md b/example/15_grouped_gemm/README.md new file mode 100644 index 00000000000..c83b23e08cc --- /dev/null +++ b/example/15_grouped_gemm/README.md @@ -0,0 +1,25 @@ +# Instructions for ```example_grouped_gemm_xdl``` + +## Run ```example_grouped_gemm_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +./bin/example_grouped_gemm_xdl_fp16 0 1 5 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +gemm[0] a_m_k: dim 2, lengths {256, 64}, strides {64, 1} b_k_n: dim 2, lengths {64, 128}, strides {1, 64} c_m_n: dim 2, lengths {256, 128}, strides {128, 1} +gemm[1] a_m_k: dim 2, lengths {512, 128}, strides {128, 1} b_k_n: dim 2, lengths {128, 256}, strides {1, 128} c_m_n: dim 2, lengths {512, 256}, strides {256, 1} +gemm[2] a_m_k: dim 2, lengths {768, 192}, strides {192, 1} b_k_n: dim 2, lengths {192, 384}, strides {1, 192} c_m_n: dim 2, lengths {768, 384}, strides {384, 1} +gemm[3] a_m_k: dim 2, lengths {1024, 256}, strides {256, 1} b_k_n: dim 2, lengths {256, 512}, strides {1, 256} c_m_n: dim 2, lengths {1024, 512}, strides {512, 1} +group: 0 arg.a_grid_desc_k0_m_k1_{8, 256, 8}, arg.b_grid_desc_k0_n_k1_{8, 128, 8}, arg.c_grid_desc_m_n_{ 256, 128} +group: 1 arg.a_grid_desc_k0_m_k1_{16, 512, 8}, arg.b_grid_desc_k0_n_k1_{16, 256, 8}, arg.c_grid_desc_m_n_{ 512, 256} +group: 2 arg.a_grid_desc_k0_m_k1_{24, 768, 8}, arg.b_grid_desc_k0_n_k1_{24, 384, 8}, arg.c_grid_desc_m_n_{ 768, 384} +group: 3 arg.a_grid_desc_k0_m_k1_{32, 1024, 8}, arg.b_grid_desc_k0_n_k1_{32, 512, 8}, arg.c_grid_desc_m_n_{ 1024, 512} +launch_and_time_kernel: grid_dim {30, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 0.037887 ms, 11.0706 TFlops, 90.8132 GB/s, DeviceGroupedGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> +``` diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp new file mode 100644 index 00000000000..8c3491c8c9f --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp @@ -0,0 +1,236 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_grouped_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +// static constexpr auto GemmMNPadding = +// ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + exit(0); + } + + int group_count = 4; + + // GEMM shape + std::vector gemm_shapes; + std::vector p_a, p_b; + std::vector p_c; + + gemm_shapes.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + int M = 256 + 256 * i; + int N = 128 + 128 * i; + int K = 64 + 64 * i; + + gemm_shapes.push_back({M, N, K, K, K, N}); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + std::vector> a_tensors; + ; + std::vector> b_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(std::size_t i = 0; i < gemm_shapes.size(); i++) + { + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].M, gemm_shapes[i].K, gemm_shapes[i].StrideA, ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].K, gemm_shapes[i].N, gemm_shapes[i].StrideB, BLayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); + + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc + << std::endl; + + flop += std::size_t(2) * gemm_shapes[i].M * gemm_shapes[i].K * gemm_shapes[i].N; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + } + + for(std::size_t i = 0; i < gemm_shapes.size(); i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpace())); + b_tensors_device.emplace_back( + std::make_unique(sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpace())); + c_tensors_device.emplace_back(std::make_unique( + sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSpace())); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + + p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_b.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = + gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + bool pass = true; + if(do_verification) + { + for(std::size_t i = 0; i < gemm_shapes.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data()); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + pass &= ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData); + } + } + + return pass ? 0 : 1; +} diff --git a/example/16_gemm_reduce/CMakeLists.txt b/example/16_gemm_reduce/CMakeLists.txt new file mode 100644 index 00000000000..5441247a56b --- /dev/null +++ b/example/16_gemm_reduce/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_gemm_reduce_xdl_max_fp16 gemm_reduce_xdl_max_fp16.cpp) +add_example_executable(example_gemm_reduce_xdl_sum_squaresum_fp16 gemm_reduce_xdl_sum_squaresum_fp16.cpp) diff --git a/example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp b/example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp new file mode 100644 index 00000000000..4d837c4675c --- /dev/null +++ b/example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp @@ -0,0 +1,249 @@ +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" +#include "element_wise_reduce_operation.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; +using F64 = double; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using ReduceAccDataType = F32; +using DDataType = F64; +using DPtrsGlobal = ck::Tuple; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using DsReduceOp = ck::Tuple>; +using DsElementOp = ck::Tuple< + ck::tensor_operation::element_wise::UnaryIdentic>; +using DGlobalMemOp = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmSpecialization = + ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DsReduceOp, DsElementOp, DsElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor d_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor d_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "d_m: " << d_m_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + auto ds_element_op = DsElementOp{}; + auto p_ds_global = ck::make_tuple(static_cast(d_device_buf.GetDeviceBuffer())); + + // do GEMM + auto gemm = DeviceGemmReduceInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + p_ds_global, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + ds_element_op, + ds_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + // init D + d_device_buf.SetValue(ck::NumericLimits::Lowest()); + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + bool pass = true; + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + d_device_buf.FromDevice(d_m_device_result.mData.data()); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + auto d_reduce_op = DsReduceOp{}[ck::Number<0>{}]; + + for(int m = 0; m < M; ++m) + { + ReduceAccDataType d_acc = d_reduce_op.GetReductionZeroVal(); + + for(int n = 0; n < N; ++n) + d_reduce_op(d_acc, c_m_n_host_result(m, n)); + + d_m_host_result(m) = d_acc; + } + + pass = ck::utils::check_err(c_m_n_device_result.mData, + c_m_n_host_result.mData, + "Error: Incorrect results c") && + ck::utils::check_err(d_m_device_result.mData, + d_m_host_result.mData, + "Error: Incorrect results d", + 1e-3, + 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/16_gemm_reduce/gemm_reduce_xdl_sum_squaresum_fp16.cpp b/example/16_gemm_reduce/gemm_reduce_xdl_sum_squaresum_fp16.cpp new file mode 100644 index 00000000000..dff9c02f449 --- /dev/null +++ b/example/16_gemm_reduce/gemm_reduce_xdl_sum_squaresum_fp16.cpp @@ -0,0 +1,285 @@ +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reduction_operator.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" +#include "reduction_operator.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using ReduceAccDataType = F32; +using DDataType = F32; +using DPtrsGlobal = ck::Tuple; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using D0ReduceOp = ck::reduce::Add; +using D1ReduceOp = ck::reduce::Add; +using DxsReduceOp = ck::Tuple; + +using UnaryIdenticElementOp = + ck::tensor_operation::element_wise::UnaryIdentic; +using UnarySquareElementOp = + ck::tensor_operation::element_wise::UnarySquare; +using DxsInElementOp = ck::Tuple; +using DxsOutElementOp = ck::Tuple; + +using DGlobalMemOp = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmSpecialization = + ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor d0_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor d1_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor d0_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor d1_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl; + std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace()); + DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + auto dxs_global = ck::make_tuple(static_cast(d0_device_buf.GetDeviceBuffer()), + static_cast(d1_device_buf.GetDeviceBuffer())); + + // do GEMM + auto gemm = DeviceGemmReduceInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + dxs_global, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + DxsInElementOp{}, + DxsOutElementOp{}); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + // init DO, D1 to 0 + d0_device_buf.SetZero(); + d1_device_buf.SetZero(); + + // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result + // will not be correct. need to set time_kernel = false for correctness test + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + bool pass = true; + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + d0_device_buf.FromDevice(d0_m_device_result.mData.data()); + d1_device_buf.FromDevice(d1_m_device_result.mData.data()); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + auto d0_reduce_op = D0ReduceOp{}; + auto d1_reduce_op = D1ReduceOp{}; + + for(int m = 0; m < M; ++m) + { + float d0_acc = d0_reduce_op.GetReductionZeroVal(); + float d1_acc = d1_reduce_op.GetReductionZeroVal(); + + for(int n = 0; n < N; ++n) + { + float c_val = ck::type_convert(c_m_n_host_result(m, n)); + float d0_val = 0; + float d1_val = 0; + + UnaryIdenticElementOp{}(d0_val, c_val); + UnarySquareElementOp{}(d1_val, c_val); + d0_reduce_op(d0_acc, d0_val); + d1_reduce_op(d1_acc, d1_val); + } + + d0_m_host_result(m) = ck::type_convert(d0_acc); + d1_m_host_result(m) = ck::type_convert(d1_acc); + } + + pass = ck::utils::check_err(c_m_n_device_result.mData, + c_m_n_host_result.mData, + "Error: Incorrect results c") && + ck::utils::check_err(d0_m_device_result.mData, + d0_m_host_result.mData, + "Error: Incorrect results d0", + 1e-4, + 1e-5) && + ck::utils::check_err(d1_m_device_result.mData, + d1_m_host_result.mData, + "Error: Incorrect results d1", + 1e-3, + 1e-5); + } + + return pass ? 0 : 1; +} diff --git a/example/17_convnd_bwd_data_xdl/CMakeLists.txt b/example/17_convnd_bwd_data_xdl/CMakeLists.txt new file mode 100644 index 00000000000..963f3117034 --- /dev/null +++ b/example/17_convnd_bwd_data_xdl/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp) +target_link_libraries(example_convnd_bwd_data_xdl PRIVATE conv_util) diff --git a/example/17_convnd_bwd_data_xdl/README.md b/example/17_convnd_bwd_data_xdl/README.md new file mode 100644 index 00000000000..b5c8281ed8a --- /dev/null +++ b/example/17_convnd_bwd_data_xdl/README.md @@ -0,0 +1,47 @@ +# Instructions for ```example_convnd_bwd_data_xdl``` + +## Run ```example_example_convnd_bwd_data_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4: num_dim_spatial(1|2|3) +#arg5 to ...: N, K, C, [Z,] [Y,] X, [Di,] [Hi,] Wi, S[z,] [Sy,] Sx, [Dz,] [Dy,] Dx, [LeftPz,] [LeftPy,] LeftPx, [RightPy,] [RightPy,] RightPx +./bin/example_convnd_bwd_data_xdl 0 1 5 +``` + +Result +``` +in_n_c_hi_wi: dim 4, lengths {128, 128, 71, 71}, strides {645248, 1, 9088, 128} +wei_k_c_y_x: dim 4, lengths {256, 128, 3, 3}, strides {1152, 1, 384, 128} +out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +arg.a_grid_desc_k0_m_k1_container_{128, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{128, 128, 8} +arg.c_grid_desc_m_n_container_{ 175232, 128} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{64, 128, 8} +arg.c_grid_desc_m_n_container_{ 175232, 128} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{64, 128, 8} +arg.c_grid_desc_m_n_container_{ 175232, 128} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +arg.a_grid_desc_k0_m_k1_container_{32, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{32, 128, 8} +arg.c_grid_desc_m_n_container_{ 175232, 128} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Perf: 1.40031 ms, 69.8734 TFlops, 179.037 GB/s +``` diff --git a/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp new file mode 100644 index 00000000000..ff2cfac1fa7 --- /dev/null +++ b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp @@ -0,0 +1,354 @@ +#include +#include +#include +#include +#include +#include + +#include "config.hpp" +#include "conv_util.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "tensor_layout.hpp" +#include "element_wise_operation.hpp" +#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "reference_conv_bwd_data.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +using DeviceConvBwdDataBasePtr = + ck::tensor_operation::device::DeviceConvBwdDataPtr; + +template +using DeviceConvNDBwdDataInstance = ck::tensor_operation::device:: + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdDefault, // ConvolutionBackwardDataSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 2, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, + 1>; // GemmCThreadTransferDstScalarPerVector + +template +using ReferenceConvBwdDataInstance = + ck::tensor_operation::host::ReferenceConvBwdData; + +void print_use_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + ck::utils::conv::ConvParams params; + int arg_idx = 5; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +DeviceConvBwdDataBasePtr get_conv_instance(int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 3: { + return std::make_unique>(); + } + case 2: { + return std::make_unique>(); + } + case 1: { + return std::make_unique>(); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int num_dim_spatial = 2; + + ck::utils::conv::ConvParams params; + params.C_ = 128; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc > 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + num_dim_spatial = std::stoi(argv[4]); + // check args number + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 5; + if(cmdline_nargs != argc) + { + print_use_msg(); + exit(1); + } + + params = parse_conv_params(num_dim_spatial, argv); + } + else if(argc != 1) + { + print_use_msg(); + exit(1); + } + + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor in_n_c_hi_wi_host_result( + ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor in_n_c_hi_wi_device_result( + ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor wei_k_c_y_x( + ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor out_n_k_ho_wo( + ck::utils::conv::get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_3{-0.2, 0.2}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.2, 0.2}); + break; + default: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * + in_n_c_hi_wi_device_result.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + // reset input to zero + in_device_buf.SetZero(); + + // do GEMM + auto conv = get_conv_instance(num_dim_spatial); + auto invoker = conv->MakeInvokerPointer(); + auto argument = + conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv->IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = ck::utils::conv::get_flops( + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); + std::size_t num_btype = ck::utils::conv::get_btype( + params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto verify_f = [&](const auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result, + wei_k_c_y_x, + out_n_k_ho_wo, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + + in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); + + return ck::utils::check_err(in_n_c_hi_wi_device_result.mData, + in_n_c_hi_wi_host_result.mData) + ? 0 + : 1; + }; + + switch(num_dim_spatial) + { + case 3: { + auto ref_conv = ReferenceConvBwdDataInstance<3>(); + verify_f(ref_conv); + break; + } + case 2: { + auto ref_conv = ReferenceConvBwdDataInstance<2>(); + verify_f(ref_conv); + break; + } + case 1: { + auto ref_conv = ReferenceConvBwdDataInstance<1>(); + verify_f(ref_conv); + break; + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } + } + return 0; +} diff --git a/example/18_batched_gemm_reduce/CMakeLists.txt b/example/18_batched_gemm_reduce/CMakeLists.txt new file mode 100644 index 00000000000..99fc0043d28 --- /dev/null +++ b/example/18_batched_gemm_reduce/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_batched_gemm_reduce_xdl_fp16 batched_gemm_reduce_xdl_fp16.cpp) + diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp new file mode 100644 index 00000000000..df63053c801 --- /dev/null +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -0,0 +1,298 @@ +#include +#include +#include +#include +#include +#include +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reduction_operator.hpp" +#include "reference_batched_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using ReduceAccDataType = F32; +using DDataType = F32; +using DPtrsGlobal = ck::Tuple; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using D0ReduceOp = ck::reduce::Add; +using D1ReduceOp = ck::reduce::Add; +using DxsReduceOp = ck::Tuple; + +using UnaryIdenticElementOp = + ck::tensor_operation::element_wise::UnaryIdentic; +using UnarySquareElementOp = + ck::tensor_operation::element_wise::UnarySquare; +using DxsInElementOp = ck::Tuple; +using DxsOutElementOp = ck::Tuple; + +using DGlobalMemOp = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmSpecialization = + ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatchedGemmReduce_Xdl_CShuffle +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: + ReferenceBatchedGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 2048; + ck::index_t N = 1920; + ck::index_t K = 2048; + + ck::index_t StrideA = 2048; + ck::index_t StrideB = 2048; + ck::index_t StrideC = 1920; + + ck::index_t BatchCount = 4; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + + BatchCount = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, BatchCount\n"); + exit(0); + } + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({row * stride, stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({col * stride, 1, stride})); + } + }; + + Tensor a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); + Tensor b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); + + Tensor c_g_m_n_host_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + Tensor d0_g_m_host_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + Tensor d1_g_m_host_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + + Tensor c_g_m_n_device_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + Tensor d0_g_m_device_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + Tensor d1_g_m_device_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; + std::cout << "d0_g_m: " << d0_g_m_host_result.mDesc << std::endl; + std::cout << "d1_g_m: " << d1_g_m_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem d0_device_buf(sizeof(DDataType) * d0_g_m_device_result.mDesc.GetElementSpace()); + DeviceMem d1_device_buf(sizeof(DDataType) * d1_g_m_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_g_m_k.mData.data()); + b_device_buf.ToDevice(b_g_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + auto dxs_global = ck::make_tuple(static_cast(d0_device_buf.GetDeviceBuffer()), + static_cast(d1_device_buf.GetDeviceBuffer())); + + // do GEMM + auto batched_gemm = DeviceBatchedGemmReduceInstance{}; + auto invoker = batched_gemm.MakeInvoker(); + auto argument = + batched_gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + dxs_global, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + DxsInElementOp{}, + DxsOutElementOp{}, + BatchCount); + + if(!batched_gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + // init DO, D1 to 0 + d0_device_buf.SetZero(); + d1_device_buf.SetZero(); + + // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result + // will not be correct. need to set time_kernel = false for correctness test + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * BatchCount * M * N * K; + std::size_t num_btype = sizeof(ADataType) * BatchCount * M * K + + sizeof(BDataType) * BatchCount * K * N + + sizeof(CDataType) * BatchCount * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << batched_gemm.GetTypeString() << std::endl; + + bool pass = true; + if(do_verification) + { + c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); + d0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); + d1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); + + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + auto d0_reduce_op = D0ReduceOp{}; + auto d1_reduce_op = D1ReduceOp{}; + + for(int batch = 0; batch < BatchCount; ++batch) + { + for(int m = 0; m < M; ++m) + { + float d0_acc = d0_reduce_op.GetReductionZeroVal(); + float d1_acc = d1_reduce_op.GetReductionZeroVal(); + + for(int n = 0; n < N; ++n) + { + float c_val = ck::type_convert(c_g_m_n_host_result(batch, m, n)); + float d0_val = 0; + float d1_val = 0; + + UnaryIdenticElementOp{}(d0_val, c_val); + UnarySquareElementOp{}(d1_val, c_val); + d0_reduce_op(d0_acc, d0_val); + d1_reduce_op(d1_acc, d1_val); + } + + d0_g_m_host_result(batch, m) = ck::type_convert(d0_acc); + d1_g_m_host_result(batch, m) = ck::type_convert(d1_acc); + } + } + + pass = ck::utils::check_err(c_g_m_n_host_result.mData, + c_g_m_n_device_result.mData, + "Error: Incorrect results c") && + ck::utils::check_err(d0_g_m_device_result.mData, + d0_g_m_host_result.mData, + "Error: Incorrect results! D0", + 1e-4, + 1e-5) && + ck::utils::check_err(d1_g_m_device_result.mData, + d1_g_m_host_result.mData, + "Error: Incorrect results! D1", + 1e-3, + 1e-5); + } + + return pass ? 0 : 1; +} diff --git a/example/19_binary_elementwise/CMakeLists.txt b/example/19_binary_elementwise/CMakeLists.txt new file mode 100644 index 00000000000..6c95b2e55e8 --- /dev/null +++ b/example/19_binary_elementwise/CMakeLists.txt @@ -0,0 +1,3 @@ +add_example_executable(example_broadcast_add_2d broadcast_add_2d.cpp) +add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp) +add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp) \ No newline at end of file diff --git a/example/19_binary_elementwise/broadcast_add_2d.cpp b/example/19_binary_elementwise/broadcast_add_2d.cpp new file mode 100644 index 00000000000..2a3ef421ff0 --- /dev/null +++ b/example/19_binary_elementwise/broadcast_add_2d.cpp @@ -0,0 +1,130 @@ +#include +#include +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" + +#include "device_tensor.hpp" +#include "binary_element_wise_operation.hpp" +#include "device_binary_elementwise.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ABDataType = F16; +using CDataType = F16; +using EltwiseComputeDataType = F32; + +using Add = ck::tensor_operation::binary_element_wise::Add; + +using DeviceElementwiseAddInstance = ck::tensor_operation::device:: + DeviceBinaryElementwise; + +template +void host_broadcast2D( + HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, int N, Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + ComputeDataType Amn = static_cast(A(m, n)); + ComputeDataType Cmn = 0; + if constexpr(broadcastDim == 0) + { + ComputeDataType Bn = static_cast(B(n)); + functor(Cmn, Amn, Bn); + } + else + { + ComputeDataType Bm = static_cast(B(m)); + functor(Cmn, Amn, Bm); + } + C(m, n) = static_cast(Cmn); + } + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t Stride = 1024; + + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); + }; + + auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + }; + + Tensor a_m_n(f_host_tensor_descriptor2d(M, N, Stride)); + Tensor b_n(f_host_tensor_descriptor1d(N, 1)); + Tensor c_m_n(f_host_tensor_descriptor2d(M, N, Stride)); + + a_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_m_n_device_buf(sizeof(ABDataType) * a_m_n.mDesc.GetElementSpace()); + DeviceMem b_n_device_buf(sizeof(ABDataType) * b_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); + + a_m_n_device_buf.ToDevice(a_m_n.mData.data()); + b_n_device_buf.ToDevice(b_n.mData.data()); + + auto broadcastAdd = DeviceElementwiseAddInstance{}; + auto argument = broadcastAdd.MakeArgumentPointer(a_m_n_device_buf.GetDeviceBuffer(), + b_n_device_buf.GetDeviceBuffer(), + c_m_n_device_buf.GetDeviceBuffer(), + {M, N}, + {Stride, 1}, + {0, 1}, // broadcast in first dimension + {Stride, 1}, + Add{}); + + if(!broadcastAdd.IsSupportedArgument(argument.get())) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "DeviceBinaryElementwise_2D instance, exiting!"); + }; + + auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); + float ave_time = + broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); + Tensor host_c_m_n(f_host_tensor_descriptor2d(M, N, Stride)); + + host_broadcast2D, + Tensor, + Tensor, + EltwiseComputeDataType, + Add, + 0>(host_c_m_n, a_m_n, b_n, M, N, Add{}); + + pass &= ck::utils::check_err( + c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results d1", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/19_binary_elementwise/elementwise_add_1d.cpp b/example/19_binary_elementwise/elementwise_add_1d.cpp new file mode 100644 index 00000000000..455ff24c31b --- /dev/null +++ b/example/19_binary_elementwise/elementwise_add_1d.cpp @@ -0,0 +1,110 @@ +#include +#include +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" + +#include "device_tensor.hpp" +#include "binary_element_wise_operation.hpp" +#include "device_binary_elementwise.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ABDataType = F16; +using CDataType = F16; +using EltwiseComputeDataType = F32; + +using Add = ck::tensor_operation::binary_element_wise::Add; + +using DeviceElementwiseAddInstance = ck::tensor_operation::device:: + DeviceBinaryElementwise; + +template +void host_elementwise1D( + HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(int m = 0; m < M; ++m) + { + ComputeDataType Am = static_cast(A(m)); + ComputeDataType Bm = static_cast(B(m)); + ComputeDataType Cm = 0; + functor(Cm, Am, Bm); + C(m) = static_cast(Cm); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + ck::index_t M = 1024; + + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); + }; + + Tensor a_m(f_host_tensor_descriptor1d(M, 1)); + Tensor b_m(f_host_tensor_descriptor1d(M, 1)); + Tensor c_m(f_host_tensor_descriptor1d(M, 1)); + + a_m.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_m.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace()); + DeviceMem b_m_device_buf(sizeof(ABDataType) * b_m.mDesc.GetElementSpace()); + DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpace()); + + a_m_device_buf.ToDevice(a_m.mData.data()); + b_m_device_buf.ToDevice(b_m.mData.data()); + + auto broadcastAdd = DeviceElementwiseAddInstance{}; + auto argument = broadcastAdd.MakeArgumentPointer(a_m_device_buf.GetDeviceBuffer(), + b_m_device_buf.GetDeviceBuffer(), + c_m_device_buf.GetDeviceBuffer(), + {M}, + {1}, + {1}, + {1}, + Add{}); + + if(!broadcastAdd.IsSupportedArgument(argument.get())) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "DeviceBinaryElementwise_2D instance, exiting!"); + }; + + auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); + float ave_time = + broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_m_device_buf.FromDevice(c_m.mData.data()); + Tensor host_c_m(f_host_tensor_descriptor1d(M, 1)); + + host_elementwise1D, + Tensor, + Tensor, + EltwiseComputeDataType, + Add>(host_c_m, a_m, b_m, M, Add{}); + + pass &= ck::utils::check_err( + c_m.mData, host_c_m.mData, "Error: Incorrect results d1", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/19_binary_elementwise/elementwise_add_4d.cpp b/example/19_binary_elementwise/elementwise_add_4d.cpp new file mode 100644 index 00000000000..937a6c8c1dc --- /dev/null +++ b/example/19_binary_elementwise/elementwise_add_4d.cpp @@ -0,0 +1,112 @@ +#include +#include +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" + +#include "device_tensor.hpp" +#include "binary_element_wise_operation.hpp" +#include "device_binary_elementwise.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ABDataType = F16; +using CDataType = F16; +using EltwiseComputeDataType = F32; + +using Add = ck::tensor_operation::binary_element_wise::Add; + +using DeviceElementwiseAddInstance = ck::tensor_operation::device:: + DeviceBinaryElementwise; + +template +void host_elementwise4D(HostTensorC& C, + const HostTensorA& A, + const HostTensorB& B, + const std::vector& shape, + Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(std::size_t n = 0; n < shape[0]; ++n) + for(std::size_t c = 0; c < shape[1]; ++c) + for(std::size_t h = 0; h < shape[2]; ++h) + for(std::size_t w = 0; w < shape[3]; ++w) + { + ComputeDataType a_val = static_cast(A(n, c, h, w)); + ComputeDataType b_val = static_cast(B(n, c, h, w)); + ComputeDataType c_val = 0; + functor(c_val, a_val, b_val); + C(n, c, h, w) = static_cast(c_val); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + std::vector nchw = {4, 16, 32, 32}; + + Tensor a(nchw); + Tensor b(nchw); + Tensor c(nchw); + + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ABDataType) * a.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(ABDataType) * b.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a.mData.data()); + b_device_buf.ToDevice(b.mData.data()); + + auto broadcastAdd = DeviceElementwiseAddInstance{}; + auto argument = broadcastAdd.MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + std::vector{nchw.begin(), nchw.end()}, + std::vector{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()}, + std::vector{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()}, + std::vector{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()}, + Add{}); + + if(!broadcastAdd.IsSupportedArgument(argument.get())) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "DeviceBinaryElementwise_2D instance, exiting!"); + }; + + auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); + float ave_time = + broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_device_buf.FromDevice(c.mData.data()); + Tensor host_c(nchw); + + host_elementwise4D, + Tensor, + Tensor, + EltwiseComputeDataType, + Add>(host_c, a, b, nchw, Add{}); + + pass &= + ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results d1", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/20_convnd_bwd_weight_xdl/CMakeLists.txt b/example/20_convnd_bwd_weight_xdl/CMakeLists.txt new file mode 100644 index 00000000000..1a644d94794 --- /dev/null +++ b/example/20_convnd_bwd_weight_xdl/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_convnd_bwd_weight_xdl convnd_bwd_weight_xdl.cpp) +target_link_libraries(example_convnd_bwd_weight_xdl PRIVATE conv_util) \ No newline at end of file diff --git a/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp new file mode 100644 index 00000000000..0fc976c34a6 --- /dev/null +++ b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp @@ -0,0 +1,425 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "conv_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "tensor_layout.hpp" +#include "element_wise_operation.hpp" +#include "device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "reference_conv_backward_weight.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +using DeviceConvBwdWeightBasePtr = + ck::tensor_operation::device::DeviceConvBwdWeightPtr; + +// clang-format off +template +using DeviceConvndBwdWeightInstance = ck::tensor_operation::device:: + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +template +using ReferenceConvBwdWeightInstance = + ck::tensor_operation::host::ReferenceConvBwdWeight; + +void print_use_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4: is show log (0=no, 1=yes)\n" + << "arg5: split-k \n" + << "arg6: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + ck::utils::conv::ConvParams params; + int arg_idx = 7; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +DeviceConvBwdWeightBasePtr get_conv_instance(int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 3: { + return std::make_unique>(); + } + case 2: { + return std::make_unique>(); + } + case 1: { + return std::make_unique>(); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int num_dim_spatial = 2; + int do_log = 0; + int split_k = 1; + + ck::utils::conv::ConvParams params; + params.C_ = 128; + + if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + } + else if(argc > 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + num_dim_spatial = std::stoi(argv[6]); + // check args number + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 7; + if(cmdline_nargs != argc) + { + print_use_msg(); + exit(1); + } + + params = parse_conv_params(num_dim_spatial, argv); + } + else if(argc != 1) + { + print_use_msg(); + exit(1); + } + + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor in_n_c_hi_wi( + ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor wei_k_c_y_x_host_result( + ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor wei_k_c_y_x_device_result( + ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor out_n_k_ho_wo( + ck::utils::conv::get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_device_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * + wei_k_c_y_x_device_result.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + // reset input to zero + wei_device_buf.SetZero(); + + // do GEMM + auto conv = get_conv_instance(num_dim_spatial); + auto invoker = conv->MakeInvokerPointer(); + auto argument = + conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}, + split_k); + + // alloc work space + size_t bwd_weight_workspace_size = conv->GetWorkSpaceSize(argument.get()); + float ave_time = 0.f; + if(std::is_same::value && split_k > 1) + { + DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size); + wei_work_space_device_buf.SetZero(); + argument = conv->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_work_space_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}, + split_k); + + if(!conv->IsSupportedArgument(argument.get())) + { + std::cout << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + return 1; + } + + ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + } + else + { + if(!conv->IsSupportedArgument(argument.get())) + { + std::cout << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + return 1; + } + ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + } + + std::size_t flop = ck::utils::conv::get_flops( + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); + std::size_t num_btype = ck::utils::conv::get_btype( + params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto verify_f = [&](const auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x_host_result, + out_n_k_ho_wo, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + + wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data()); + + if(do_log) + { + LogRangeAsType(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",") + << std::endl; + } + + return ck::utils::check_err(wei_k_c_y_x_device_result.mData, + wei_k_c_y_x_host_result.mData) + ? 0 + : 1; + }; + + switch(num_dim_spatial) + { + case 3: { + auto ref_conv = ReferenceConvBwdWeightInstance<3>(); + verify_f(ref_conv); + break; + } + case 2: { + auto ref_conv = ReferenceConvBwdWeightInstance<2>(); + verify_f(ref_conv); + break; + } + case 1: { + auto ref_conv = ReferenceConvBwdWeightInstance<1>(); + verify_f(ref_conv); + break; + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } + } + return 0; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt new file mode 100644 index 00000000000..e595ca23333 --- /dev/null +++ b/example/CMakeLists.txt @@ -0,0 +1,56 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/include/ck + ${PROJECT_SOURCE_DIR}/include/ck/utility + ${PROJECT_SOURCE_DIR}/include/ck/host_utility + ${PROJECT_SOURCE_DIR}/include/ck/tensor_description + ${PROJECT_SOURCE_DIR}/include/ck/tensor + ${PROJECT_SOURCE_DIR}/include/ck/problem_transform + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/device + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/grid + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/block + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/warp + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/thread + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/element + ${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor + ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu + ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu + ${PROJECT_SOURCE_DIR}/library/include/ck/library/utility + ${PROJECT_SOURCE_DIR}/external/include/half +) + +add_custom_target(examples) + +function(add_example_executable EXAMPLE_NAME FILE_NAME) + message("adding example ${EXAMPLE_NAME}") + add_executable(${EXAMPLE_NAME} ${FILE_NAME}) + target_link_libraries(${EXAMPLE_NAME} PRIVATE host_tensor) + add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) + add_dependencies(examples ${EXAMPLE_NAME}) + add_dependencies(check ${EXAMPLE_NAME}) +endfunction(add_example_executable EXAMPLE_NAME) + +function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) + message("adding example ${EXAMPLE_NAME}") + add_executable(${EXAMPLE_NAME} ${FILE_NAME}) + target_link_libraries(${EXAMPLE_NAME} PRIVATE host_tensor) + add_dependencies(examples ${EXAMPLE_NAME}) +endfunction(add_example_executable_no_testing EXAMPLE_NAME) + +add_subdirectory(01_gemm) +add_subdirectory(02_gemm_alpha_beta) +add_subdirectory(03_gemm_bias_relu) +add_subdirectory(04_gemm_bias_relu_add) +add_subdirectory(06_conv2d_fwd_bias_relu) +add_subdirectory(07_conv2d_fwd_bias_relu_add) +add_subdirectory(09_convnd_fwd) +add_subdirectory(10_conv2d_bwd_data) +add_subdirectory(11_conv2d_bwd_weight) +add_subdirectory(12_reduce) +add_subdirectory(13_pool2d_fwd) +add_subdirectory(14_gemm_xdl_requant_relu_requant) +add_subdirectory(17_convnd_bwd_data_xdl) +add_subdirectory(15_grouped_gemm) +add_subdirectory(16_gemm_reduce) +add_subdirectory(18_batched_gemm_reduce) +add_subdirectory(19_binary_elementwise) +add_subdirectory(20_convnd_bwd_weight_xdl) diff --git a/external/include/half/half.hpp b/external/include/half/half.hpp new file mode 100644 index 00000000000..25f543881f6 --- /dev/null +++ b/external/include/half/half.hpp @@ -0,0 +1,5670 @@ +// half - IEEE 754-based half-precision floating-point library. +// +// Copyright (c) 2012-2019 Christian Rau +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +// associated documentation +// files (the "Software"), to deal in the Software without restriction, including without limitation +// the rights to use, copy, +// modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit +// persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT +// NOT LIMITED TO THE +// WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF +// CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Version 2.1.0 + +/// \file +/// Main header file for half-precision functionality. + +#ifndef HALF_HALF_HPP +#define HALF_HALF_HPP + +#define HALF_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) + +#if defined(__INTEL_COMPILER) +#define HALF_ICC_VERSION __INTEL_COMPILER +#elif defined(__ICC) +#define HALF_ICC_VERSION __ICC +#elif defined(__ICL) +#define HALF_ICC_VERSION __ICL +#else +#define HALF_ICC_VERSION 0 +#endif + +// check C++11 language features +#if defined(__clang__) // clang +#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if __has_feature(cxx_thread_local) && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if(defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && \ + !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#elif HALF_ICC_VERSION && defined(__INTEL_CXX11_MODE__) // Intel C++ +#if HALF_ICC_VERSION >= 1500 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if HALF_ICC_VERSION >= 1500 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if HALF_ICC_VERSION >= 1400 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if HALF_ICC_VERSION >= 1400 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if HALF_ICC_VERSION >= 1110 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if HALF_ICC_VERSION >= 1110 && !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#elif defined(__GNUC__) // gcc +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L +#if HALF_GCC_VERSION >= 408 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if HALF_GCC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if HALF_GCC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if HALF_GCC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#endif +#define HALF_TWOS_COMPLEMENT_INT 1 +#elif defined(_MSC_VER) // Visual C++ +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#define HALF_TWOS_COMPLEMENT_INT 1 +#define HALF_POP_WARNINGS 1 +#pragma warning(push) +#pragma warning(disable : 4099 4127 4146) // struct vs class, constant in if, negative unsigned +#endif + +// check C++11 library features +#include +#if defined(_LIBCPP_VERSION) // libc++ +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 +#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#ifndef HALF_ENABLE_CPP11_CSTDINT +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#ifndef HALF_ENABLE_CPP11_CMATH +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#ifndef HALF_ENABLE_CPP11_HASH +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#ifndef HALF_ENABLE_CPP11_CFENV +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#endif +#elif defined(__GLIBCXX__) // libstdc++ +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 +#ifdef __clang__ +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CFENV) +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#else +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CFENV) +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#endif +#endif +#elif defined(_CPPLIB_VER) // Dinkumware/Visual C++ +#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_CSTDINT) +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_HASH) +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#if _CPPLIB_VER >= 610 && !defined(HALF_ENABLE_CPP11_CMATH) +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#if _CPPLIB_VER >= 610 && !defined(HALF_ENABLE_CPP11_CFENV) +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#endif +#undef HALF_GCC_VERSION +#undef HALF_ICC_VERSION + +// any error throwing C++ exceptions? +#if defined(HALF_ERRHANDLING_THROW_INVALID) || defined(HALF_ERRHANDLING_THROW_DIVBYZERO) || \ + defined(HALF_ERRHANDLING_THROW_OVERFLOW) || defined(HALF_ERRHANDLING_THROW_UNDERFLOW) || \ + defined(HALF_ERRHANDLING_THROW_INEXACT) +#define HALF_ERRHANDLING_THROWS 1 +#endif + +// any error handling enabled? +#define HALF_ERRHANDLING \ + (HALF_ERRHANDLING_FLAGS || HALF_ERRHANDLING_ERRNO || HALF_ERRHANDLING_FENV || \ + HALF_ERRHANDLING_THROWS) + +#if HALF_ERRHANDLING +#define HALF_UNUSED_NOERR(name) name +#else +#define HALF_UNUSED_NOERR(name) +#endif + +// support constexpr +#if HALF_ENABLE_CPP11_CONSTEXPR +#define HALF_CONSTEXPR constexpr +#define HALF_CONSTEXPR_CONST constexpr +#if HALF_ERRHANDLING +#define HALF_CONSTEXPR_NOERR +#else +#define HALF_CONSTEXPR_NOERR constexpr +#endif +#else +#define HALF_CONSTEXPR +#define HALF_CONSTEXPR_CONST const +#define HALF_CONSTEXPR_NOERR +#endif + +// support noexcept +#if HALF_ENABLE_CPP11_NOEXCEPT +#define HALF_NOEXCEPT noexcept +#define HALF_NOTHROW noexcept +#else +#define HALF_NOEXCEPT +#define HALF_NOTHROW throw() +#endif + +// support thread storage +#if HALF_ENABLE_CPP11_THREAD_LOCAL +#define HALF_THREAD_LOCAL thread_local +#else +#define HALF_THREAD_LOCAL static +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if HALF_ENABLE_CPP11_TYPE_TRAITS +#include +#endif +#if HALF_ENABLE_CPP11_CSTDINT +#include +#endif +#if HALF_ERRHANDLING_ERRNO +#include +#endif +#if HALF_ENABLE_CPP11_CFENV +#include +#endif +#if HALF_ENABLE_CPP11_HASH +#include +#endif +#if HALF_ENABLE_F16C_INTRINSICS +#include +#endif + +#ifndef HALF_ENABLE_F16C_INTRINSICS +/// Enable F16C intruction set intrinsics. +/// Defining this to 1 enables the use of [F16C compiler +/// intrinsics](https://en.wikipedia.org/wiki/F16C) for converting between +/// half-precision and single-precision values which may result in improved performance. This will +/// not perform additional checks +/// for support of the F16C instruction set, so an appropriate target platform is required when +/// enabling this feature. +/// +/// Unless predefined it will be enabled automatically when the `__F16C__` symbol is defined, which +/// some compilers do on supporting platforms. +#define HALF_ENABLE_F16C_INTRINSICS __F16C__ +#endif + +#ifdef HALF_DOXYGEN_ONLY +/// Type for internal floating-point computations. +/// This can be predefined to a built-in floating-point type (`float`, `double` or `long double`) to +/// override the internal +/// half-precision implementation to use this type for computing arithmetic operations and +/// mathematical function (if available). +/// This can result in improved performance for arithmetic operators and mathematical functions but +/// might cause results to +/// deviate from the specified half-precision rounding mode and inhibits proper detection of +/// half-precision exceptions. +#define HALF_ARITHMETIC_TYPE (undefined) + +/// Enable internal exception flags. +/// Defining this to 1 causes operations on half-precision values to raise internal floating-point +/// exception flags according to +/// the IEEE 754 standard. These can then be cleared and checked with clearexcept(), testexcept(). +#define HALF_ERRHANDLING_FLAGS 0 + +/// Enable exception propagation to `errno`. +/// Defining this to 1 causes operations on half-precision values to propagate floating-point +/// exceptions to +/// [errno](https://en.cppreference.com/w/cpp/error/errno) from ``. Specifically this will +/// propagate domain errors as +/// [EDOM](https://en.cppreference.com/w/cpp/error/errno_macros) and pole, overflow and underflow +/// errors as +/// [ERANGE](https://en.cppreference.com/w/cpp/error/errno_macros). Inexact errors won't be +/// propagated. +#define HALF_ERRHANDLING_ERRNO 0 + +/// Enable exception propagation to built-in floating-point platform. +/// Defining this to 1 causes operations on half-precision values to propagate floating-point +/// exceptions to the built-in +/// single- and double-precision implementation's exception flags using the +/// [C++11 floating-point environment control](https://en.cppreference.com/w/cpp/numeric/fenv) from +/// ``. However, this +/// does not work in reverse and single- or double-precision exceptions will not raise the +/// corresponding half-precision +/// exception flags, nor will explicitly clearing flags clear the corresponding built-in flags. +#define HALF_ERRHANDLING_FENV 0 + +/// Throw C++ exception on domain errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::domain_error](https://en.cppreference.com/w/cpp/error/domain_error) with the specified +/// message on domain errors. +#define HALF_ERRHANDLING_THROW_INVALID (undefined) + +/// Throw C++ exception on pole errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::domain_error](https://en.cppreference.com/w/cpp/error/domain_error) with the specified +/// message on pole errors. +#define HALF_ERRHANDLING_THROW_DIVBYZERO (undefined) + +/// Throw C++ exception on overflow errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::overflow_error](https://en.cppreference.com/w/cpp/error/overflow_error) with the specified +/// message on overflows. +#define HALF_ERRHANDLING_THROW_OVERFLOW (undefined) + +/// Throw C++ exception on underflow errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::underflow_error](https://en.cppreference.com/w/cpp/error/underflow_error) with the +/// specified message on underflows. +#define HALF_ERRHANDLING_THROW_UNDERFLOW (undefined) + +/// Throw C++ exception on rounding errors. +/// Defining this to 1 causes operations on half-precision values to throw a +/// [std::range_error](https://en.cppreference.com/w/cpp/error/range_error) with the specified +/// message on general rounding errors. +#define HALF_ERRHANDLING_THROW_INEXACT (undefined) +#endif + +#ifndef HALF_ERRHANDLING_OVERFLOW_TO_INEXACT +/// Raise INEXACT exception on overflow. +/// Defining this to 1 (default) causes overflow errors to automatically raise inexact exceptions in +/// addition. +/// These will be raised after any possible handling of the underflow exception. +#define HALF_ERRHANDLING_OVERFLOW_TO_INEXACT 1 +#endif + +#ifndef HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT +/// Raise INEXACT exception on underflow. +/// Defining this to 1 (default) causes underflow errors to automatically raise inexact exceptions +/// in addition. +/// These will be raised after any possible handling of the underflow exception. +/// +/// **Note:** This will actually cause underflow (and the accompanying inexact) exceptions to be +/// raised *only* when the result +/// is inexact, while if disabled bare underflow errors will be raised for *any* (possibly exact) +/// subnormal result. +#define HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT 1 +#endif + +/// Default rounding mode. +/// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s +/// and more precise types +/// (unless using half_cast() and specifying the rounding mode directly) as well as in arithmetic +/// operations and mathematical +/// functions. It can be redefined (before including half.hpp) to one of the standard rounding modes +/// using their respective +/// constants or the equivalent values of +/// [std::float_round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/float_round_style): +/// +/// `std::float_round_style` | value | rounding +/// ---------------------------------|-------|------------------------- +/// `std::round_indeterminate` | -1 | fastest +/// `std::round_toward_zero` | 0 | toward zero +/// `std::round_to_nearest` | 1 | to nearest (default) +/// `std::round_toward_infinity` | 2 | toward positive infinity +/// `std::round_toward_neg_infinity` | 3 | toward negative infinity +/// +/// By default this is set to `1` (`std::round_to_nearest`), which rounds results to the nearest +/// representable value. It can even +/// be set to +/// [std::numeric_limits::round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/round_style) +/// to synchronize +/// the rounding mode with that of the built-in single-precision implementation (which is likely +/// `std::round_to_nearest`, though). +#ifndef HALF_ROUND_STYLE +#define HALF_ROUND_STYLE 1 // = std::round_to_nearest +#endif + +/// Value signaling overflow. +/// In correspondence with `HUGE_VAL[F|L]` from `` this symbol expands to a positive value +/// signaling the overflow of an +/// operation, in particular it just evaluates to positive infinity. +/// +/// **See also:** Documentation for +/// [HUGE_VAL](https://en.cppreference.com/w/cpp/numeric/math/HUGE_VAL) +#define HUGE_VALH std::numeric_limits::infinity() + +/// Fast half-precision fma function. +/// This symbol is defined if the fma() function generally executes as fast as, or faster than, a +/// separate +/// half-precision multiplication followed by an addition, which is always the case. +/// +/// **See also:** Documentation for +/// [FP_FAST_FMA](https://en.cppreference.com/w/cpp/numeric/math/fma) +#define FP_FAST_FMAH 1 + +/// Half rounding mode. +/// In correspondence with `FLT_ROUNDS` from `` this symbol expands to the rounding mode +/// used for +/// half-precision operations. It is an alias for [HALF_ROUND_STYLE](\ref HALF_ROUND_STYLE). +/// +/// **See also:** Documentation for +/// [FLT_ROUNDS](https://en.cppreference.com/w/cpp/types/climits/FLT_ROUNDS) +#define HLF_ROUNDS HALF_ROUND_STYLE + +#ifndef FP_ILOGB0 +#define FP_ILOGB0 INT_MIN +#endif +#ifndef FP_ILOGBNAN +#define FP_ILOGBNAN INT_MAX +#endif +#ifndef FP_SUBNORMAL +#define FP_SUBNORMAL 0 +#endif +#ifndef FP_ZERO +#define FP_ZERO 1 +#endif +#ifndef FP_NAN +#define FP_NAN 2 +#endif +#ifndef FP_INFINITE +#define FP_INFINITE 3 +#endif +#ifndef FP_NORMAL +#define FP_NORMAL 4 +#endif + +#if !HALF_ENABLE_CPP11_CFENV && !defined(FE_ALL_EXCEPT) +#define FE_INVALID 0x10 +#define FE_DIVBYZERO 0x08 +#define FE_OVERFLOW 0x04 +#define FE_UNDERFLOW 0x02 +#define FE_INEXACT 0x01 +#define FE_ALL_EXCEPT (FE_INVALID | FE_DIVBYZERO | FE_OVERFLOW | FE_UNDERFLOW | FE_INEXACT) +#endif + +/// Main namespace for half-precision functionality. +/// This namespace contains all the functionality provided by the library. +namespace half_float { +class half; + +#if HALF_ENABLE_CPP11_USER_LITERALS +/// Library-defined half-precision literals. +/// Import this namespace to enable half-precision floating-point literals: +/// ~~~~{.cpp} +/// using namespace half_float::literal; +/// half_float::half = 4.2_h; +/// ~~~~ +namespace literal { +half operator"" _h(long double); +} +#endif + +/// \internal +/// \brief Implementation details. +namespace detail { +#if HALF_ENABLE_CPP11_TYPE_TRAITS +/// Conditional type. +template +struct conditional : std::conditional +{ +}; + +/// Helper for tag dispatching. +template +struct bool_type : std::integral_constant +{ +}; +using std::false_type; +using std::true_type; + +/// Type traits for floating-point types. +template +struct is_float : std::is_floating_point +{ +}; +#else +/// Conditional type. +template +struct conditional +{ + typedef T type; +}; +template +struct conditional +{ + typedef F type; +}; + +/// Helper for tag dispatching. +template +struct bool_type +{ +}; +typedef bool_type true_type; +typedef bool_type false_type; + +/// Type traits for floating-point types. +template +struct is_float : false_type +{ +}; +template +struct is_float : is_float +{ +}; +template +struct is_float : is_float +{ +}; +template +struct is_float : is_float +{ +}; +template <> +struct is_float : true_type +{ +}; +template <> +struct is_float : true_type +{ +}; +template <> +struct is_float : true_type +{ +}; +#endif + +/// Type traits for floating-point bits. +template +struct bits +{ + typedef unsigned char type; +}; +template +struct bits : bits +{ +}; +template +struct bits : bits +{ +}; +template +struct bits : bits +{ +}; + +#if HALF_ENABLE_CPP11_CSTDINT +/// Unsigned integer of (at least) 16 bits width. +typedef std::uint_least16_t uint16; + +/// Fastest unsigned integer of (at least) 32 bits width. +typedef std::uint_fast32_t uint32; + +/// Fastest signed integer of (at least) 32 bits width. +typedef std::int_fast32_t int32; + +/// Unsigned integer of (at least) 32 bits width. +template <> +struct bits +{ + typedef std::uint_least32_t type; +}; + +/// Unsigned integer of (at least) 64 bits width. +template <> +struct bits +{ + typedef std::uint_least64_t type; +}; +#else +/// Unsigned integer of (at least) 16 bits width. +typedef unsigned short uint16; + +/// Fastest unsigned integer of (at least) 32 bits width. +typedef unsigned long uint32; + +/// Fastest unsigned integer of (at least) 32 bits width. +typedef long int32; + +/// Unsigned integer of (at least) 32 bits width. +template <> +struct bits + : conditional::digits >= 32, unsigned int, unsigned long> +{ +}; + +#if HALF_ENABLE_CPP11_LONG_LONG +/// Unsigned integer of (at least) 64 bits width. +template <> +struct bits : conditional::digits >= 64, + unsigned long, + unsigned long long> +{ +}; +#else +/// Unsigned integer of (at least) 64 bits width. +template <> +struct bits +{ + typedef unsigned long type; +}; +#endif +#endif + +#ifdef HALF_ARITHMETIC_TYPE +/// Type to use for arithmetic computations and mathematic functions internally. +typedef HALF_ARITHMETIC_TYPE internal_t; +#endif + +/// Tag type for binary construction. +struct binary_t +{ +}; + +/// Tag for binary construction. +HALF_CONSTEXPR_CONST binary_t binary = binary_t(); + +/// \name Implementation defined classification and arithmetic +/// \{ + +/// Check for infinity. +/// \tparam T argument type (builtin floating-point type) +/// \param arg value to query +/// \retval true if infinity +/// \retval false else +template +bool builtin_isinf(T arg) +{ +#if HALF_ENABLE_CPP11_CMATH + return std::isinf(arg); +#elif defined(_MSC_VER) + return !::_finite(static_cast(arg)) && !::_isnan(static_cast(arg)); +#else + return arg == std::numeric_limits::infinity() || arg == -std::numeric_limits::infinity(); +#endif +} + +/// Check for NaN. +/// \tparam T argument type (builtin floating-point type) +/// \param arg value to query +/// \retval true if not a number +/// \retval false else +template +bool builtin_isnan(T arg) +{ +#if HALF_ENABLE_CPP11_CMATH + return std::isnan(arg); +#elif defined(_MSC_VER) + return ::_isnan(static_cast(arg)) != 0; +#else + return arg != arg; +#endif +} + +/// Check sign. +/// \tparam T argument type (builtin floating-point type) +/// \param arg value to query +/// \retval true if signbit set +/// \retval false else +template +bool builtin_signbit(T arg) +{ +#if HALF_ENABLE_CPP11_CMATH + return std::signbit(arg); +#else + return arg < T() || (arg == T() && T(1) / arg < T()); +#endif +} + +/// Platform-independent sign mask. +/// \param arg integer value in two's complement +/// \retval -1 if \a arg negative +/// \retval 0 if \a arg positive +inline uint32 sign_mask(uint32 arg) +{ + static const int N = std::numeric_limits::digits - 1; +#if HALF_TWOS_COMPLEMENT_INT + return static_cast(arg) >> N; +#else + return -((arg >> N) & 1); +#endif +} + +/// Platform-independent arithmetic right shift. +/// \param arg integer value in two's complement +/// \param i shift amount (at most 31) +/// \return \a arg right shifted for \a i bits with possible sign extension +inline uint32 arithmetic_shift(uint32 arg, int i) +{ +#if HALF_TWOS_COMPLEMENT_INT + return static_cast(arg) >> i; +#else + return static_cast(arg) / (static_cast(1) << i) - + ((arg >> (std::numeric_limits::digits - 1)) & 1); +#endif +} + +/// \} +/// \name Error handling +/// \{ + +/// Internal exception flags. +/// \return reference to global exception flags +inline int& errflags() +{ + HALF_THREAD_LOCAL int flags = 0; + return flags; +} + +/// Raise floating-point exception. +/// \param flags exceptions to raise +/// \param cond condition to raise exceptions for +inline void raise(int HALF_UNUSED_NOERR(flags), bool HALF_UNUSED_NOERR(cond) = true) +{ +#if HALF_ERRHANDLING + if(!cond) + return; +#if HALF_ERRHANDLING_FLAGS + errflags() |= flags; +#endif +#if HALF_ERRHANDLING_ERRNO + if(flags & FE_INVALID) + errno = EDOM; + else if(flags & (FE_DIVBYZERO | FE_OVERFLOW | FE_UNDERFLOW)) + errno = ERANGE; +#endif +#if HALF_ERRHANDLING_FENV && HALF_ENABLE_CPP11_CFENV + std::feraiseexcept(flags); +#endif +#ifdef HALF_ERRHANDLING_THROW_INVALID + if(flags & FE_INVALID) + throw std::domain_error(HALF_ERRHANDLING_THROW_INVALID); +#endif +#ifdef HALF_ERRHANDLING_THROW_DIVBYZERO + if(flags & FE_DIVBYZERO) + throw std::domain_error(HALF_ERRHANDLING_THROW_DIVBYZERO); +#endif +#ifdef HALF_ERRHANDLING_THROW_OVERFLOW + if(flags & FE_OVERFLOW) + throw std::overflow_error(HALF_ERRHANDLING_THROW_OVERFLOW); +#endif +#ifdef HALF_ERRHANDLING_THROW_UNDERFLOW + if(flags & FE_UNDERFLOW) + throw std::underflow_error(HALF_ERRHANDLING_THROW_UNDERFLOW); +#endif +#ifdef HALF_ERRHANDLING_THROW_INEXACT + if(flags & FE_INEXACT) + throw std::range_error(HALF_ERRHANDLING_THROW_INEXACT); +#endif +#if HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT + if((flags & FE_UNDERFLOW) && !(flags & FE_INEXACT)) + raise(FE_INEXACT); +#endif +#if HALF_ERRHANDLING_OVERFLOW_TO_INEXACT + if((flags & FE_OVERFLOW) && !(flags & FE_INEXACT)) + raise(FE_INEXACT); +#endif +#endif +} + +/// Check and signal for any NaN. +/// \param x first half-precision value to check +/// \param y second half-precision value to check +/// \retval true if either \a x or \a y is NaN +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool compsignal(unsigned int x, unsigned int y) +{ +#if HALF_ERRHANDLING + raise(FE_INVALID, (x & 0x7FFF) > 0x7C00 || (y & 0x7FFF) > 0x7C00); +#endif + return (x & 0x7FFF) > 0x7C00 || (y & 0x7FFF) > 0x7C00; +} + +/// Signal and silence signaling NaN. +/// \param nan half-precision NaN value +/// \return quiet NaN +/// \exception FE_INVALID if \a nan is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int nan) +{ +#if HALF_ERRHANDLING + raise(FE_INVALID, !(nan & 0x200)); +#endif + return nan | 0x200; +} + +/// Signal and silence signaling NaNs. +/// \param x first half-precision value to check +/// \param y second half-precision value to check +/// \return quiet NaN +/// \exception FE_INVALID if \a x or \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y) +{ +#if HALF_ERRHANDLING + raise(FE_INVALID, + ((x & 0x7FFF) > 0x7C00 && !(x & 0x200)) || ((y & 0x7FFF) > 0x7C00 && !(y & 0x200))); +#endif + return ((x & 0x7FFF) > 0x7C00) ? (x | 0x200) : (y | 0x200); +} + +/// Signal and silence signaling NaNs. +/// \param x first half-precision value to check +/// \param y second half-precision value to check +/// \param z third half-precision value to check +/// \return quiet NaN +/// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y, unsigned int z) +{ +#if HALF_ERRHANDLING + raise(FE_INVALID, + ((x & 0x7FFF) > 0x7C00 && !(x & 0x200)) || ((y & 0x7FFF) > 0x7C00 && !(y & 0x200)) || + ((z & 0x7FFF) > 0x7C00 && !(z & 0x200))); +#endif + return ((x & 0x7FFF) > 0x7C00) ? (x | 0x200) + : ((y & 0x7FFF) > 0x7C00) ? (y | 0x200) : (z | 0x200); +} + +/// Select value or signaling NaN. +/// \param x preferred half-precision value +/// \param y ignored half-precision value except for signaling NaN +/// \return \a y if signaling NaN, \a x otherwise +/// \exception FE_INVALID if \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int select(unsigned int x, unsigned int HALF_UNUSED_NOERR(y)) +{ +#if HALF_ERRHANDLING + return (((y & 0x7FFF) > 0x7C00) && !(y & 0x200)) ? signal(y) : x; +#else + return x; +#endif +} + +/// Raise domain error and return NaN. +/// return quiet NaN +/// \exception FE_INVALID +inline HALF_CONSTEXPR_NOERR unsigned int invalid() +{ +#if HALF_ERRHANDLING + raise(FE_INVALID); +#endif + return 0x7FFF; +} + +/// Raise pole error and return infinity. +/// \param sign half-precision value with sign bit only +/// \return half-precision infinity with sign of \a sign +/// \exception FE_DIVBYZERO +inline HALF_CONSTEXPR_NOERR unsigned int pole(unsigned int sign = 0) +{ +#if HALF_ERRHANDLING + raise(FE_DIVBYZERO); +#endif + return sign | 0x7C00; +} + +/// Check value for underflow. +/// \param arg non-zero half-precision value to check +/// \return \a arg +/// \exception FE_UNDERFLOW if arg is subnormal +inline HALF_CONSTEXPR_NOERR unsigned int check_underflow(unsigned int arg) +{ +#if HALF_ERRHANDLING && !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT + raise(FE_UNDERFLOW, !(arg & 0x7C00)); +#endif + return arg; +} + +/// \} +/// \name Conversion and rounding +/// \{ + +/// Half-precision overflow. +/// \tparam R rounding mode to use +/// \param sign half-precision value with sign bit only +/// \return rounded overflowing half-precision value +/// \exception FE_OVERFLOW +template +HALF_CONSTEXPR_NOERR unsigned int overflow(unsigned int sign = 0) +{ +#if HALF_ERRHANDLING + raise(FE_OVERFLOW); +#endif + return (R == std::round_toward_infinity) + ? (sign + 0x7C00 - (sign >> 15)) + : (R == std::round_toward_neg_infinity) + ? (sign + 0x7BFF + (sign >> 15)) + : (R == std::round_toward_zero) ? (sign | 0x7BFF) : (sign | 0x7C00); +} + +/// Half-precision underflow. +/// \tparam R rounding mode to use +/// \param sign half-precision value with sign bit only +/// \return rounded underflowing half-precision value +/// \exception FE_UNDERFLOW +template +HALF_CONSTEXPR_NOERR unsigned int underflow(unsigned int sign = 0) +{ +#if HALF_ERRHANDLING + raise(FE_UNDERFLOW); +#endif + return (R == std::round_toward_infinity) + ? (sign + 1 - (sign >> 15)) + : (R == std::round_toward_neg_infinity) ? (sign + (sign >> 15)) : sign; +} + +/// Round half-precision number. +/// \tparam R rounding mode to use +/// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results +/// \param value finite half-precision number to round +/// \param g guard bit (most significant discarded bit) +/// \param s sticky bit (or of all but the most significant discarded bits) +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded or \a I is `true` +template +HALF_CONSTEXPR_NOERR unsigned int rounded(unsigned int value, int g, int s) +{ +#if HALF_ERRHANDLING + value += (R == std::round_to_nearest) + ? (g & (s | value)) + : (R == std::round_toward_infinity) + ? (~(value >> 15) & (g | s)) + : (R == std::round_toward_neg_infinity) ? ((value >> 15) & (g | s)) : 0; + if((value & 0x7C00) == 0x7C00) + raise(FE_OVERFLOW); + else if(value & 0x7C00) + raise(FE_INEXACT, I || (g | s) != 0); + else + raise(FE_UNDERFLOW, !(HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT) || I || (g | s) != 0); + return value; +#else + return (R == std::round_to_nearest) + ? (value + (g & (s | value))) + : (R == std::round_toward_infinity) + ? (value + (~(value >> 15) & (g | s))) + : (R == std::round_toward_neg_infinity) ? (value + ((value >> 15) & (g | s))) + : value; +#endif +} + +/// Round half-precision number to nearest integer value. +/// \tparam R rounding mode to use +/// \tparam E `true` for round to even, `false` for round away from zero +/// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never raise it +/// \param value half-precision value to round +/// \return half-precision bits for nearest integral value +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded and \a I is `true` +template +unsigned int integral(unsigned int value) +{ + unsigned int abs = value & 0x7FFF; + if(abs < 0x3C00) + { + raise(FE_INEXACT, I); + return ((R == std::round_to_nearest) + ? (0x3C00 & -static_cast(abs >= (0x3800 + E))) + : (R == std::round_toward_infinity) + ? (0x3C00 & -(~(value >> 15) & (abs != 0))) + : (R == std::round_toward_neg_infinity) + ? (0x3C00 & -static_cast(value > 0x8000)) + : 0) | + (value & 0x8000); + } + if(abs >= 0x6400) + return (abs > 0x7C00) ? signal(value) : value; + unsigned int exp = 25 - (abs >> 10), mask = (1 << exp) - 1; + raise(FE_INEXACT, I && (value & mask)); + return (((R == std::round_to_nearest) + ? ((1 << (exp - 1)) - (~(value >> exp) & E)) + : (R == std::round_toward_infinity) + ? (mask & ((value >> 15) - 1)) + : (R == std::round_toward_neg_infinity) ? (mask & -(value >> 15)) : 0) + + value) & + ~mask; +} + +/// Convert fixed point to half-precision floating-point. +/// \tparam R rounding mode to use +/// \tparam F number of fractional bits (at least 11) +/// \tparam S `true` for signed, `false` for unsigned +/// \tparam N `true` for additional normalization step, `false` if already normalized to 1.F +/// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results +/// \param m mantissa in Q1.F fixed point format +/// \param exp exponent +/// \param sign half-precision value with sign bit only +/// \param s sticky bit (or of all but the most significant already discarded bits) +/// \return value converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded or \a I is `true` +template +unsigned int fixed2half(uint32 m, int exp = 14, unsigned int sign = 0, int s = 0) +{ + if(S) + { + uint32 msign = sign_mask(m); + m = (m ^ msign) - msign; + sign = msign & 0x8000; + } + if(N) + for(; m < (static_cast(1) << F) && exp; m <<= 1, --exp) + ; + else if(exp < 0) + return rounded(sign + (m >> (F - 10 - exp)), + (m >> (F - 11 - exp)) & 1, + s | ((m & ((static_cast(1) << (F - 11 - exp)) - 1)) != 0)); + return rounded(sign + (exp << 10) + (m >> (F - 10)), + (m >> (F - 11)) & 1, + s | ((m & ((static_cast(1) << (F - 11)) - 1)) != 0)); +} + +/// Convert IEEE single-precision to half-precision. +/// Credit for this goes to [Jeroen van der +/// Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). +/// \tparam R rounding mode to use +/// \param value single-precision value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half_impl(float value, true_type) +{ +#if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtsi128_si32(_mm_cvtps_ph(_mm_set_ss(value), + (R == std::round_to_nearest) + ? _MM_FROUND_TO_NEAREST_INT + : (R == std::round_toward_zero) + ? _MM_FROUND_TO_ZERO + : (R == std::round_toward_infinity) + ? _MM_FROUND_TO_POS_INF + : (R == std::round_toward_neg_infinity) + ? _MM_FROUND_TO_NEG_INF + : _MM_FROUND_CUR_DIRECTION)); +#else + bits::type fbits; + std::memcpy(&fbits, &value, sizeof(float)); +#if 1 + unsigned int sign = (fbits >> 16) & 0x8000; + fbits &= 0x7FFFFFFF; + if(fbits >= 0x7F800000) + return sign | 0x7C00 | ((fbits > 0x7F800000) ? (0x200 | ((fbits >> 13) & 0x3FF)) : 0); + if(fbits >= 0x47800000) + return overflow(sign); + if(fbits >= 0x38800000) + return rounded(sign | (((fbits >> 23) - 112) << 10) | ((fbits >> 13) & 0x3FF), + (fbits >> 12) & 1, + (fbits & 0xFFF) != 0); + if(fbits >= 0x33000000) + { + int i = 125 - (fbits >> 23); + fbits = (fbits & 0x7FFFFF) | 0x800000; + return rounded(sign | (fbits >> (i + 1)), + (fbits >> i) & 1, + (fbits & ((static_cast(1) << i) - 1)) != 0); + } + if(fbits != 0) + return underflow(sign); + return sign; +#else + static const uint16 base_table[512] = { + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, + 0x0080, 0x0100, 0x0200, 0x0400, 0x0800, 0x0C00, 0x1000, 0x1400, 0x1800, 0x1C00, 0x2000, + 0x2400, 0x2800, 0x2C00, 0x3000, 0x3400, 0x3800, 0x3C00, 0x4000, 0x4400, 0x4800, 0x4C00, + 0x5000, 0x5400, 0x5800, 0x5C00, 0x6000, 0x6400, 0x6800, 0x6C00, 0x7000, 0x7400, 0x7800, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7C00, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8001, 0x8002, 0x8004, 0x8008, + 0x8010, 0x8020, 0x8040, 0x8080, 0x8100, 0x8200, 0x8400, 0x8800, 0x8C00, 0x9000, 0x9400, + 0x9800, 0x9C00, 0xA000, 0xA400, 0xA800, 0xAC00, 0xB000, 0xB400, 0xB800, 0xBC00, 0xC000, + 0xC400, 0xC800, 0xCC00, 0xD000, 0xD400, 0xD800, 0xDC00, 0xE000, 0xE400, 0xE800, 0xEC00, + 0xF000, 0xF400, 0xF800, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFC00}; + static const unsigned char shift_table[256] = { + 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 24, 23, 22, 21, 20, 19, 18, 17, + 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 13}; + int sexp = fbits >> 23, exp = sexp & 0xFF, i = shift_table[exp]; + fbits &= 0x7FFFFF; + uint32 m = (fbits | ((exp != 0) << 23)) & -static_cast(exp != 0xFF); + return rounded(base_table[sexp] + (fbits >> i), + (m >> (i - 1)) & 1, + (((static_cast(1) << (i - 1)) - 1) & m) != 0); +#endif +#endif +} + +/// Convert IEEE double-precision to half-precision. +/// \tparam R rounding mode to use +/// \param value double-precision value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half_impl(double value, true_type) +{ +#if HALF_ENABLE_F16C_INTRINSICS + if(R == std::round_indeterminate) + return _mm_cvtsi128_si32( + _mm_cvtps_ph(_mm_cvtpd_ps(_mm_set_sd(value)), _MM_FROUND_CUR_DIRECTION)); +#endif + bits::type dbits; + std::memcpy(&dbits, &value, sizeof(double)); + uint32 hi = dbits >> 32, lo = dbits & 0xFFFFFFFF; + unsigned int sign = (hi >> 16) & 0x8000; + hi &= 0x7FFFFFFF; + if(hi >= 0x7FF00000) + return sign | 0x7C00 | ((dbits & 0xFFFFFFFFFFFFF) ? (0x200 | ((hi >> 10) & 0x3FF)) : 0); + if(hi >= 0x40F00000) + return overflow(sign); + if(hi >= 0x3F100000) + return rounded(sign | (((hi >> 20) - 1008) << 10) | ((hi >> 10) & 0x3FF), + (hi >> 9) & 1, + ((hi & 0x1FF) | lo) != 0); + if(hi >= 0x3E600000) + { + int i = 1018 - (hi >> 20); + hi = (hi & 0xFFFFF) | 0x100000; + return rounded(sign | (hi >> (i + 1)), + (hi >> i) & 1, + ((hi & ((static_cast(1) << i) - 1)) | lo) != 0); + } + if((hi | lo) != 0) + return underflow(sign); + return sign; +} + +/// Convert non-IEEE floating-point to half-precision. +/// \tparam R rounding mode to use +/// \tparam T source type (builtin floating-point type) +/// \param value floating-point value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half_impl(T value, ...) +{ + unsigned int hbits = static_cast(builtin_signbit(value)) << 15; + if(value == T()) + return hbits; + if(builtin_isnan(value)) + return hbits | 0x7FFF; + if(builtin_isinf(value)) + return hbits | 0x7C00; + int exp; + std::frexp(value, &exp); + if(exp > 16) + return overflow(hbits); + if(exp < -13) + value = std::ldexp(value, 25); + else + { + value = std::ldexp(value, 12 - exp); + hbits |= ((exp + 13) << 10); + } + T ival, frac = std::modf(value, &ival); + int m = std::abs(static_cast(ival)); + return rounded(hbits + (m >> 1), m & 1, frac != T()); +} + +/// Convert floating-point to half-precision. +/// \tparam R rounding mode to use +/// \tparam T source type (builtin floating-point type) +/// \param value floating-point value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half(T value) +{ + return float2half_impl(value, + bool_type < std::numeric_limits::is_iec559 && + sizeof(typename bits::type) == sizeof(T) > ()); +} + +/// Convert integer to half-precision floating-point. +/// \tparam R rounding mode to use +/// \tparam T type to convert (builtin integer type) +/// \param value integral value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int int2half(T value) +{ + unsigned int bits = static_cast(value < 0) << 15; + if(!value) + return bits; + if(bits) + value = -value; + if(value > 0xFFFF) + return overflow(bits); + unsigned int m = static_cast(value), exp = 24; + for(; m < 0x400; m <<= 1, --exp) + ; + for(; m > 0x7FF; m >>= 1, ++exp) + ; + bits |= (exp << 10) + m; + return (exp > 24) ? rounded( + bits, (value >> (exp - 25)) & 1, (((1 << (exp - 25)) - 1) & value) != 0) + : bits; +} + +/// Convert half-precision to IEEE single-precision. +/// Credit for this goes to [Jeroen van der +/// Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). +/// \param value half-precision value to convert +/// \return single-precision value +inline float half2float_impl(unsigned int value, float, true_type) +{ +#if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(value))); +#else +#if 0 + bits::type fbits = static_cast::type>(value&0x8000) << 16; + int abs = value & 0x7FFF; + if(abs) + { + fbits |= 0x38000000 << static_cast(abs>=0x7C00); + for(; abs<0x400; abs<<=1,fbits-=0x800000) ; + fbits += static_cast::type>(abs) << 13; + } +#else + static const bits::type mantissa_table[2048] = { + 0x00000000, 0x33800000, 0x34000000, 0x34400000, 0x34800000, 0x34A00000, 0x34C00000, + 0x34E00000, 0x35000000, 0x35100000, 0x35200000, 0x35300000, 0x35400000, 0x35500000, + 0x35600000, 0x35700000, 0x35800000, 0x35880000, 0x35900000, 0x35980000, 0x35A00000, + 0x35A80000, 0x35B00000, 0x35B80000, 0x35C00000, 0x35C80000, 0x35D00000, 0x35D80000, + 0x35E00000, 0x35E80000, 0x35F00000, 0x35F80000, 0x36000000, 0x36040000, 0x36080000, + 0x360C0000, 0x36100000, 0x36140000, 0x36180000, 0x361C0000, 0x36200000, 0x36240000, + 0x36280000, 0x362C0000, 0x36300000, 0x36340000, 0x36380000, 0x363C0000, 0x36400000, + 0x36440000, 0x36480000, 0x364C0000, 0x36500000, 0x36540000, 0x36580000, 0x365C0000, + 0x36600000, 0x36640000, 0x36680000, 0x366C0000, 0x36700000, 0x36740000, 0x36780000, + 0x367C0000, 0x36800000, 0x36820000, 0x36840000, 0x36860000, 0x36880000, 0x368A0000, + 0x368C0000, 0x368E0000, 0x36900000, 0x36920000, 0x36940000, 0x36960000, 0x36980000, + 0x369A0000, 0x369C0000, 0x369E0000, 0x36A00000, 0x36A20000, 0x36A40000, 0x36A60000, + 0x36A80000, 0x36AA0000, 0x36AC0000, 0x36AE0000, 0x36B00000, 0x36B20000, 0x36B40000, + 0x36B60000, 0x36B80000, 0x36BA0000, 0x36BC0000, 0x36BE0000, 0x36C00000, 0x36C20000, + 0x36C40000, 0x36C60000, 0x36C80000, 0x36CA0000, 0x36CC0000, 0x36CE0000, 0x36D00000, + 0x36D20000, 0x36D40000, 0x36D60000, 0x36D80000, 0x36DA0000, 0x36DC0000, 0x36DE0000, + 0x36E00000, 0x36E20000, 0x36E40000, 0x36E60000, 0x36E80000, 0x36EA0000, 0x36EC0000, + 0x36EE0000, 0x36F00000, 0x36F20000, 0x36F40000, 0x36F60000, 0x36F80000, 0x36FA0000, + 0x36FC0000, 0x36FE0000, 0x37000000, 0x37010000, 0x37020000, 0x37030000, 0x37040000, + 0x37050000, 0x37060000, 0x37070000, 0x37080000, 0x37090000, 0x370A0000, 0x370B0000, + 0x370C0000, 0x370D0000, 0x370E0000, 0x370F0000, 0x37100000, 0x37110000, 0x37120000, + 0x37130000, 0x37140000, 0x37150000, 0x37160000, 0x37170000, 0x37180000, 0x37190000, + 0x371A0000, 0x371B0000, 0x371C0000, 0x371D0000, 0x371E0000, 0x371F0000, 0x37200000, + 0x37210000, 0x37220000, 0x37230000, 0x37240000, 0x37250000, 0x37260000, 0x37270000, + 0x37280000, 0x37290000, 0x372A0000, 0x372B0000, 0x372C0000, 0x372D0000, 0x372E0000, + 0x372F0000, 0x37300000, 0x37310000, 0x37320000, 0x37330000, 0x37340000, 0x37350000, + 0x37360000, 0x37370000, 0x37380000, 0x37390000, 0x373A0000, 0x373B0000, 0x373C0000, + 0x373D0000, 0x373E0000, 0x373F0000, 0x37400000, 0x37410000, 0x37420000, 0x37430000, + 0x37440000, 0x37450000, 0x37460000, 0x37470000, 0x37480000, 0x37490000, 0x374A0000, + 0x374B0000, 0x374C0000, 0x374D0000, 0x374E0000, 0x374F0000, 0x37500000, 0x37510000, + 0x37520000, 0x37530000, 0x37540000, 0x37550000, 0x37560000, 0x37570000, 0x37580000, + 0x37590000, 0x375A0000, 0x375B0000, 0x375C0000, 0x375D0000, 0x375E0000, 0x375F0000, + 0x37600000, 0x37610000, 0x37620000, 0x37630000, 0x37640000, 0x37650000, 0x37660000, + 0x37670000, 0x37680000, 0x37690000, 0x376A0000, 0x376B0000, 0x376C0000, 0x376D0000, + 0x376E0000, 0x376F0000, 0x37700000, 0x37710000, 0x37720000, 0x37730000, 0x37740000, + 0x37750000, 0x37760000, 0x37770000, 0x37780000, 0x37790000, 0x377A0000, 0x377B0000, + 0x377C0000, 0x377D0000, 0x377E0000, 0x377F0000, 0x37800000, 0x37808000, 0x37810000, + 0x37818000, 0x37820000, 0x37828000, 0x37830000, 0x37838000, 0x37840000, 0x37848000, + 0x37850000, 0x37858000, 0x37860000, 0x37868000, 0x37870000, 0x37878000, 0x37880000, + 0x37888000, 0x37890000, 0x37898000, 0x378A0000, 0x378A8000, 0x378B0000, 0x378B8000, + 0x378C0000, 0x378C8000, 0x378D0000, 0x378D8000, 0x378E0000, 0x378E8000, 0x378F0000, + 0x378F8000, 0x37900000, 0x37908000, 0x37910000, 0x37918000, 0x37920000, 0x37928000, + 0x37930000, 0x37938000, 0x37940000, 0x37948000, 0x37950000, 0x37958000, 0x37960000, + 0x37968000, 0x37970000, 0x37978000, 0x37980000, 0x37988000, 0x37990000, 0x37998000, + 0x379A0000, 0x379A8000, 0x379B0000, 0x379B8000, 0x379C0000, 0x379C8000, 0x379D0000, + 0x379D8000, 0x379E0000, 0x379E8000, 0x379F0000, 0x379F8000, 0x37A00000, 0x37A08000, + 0x37A10000, 0x37A18000, 0x37A20000, 0x37A28000, 0x37A30000, 0x37A38000, 0x37A40000, + 0x37A48000, 0x37A50000, 0x37A58000, 0x37A60000, 0x37A68000, 0x37A70000, 0x37A78000, + 0x37A80000, 0x37A88000, 0x37A90000, 0x37A98000, 0x37AA0000, 0x37AA8000, 0x37AB0000, + 0x37AB8000, 0x37AC0000, 0x37AC8000, 0x37AD0000, 0x37AD8000, 0x37AE0000, 0x37AE8000, + 0x37AF0000, 0x37AF8000, 0x37B00000, 0x37B08000, 0x37B10000, 0x37B18000, 0x37B20000, + 0x37B28000, 0x37B30000, 0x37B38000, 0x37B40000, 0x37B48000, 0x37B50000, 0x37B58000, + 0x37B60000, 0x37B68000, 0x37B70000, 0x37B78000, 0x37B80000, 0x37B88000, 0x37B90000, + 0x37B98000, 0x37BA0000, 0x37BA8000, 0x37BB0000, 0x37BB8000, 0x37BC0000, 0x37BC8000, + 0x37BD0000, 0x37BD8000, 0x37BE0000, 0x37BE8000, 0x37BF0000, 0x37BF8000, 0x37C00000, + 0x37C08000, 0x37C10000, 0x37C18000, 0x37C20000, 0x37C28000, 0x37C30000, 0x37C38000, + 0x37C40000, 0x37C48000, 0x37C50000, 0x37C58000, 0x37C60000, 0x37C68000, 0x37C70000, + 0x37C78000, 0x37C80000, 0x37C88000, 0x37C90000, 0x37C98000, 0x37CA0000, 0x37CA8000, + 0x37CB0000, 0x37CB8000, 0x37CC0000, 0x37CC8000, 0x37CD0000, 0x37CD8000, 0x37CE0000, + 0x37CE8000, 0x37CF0000, 0x37CF8000, 0x37D00000, 0x37D08000, 0x37D10000, 0x37D18000, + 0x37D20000, 0x37D28000, 0x37D30000, 0x37D38000, 0x37D40000, 0x37D48000, 0x37D50000, + 0x37D58000, 0x37D60000, 0x37D68000, 0x37D70000, 0x37D78000, 0x37D80000, 0x37D88000, + 0x37D90000, 0x37D98000, 0x37DA0000, 0x37DA8000, 0x37DB0000, 0x37DB8000, 0x37DC0000, + 0x37DC8000, 0x37DD0000, 0x37DD8000, 0x37DE0000, 0x37DE8000, 0x37DF0000, 0x37DF8000, + 0x37E00000, 0x37E08000, 0x37E10000, 0x37E18000, 0x37E20000, 0x37E28000, 0x37E30000, + 0x37E38000, 0x37E40000, 0x37E48000, 0x37E50000, 0x37E58000, 0x37E60000, 0x37E68000, + 0x37E70000, 0x37E78000, 0x37E80000, 0x37E88000, 0x37E90000, 0x37E98000, 0x37EA0000, + 0x37EA8000, 0x37EB0000, 0x37EB8000, 0x37EC0000, 0x37EC8000, 0x37ED0000, 0x37ED8000, + 0x37EE0000, 0x37EE8000, 0x37EF0000, 0x37EF8000, 0x37F00000, 0x37F08000, 0x37F10000, + 0x37F18000, 0x37F20000, 0x37F28000, 0x37F30000, 0x37F38000, 0x37F40000, 0x37F48000, + 0x37F50000, 0x37F58000, 0x37F60000, 0x37F68000, 0x37F70000, 0x37F78000, 0x37F80000, + 0x37F88000, 0x37F90000, 0x37F98000, 0x37FA0000, 0x37FA8000, 0x37FB0000, 0x37FB8000, + 0x37FC0000, 0x37FC8000, 0x37FD0000, 0x37FD8000, 0x37FE0000, 0x37FE8000, 0x37FF0000, + 0x37FF8000, 0x38000000, 0x38004000, 0x38008000, 0x3800C000, 0x38010000, 0x38014000, + 0x38018000, 0x3801C000, 0x38020000, 0x38024000, 0x38028000, 0x3802C000, 0x38030000, + 0x38034000, 0x38038000, 0x3803C000, 0x38040000, 0x38044000, 0x38048000, 0x3804C000, + 0x38050000, 0x38054000, 0x38058000, 0x3805C000, 0x38060000, 0x38064000, 0x38068000, + 0x3806C000, 0x38070000, 0x38074000, 0x38078000, 0x3807C000, 0x38080000, 0x38084000, + 0x38088000, 0x3808C000, 0x38090000, 0x38094000, 0x38098000, 0x3809C000, 0x380A0000, + 0x380A4000, 0x380A8000, 0x380AC000, 0x380B0000, 0x380B4000, 0x380B8000, 0x380BC000, + 0x380C0000, 0x380C4000, 0x380C8000, 0x380CC000, 0x380D0000, 0x380D4000, 0x380D8000, + 0x380DC000, 0x380E0000, 0x380E4000, 0x380E8000, 0x380EC000, 0x380F0000, 0x380F4000, + 0x380F8000, 0x380FC000, 0x38100000, 0x38104000, 0x38108000, 0x3810C000, 0x38110000, + 0x38114000, 0x38118000, 0x3811C000, 0x38120000, 0x38124000, 0x38128000, 0x3812C000, + 0x38130000, 0x38134000, 0x38138000, 0x3813C000, 0x38140000, 0x38144000, 0x38148000, + 0x3814C000, 0x38150000, 0x38154000, 0x38158000, 0x3815C000, 0x38160000, 0x38164000, + 0x38168000, 0x3816C000, 0x38170000, 0x38174000, 0x38178000, 0x3817C000, 0x38180000, + 0x38184000, 0x38188000, 0x3818C000, 0x38190000, 0x38194000, 0x38198000, 0x3819C000, + 0x381A0000, 0x381A4000, 0x381A8000, 0x381AC000, 0x381B0000, 0x381B4000, 0x381B8000, + 0x381BC000, 0x381C0000, 0x381C4000, 0x381C8000, 0x381CC000, 0x381D0000, 0x381D4000, + 0x381D8000, 0x381DC000, 0x381E0000, 0x381E4000, 0x381E8000, 0x381EC000, 0x381F0000, + 0x381F4000, 0x381F8000, 0x381FC000, 0x38200000, 0x38204000, 0x38208000, 0x3820C000, + 0x38210000, 0x38214000, 0x38218000, 0x3821C000, 0x38220000, 0x38224000, 0x38228000, + 0x3822C000, 0x38230000, 0x38234000, 0x38238000, 0x3823C000, 0x38240000, 0x38244000, + 0x38248000, 0x3824C000, 0x38250000, 0x38254000, 0x38258000, 0x3825C000, 0x38260000, + 0x38264000, 0x38268000, 0x3826C000, 0x38270000, 0x38274000, 0x38278000, 0x3827C000, + 0x38280000, 0x38284000, 0x38288000, 0x3828C000, 0x38290000, 0x38294000, 0x38298000, + 0x3829C000, 0x382A0000, 0x382A4000, 0x382A8000, 0x382AC000, 0x382B0000, 0x382B4000, + 0x382B8000, 0x382BC000, 0x382C0000, 0x382C4000, 0x382C8000, 0x382CC000, 0x382D0000, + 0x382D4000, 0x382D8000, 0x382DC000, 0x382E0000, 0x382E4000, 0x382E8000, 0x382EC000, + 0x382F0000, 0x382F4000, 0x382F8000, 0x382FC000, 0x38300000, 0x38304000, 0x38308000, + 0x3830C000, 0x38310000, 0x38314000, 0x38318000, 0x3831C000, 0x38320000, 0x38324000, + 0x38328000, 0x3832C000, 0x38330000, 0x38334000, 0x38338000, 0x3833C000, 0x38340000, + 0x38344000, 0x38348000, 0x3834C000, 0x38350000, 0x38354000, 0x38358000, 0x3835C000, + 0x38360000, 0x38364000, 0x38368000, 0x3836C000, 0x38370000, 0x38374000, 0x38378000, + 0x3837C000, 0x38380000, 0x38384000, 0x38388000, 0x3838C000, 0x38390000, 0x38394000, + 0x38398000, 0x3839C000, 0x383A0000, 0x383A4000, 0x383A8000, 0x383AC000, 0x383B0000, + 0x383B4000, 0x383B8000, 0x383BC000, 0x383C0000, 0x383C4000, 0x383C8000, 0x383CC000, + 0x383D0000, 0x383D4000, 0x383D8000, 0x383DC000, 0x383E0000, 0x383E4000, 0x383E8000, + 0x383EC000, 0x383F0000, 0x383F4000, 0x383F8000, 0x383FC000, 0x38400000, 0x38404000, + 0x38408000, 0x3840C000, 0x38410000, 0x38414000, 0x38418000, 0x3841C000, 0x38420000, + 0x38424000, 0x38428000, 0x3842C000, 0x38430000, 0x38434000, 0x38438000, 0x3843C000, + 0x38440000, 0x38444000, 0x38448000, 0x3844C000, 0x38450000, 0x38454000, 0x38458000, + 0x3845C000, 0x38460000, 0x38464000, 0x38468000, 0x3846C000, 0x38470000, 0x38474000, + 0x38478000, 0x3847C000, 0x38480000, 0x38484000, 0x38488000, 0x3848C000, 0x38490000, + 0x38494000, 0x38498000, 0x3849C000, 0x384A0000, 0x384A4000, 0x384A8000, 0x384AC000, + 0x384B0000, 0x384B4000, 0x384B8000, 0x384BC000, 0x384C0000, 0x384C4000, 0x384C8000, + 0x384CC000, 0x384D0000, 0x384D4000, 0x384D8000, 0x384DC000, 0x384E0000, 0x384E4000, + 0x384E8000, 0x384EC000, 0x384F0000, 0x384F4000, 0x384F8000, 0x384FC000, 0x38500000, + 0x38504000, 0x38508000, 0x3850C000, 0x38510000, 0x38514000, 0x38518000, 0x3851C000, + 0x38520000, 0x38524000, 0x38528000, 0x3852C000, 0x38530000, 0x38534000, 0x38538000, + 0x3853C000, 0x38540000, 0x38544000, 0x38548000, 0x3854C000, 0x38550000, 0x38554000, + 0x38558000, 0x3855C000, 0x38560000, 0x38564000, 0x38568000, 0x3856C000, 0x38570000, + 0x38574000, 0x38578000, 0x3857C000, 0x38580000, 0x38584000, 0x38588000, 0x3858C000, + 0x38590000, 0x38594000, 0x38598000, 0x3859C000, 0x385A0000, 0x385A4000, 0x385A8000, + 0x385AC000, 0x385B0000, 0x385B4000, 0x385B8000, 0x385BC000, 0x385C0000, 0x385C4000, + 0x385C8000, 0x385CC000, 0x385D0000, 0x385D4000, 0x385D8000, 0x385DC000, 0x385E0000, + 0x385E4000, 0x385E8000, 0x385EC000, 0x385F0000, 0x385F4000, 0x385F8000, 0x385FC000, + 0x38600000, 0x38604000, 0x38608000, 0x3860C000, 0x38610000, 0x38614000, 0x38618000, + 0x3861C000, 0x38620000, 0x38624000, 0x38628000, 0x3862C000, 0x38630000, 0x38634000, + 0x38638000, 0x3863C000, 0x38640000, 0x38644000, 0x38648000, 0x3864C000, 0x38650000, + 0x38654000, 0x38658000, 0x3865C000, 0x38660000, 0x38664000, 0x38668000, 0x3866C000, + 0x38670000, 0x38674000, 0x38678000, 0x3867C000, 0x38680000, 0x38684000, 0x38688000, + 0x3868C000, 0x38690000, 0x38694000, 0x38698000, 0x3869C000, 0x386A0000, 0x386A4000, + 0x386A8000, 0x386AC000, 0x386B0000, 0x386B4000, 0x386B8000, 0x386BC000, 0x386C0000, + 0x386C4000, 0x386C8000, 0x386CC000, 0x386D0000, 0x386D4000, 0x386D8000, 0x386DC000, + 0x386E0000, 0x386E4000, 0x386E8000, 0x386EC000, 0x386F0000, 0x386F4000, 0x386F8000, + 0x386FC000, 0x38700000, 0x38704000, 0x38708000, 0x3870C000, 0x38710000, 0x38714000, + 0x38718000, 0x3871C000, 0x38720000, 0x38724000, 0x38728000, 0x3872C000, 0x38730000, + 0x38734000, 0x38738000, 0x3873C000, 0x38740000, 0x38744000, 0x38748000, 0x3874C000, + 0x38750000, 0x38754000, 0x38758000, 0x3875C000, 0x38760000, 0x38764000, 0x38768000, + 0x3876C000, 0x38770000, 0x38774000, 0x38778000, 0x3877C000, 0x38780000, 0x38784000, + 0x38788000, 0x3878C000, 0x38790000, 0x38794000, 0x38798000, 0x3879C000, 0x387A0000, + 0x387A4000, 0x387A8000, 0x387AC000, 0x387B0000, 0x387B4000, 0x387B8000, 0x387BC000, + 0x387C0000, 0x387C4000, 0x387C8000, 0x387CC000, 0x387D0000, 0x387D4000, 0x387D8000, + 0x387DC000, 0x387E0000, 0x387E4000, 0x387E8000, 0x387EC000, 0x387F0000, 0x387F4000, + 0x387F8000, 0x387FC000, 0x38000000, 0x38002000, 0x38004000, 0x38006000, 0x38008000, + 0x3800A000, 0x3800C000, 0x3800E000, 0x38010000, 0x38012000, 0x38014000, 0x38016000, + 0x38018000, 0x3801A000, 0x3801C000, 0x3801E000, 0x38020000, 0x38022000, 0x38024000, + 0x38026000, 0x38028000, 0x3802A000, 0x3802C000, 0x3802E000, 0x38030000, 0x38032000, + 0x38034000, 0x38036000, 0x38038000, 0x3803A000, 0x3803C000, 0x3803E000, 0x38040000, + 0x38042000, 0x38044000, 0x38046000, 0x38048000, 0x3804A000, 0x3804C000, 0x3804E000, + 0x38050000, 0x38052000, 0x38054000, 0x38056000, 0x38058000, 0x3805A000, 0x3805C000, + 0x3805E000, 0x38060000, 0x38062000, 0x38064000, 0x38066000, 0x38068000, 0x3806A000, + 0x3806C000, 0x3806E000, 0x38070000, 0x38072000, 0x38074000, 0x38076000, 0x38078000, + 0x3807A000, 0x3807C000, 0x3807E000, 0x38080000, 0x38082000, 0x38084000, 0x38086000, + 0x38088000, 0x3808A000, 0x3808C000, 0x3808E000, 0x38090000, 0x38092000, 0x38094000, + 0x38096000, 0x38098000, 0x3809A000, 0x3809C000, 0x3809E000, 0x380A0000, 0x380A2000, + 0x380A4000, 0x380A6000, 0x380A8000, 0x380AA000, 0x380AC000, 0x380AE000, 0x380B0000, + 0x380B2000, 0x380B4000, 0x380B6000, 0x380B8000, 0x380BA000, 0x380BC000, 0x380BE000, + 0x380C0000, 0x380C2000, 0x380C4000, 0x380C6000, 0x380C8000, 0x380CA000, 0x380CC000, + 0x380CE000, 0x380D0000, 0x380D2000, 0x380D4000, 0x380D6000, 0x380D8000, 0x380DA000, + 0x380DC000, 0x380DE000, 0x380E0000, 0x380E2000, 0x380E4000, 0x380E6000, 0x380E8000, + 0x380EA000, 0x380EC000, 0x380EE000, 0x380F0000, 0x380F2000, 0x380F4000, 0x380F6000, + 0x380F8000, 0x380FA000, 0x380FC000, 0x380FE000, 0x38100000, 0x38102000, 0x38104000, + 0x38106000, 0x38108000, 0x3810A000, 0x3810C000, 0x3810E000, 0x38110000, 0x38112000, + 0x38114000, 0x38116000, 0x38118000, 0x3811A000, 0x3811C000, 0x3811E000, 0x38120000, + 0x38122000, 0x38124000, 0x38126000, 0x38128000, 0x3812A000, 0x3812C000, 0x3812E000, + 0x38130000, 0x38132000, 0x38134000, 0x38136000, 0x38138000, 0x3813A000, 0x3813C000, + 0x3813E000, 0x38140000, 0x38142000, 0x38144000, 0x38146000, 0x38148000, 0x3814A000, + 0x3814C000, 0x3814E000, 0x38150000, 0x38152000, 0x38154000, 0x38156000, 0x38158000, + 0x3815A000, 0x3815C000, 0x3815E000, 0x38160000, 0x38162000, 0x38164000, 0x38166000, + 0x38168000, 0x3816A000, 0x3816C000, 0x3816E000, 0x38170000, 0x38172000, 0x38174000, + 0x38176000, 0x38178000, 0x3817A000, 0x3817C000, 0x3817E000, 0x38180000, 0x38182000, + 0x38184000, 0x38186000, 0x38188000, 0x3818A000, 0x3818C000, 0x3818E000, 0x38190000, + 0x38192000, 0x38194000, 0x38196000, 0x38198000, 0x3819A000, 0x3819C000, 0x3819E000, + 0x381A0000, 0x381A2000, 0x381A4000, 0x381A6000, 0x381A8000, 0x381AA000, 0x381AC000, + 0x381AE000, 0x381B0000, 0x381B2000, 0x381B4000, 0x381B6000, 0x381B8000, 0x381BA000, + 0x381BC000, 0x381BE000, 0x381C0000, 0x381C2000, 0x381C4000, 0x381C6000, 0x381C8000, + 0x381CA000, 0x381CC000, 0x381CE000, 0x381D0000, 0x381D2000, 0x381D4000, 0x381D6000, + 0x381D8000, 0x381DA000, 0x381DC000, 0x381DE000, 0x381E0000, 0x381E2000, 0x381E4000, + 0x381E6000, 0x381E8000, 0x381EA000, 0x381EC000, 0x381EE000, 0x381F0000, 0x381F2000, + 0x381F4000, 0x381F6000, 0x381F8000, 0x381FA000, 0x381FC000, 0x381FE000, 0x38200000, + 0x38202000, 0x38204000, 0x38206000, 0x38208000, 0x3820A000, 0x3820C000, 0x3820E000, + 0x38210000, 0x38212000, 0x38214000, 0x38216000, 0x38218000, 0x3821A000, 0x3821C000, + 0x3821E000, 0x38220000, 0x38222000, 0x38224000, 0x38226000, 0x38228000, 0x3822A000, + 0x3822C000, 0x3822E000, 0x38230000, 0x38232000, 0x38234000, 0x38236000, 0x38238000, + 0x3823A000, 0x3823C000, 0x3823E000, 0x38240000, 0x38242000, 0x38244000, 0x38246000, + 0x38248000, 0x3824A000, 0x3824C000, 0x3824E000, 0x38250000, 0x38252000, 0x38254000, + 0x38256000, 0x38258000, 0x3825A000, 0x3825C000, 0x3825E000, 0x38260000, 0x38262000, + 0x38264000, 0x38266000, 0x38268000, 0x3826A000, 0x3826C000, 0x3826E000, 0x38270000, + 0x38272000, 0x38274000, 0x38276000, 0x38278000, 0x3827A000, 0x3827C000, 0x3827E000, + 0x38280000, 0x38282000, 0x38284000, 0x38286000, 0x38288000, 0x3828A000, 0x3828C000, + 0x3828E000, 0x38290000, 0x38292000, 0x38294000, 0x38296000, 0x38298000, 0x3829A000, + 0x3829C000, 0x3829E000, 0x382A0000, 0x382A2000, 0x382A4000, 0x382A6000, 0x382A8000, + 0x382AA000, 0x382AC000, 0x382AE000, 0x382B0000, 0x382B2000, 0x382B4000, 0x382B6000, + 0x382B8000, 0x382BA000, 0x382BC000, 0x382BE000, 0x382C0000, 0x382C2000, 0x382C4000, + 0x382C6000, 0x382C8000, 0x382CA000, 0x382CC000, 0x382CE000, 0x382D0000, 0x382D2000, + 0x382D4000, 0x382D6000, 0x382D8000, 0x382DA000, 0x382DC000, 0x382DE000, 0x382E0000, + 0x382E2000, 0x382E4000, 0x382E6000, 0x382E8000, 0x382EA000, 0x382EC000, 0x382EE000, + 0x382F0000, 0x382F2000, 0x382F4000, 0x382F6000, 0x382F8000, 0x382FA000, 0x382FC000, + 0x382FE000, 0x38300000, 0x38302000, 0x38304000, 0x38306000, 0x38308000, 0x3830A000, + 0x3830C000, 0x3830E000, 0x38310000, 0x38312000, 0x38314000, 0x38316000, 0x38318000, + 0x3831A000, 0x3831C000, 0x3831E000, 0x38320000, 0x38322000, 0x38324000, 0x38326000, + 0x38328000, 0x3832A000, 0x3832C000, 0x3832E000, 0x38330000, 0x38332000, 0x38334000, + 0x38336000, 0x38338000, 0x3833A000, 0x3833C000, 0x3833E000, 0x38340000, 0x38342000, + 0x38344000, 0x38346000, 0x38348000, 0x3834A000, 0x3834C000, 0x3834E000, 0x38350000, + 0x38352000, 0x38354000, 0x38356000, 0x38358000, 0x3835A000, 0x3835C000, 0x3835E000, + 0x38360000, 0x38362000, 0x38364000, 0x38366000, 0x38368000, 0x3836A000, 0x3836C000, + 0x3836E000, 0x38370000, 0x38372000, 0x38374000, 0x38376000, 0x38378000, 0x3837A000, + 0x3837C000, 0x3837E000, 0x38380000, 0x38382000, 0x38384000, 0x38386000, 0x38388000, + 0x3838A000, 0x3838C000, 0x3838E000, 0x38390000, 0x38392000, 0x38394000, 0x38396000, + 0x38398000, 0x3839A000, 0x3839C000, 0x3839E000, 0x383A0000, 0x383A2000, 0x383A4000, + 0x383A6000, 0x383A8000, 0x383AA000, 0x383AC000, 0x383AE000, 0x383B0000, 0x383B2000, + 0x383B4000, 0x383B6000, 0x383B8000, 0x383BA000, 0x383BC000, 0x383BE000, 0x383C0000, + 0x383C2000, 0x383C4000, 0x383C6000, 0x383C8000, 0x383CA000, 0x383CC000, 0x383CE000, + 0x383D0000, 0x383D2000, 0x383D4000, 0x383D6000, 0x383D8000, 0x383DA000, 0x383DC000, + 0x383DE000, 0x383E0000, 0x383E2000, 0x383E4000, 0x383E6000, 0x383E8000, 0x383EA000, + 0x383EC000, 0x383EE000, 0x383F0000, 0x383F2000, 0x383F4000, 0x383F6000, 0x383F8000, + 0x383FA000, 0x383FC000, 0x383FE000, 0x38400000, 0x38402000, 0x38404000, 0x38406000, + 0x38408000, 0x3840A000, 0x3840C000, 0x3840E000, 0x38410000, 0x38412000, 0x38414000, + 0x38416000, 0x38418000, 0x3841A000, 0x3841C000, 0x3841E000, 0x38420000, 0x38422000, + 0x38424000, 0x38426000, 0x38428000, 0x3842A000, 0x3842C000, 0x3842E000, 0x38430000, + 0x38432000, 0x38434000, 0x38436000, 0x38438000, 0x3843A000, 0x3843C000, 0x3843E000, + 0x38440000, 0x38442000, 0x38444000, 0x38446000, 0x38448000, 0x3844A000, 0x3844C000, + 0x3844E000, 0x38450000, 0x38452000, 0x38454000, 0x38456000, 0x38458000, 0x3845A000, + 0x3845C000, 0x3845E000, 0x38460000, 0x38462000, 0x38464000, 0x38466000, 0x38468000, + 0x3846A000, 0x3846C000, 0x3846E000, 0x38470000, 0x38472000, 0x38474000, 0x38476000, + 0x38478000, 0x3847A000, 0x3847C000, 0x3847E000, 0x38480000, 0x38482000, 0x38484000, + 0x38486000, 0x38488000, 0x3848A000, 0x3848C000, 0x3848E000, 0x38490000, 0x38492000, + 0x38494000, 0x38496000, 0x38498000, 0x3849A000, 0x3849C000, 0x3849E000, 0x384A0000, + 0x384A2000, 0x384A4000, 0x384A6000, 0x384A8000, 0x384AA000, 0x384AC000, 0x384AE000, + 0x384B0000, 0x384B2000, 0x384B4000, 0x384B6000, 0x384B8000, 0x384BA000, 0x384BC000, + 0x384BE000, 0x384C0000, 0x384C2000, 0x384C4000, 0x384C6000, 0x384C8000, 0x384CA000, + 0x384CC000, 0x384CE000, 0x384D0000, 0x384D2000, 0x384D4000, 0x384D6000, 0x384D8000, + 0x384DA000, 0x384DC000, 0x384DE000, 0x384E0000, 0x384E2000, 0x384E4000, 0x384E6000, + 0x384E8000, 0x384EA000, 0x384EC000, 0x384EE000, 0x384F0000, 0x384F2000, 0x384F4000, + 0x384F6000, 0x384F8000, 0x384FA000, 0x384FC000, 0x384FE000, 0x38500000, 0x38502000, + 0x38504000, 0x38506000, 0x38508000, 0x3850A000, 0x3850C000, 0x3850E000, 0x38510000, + 0x38512000, 0x38514000, 0x38516000, 0x38518000, 0x3851A000, 0x3851C000, 0x3851E000, + 0x38520000, 0x38522000, 0x38524000, 0x38526000, 0x38528000, 0x3852A000, 0x3852C000, + 0x3852E000, 0x38530000, 0x38532000, 0x38534000, 0x38536000, 0x38538000, 0x3853A000, + 0x3853C000, 0x3853E000, 0x38540000, 0x38542000, 0x38544000, 0x38546000, 0x38548000, + 0x3854A000, 0x3854C000, 0x3854E000, 0x38550000, 0x38552000, 0x38554000, 0x38556000, + 0x38558000, 0x3855A000, 0x3855C000, 0x3855E000, 0x38560000, 0x38562000, 0x38564000, + 0x38566000, 0x38568000, 0x3856A000, 0x3856C000, 0x3856E000, 0x38570000, 0x38572000, + 0x38574000, 0x38576000, 0x38578000, 0x3857A000, 0x3857C000, 0x3857E000, 0x38580000, + 0x38582000, 0x38584000, 0x38586000, 0x38588000, 0x3858A000, 0x3858C000, 0x3858E000, + 0x38590000, 0x38592000, 0x38594000, 0x38596000, 0x38598000, 0x3859A000, 0x3859C000, + 0x3859E000, 0x385A0000, 0x385A2000, 0x385A4000, 0x385A6000, 0x385A8000, 0x385AA000, + 0x385AC000, 0x385AE000, 0x385B0000, 0x385B2000, 0x385B4000, 0x385B6000, 0x385B8000, + 0x385BA000, 0x385BC000, 0x385BE000, 0x385C0000, 0x385C2000, 0x385C4000, 0x385C6000, + 0x385C8000, 0x385CA000, 0x385CC000, 0x385CE000, 0x385D0000, 0x385D2000, 0x385D4000, + 0x385D6000, 0x385D8000, 0x385DA000, 0x385DC000, 0x385DE000, 0x385E0000, 0x385E2000, + 0x385E4000, 0x385E6000, 0x385E8000, 0x385EA000, 0x385EC000, 0x385EE000, 0x385F0000, + 0x385F2000, 0x385F4000, 0x385F6000, 0x385F8000, 0x385FA000, 0x385FC000, 0x385FE000, + 0x38600000, 0x38602000, 0x38604000, 0x38606000, 0x38608000, 0x3860A000, 0x3860C000, + 0x3860E000, 0x38610000, 0x38612000, 0x38614000, 0x38616000, 0x38618000, 0x3861A000, + 0x3861C000, 0x3861E000, 0x38620000, 0x38622000, 0x38624000, 0x38626000, 0x38628000, + 0x3862A000, 0x3862C000, 0x3862E000, 0x38630000, 0x38632000, 0x38634000, 0x38636000, + 0x38638000, 0x3863A000, 0x3863C000, 0x3863E000, 0x38640000, 0x38642000, 0x38644000, + 0x38646000, 0x38648000, 0x3864A000, 0x3864C000, 0x3864E000, 0x38650000, 0x38652000, + 0x38654000, 0x38656000, 0x38658000, 0x3865A000, 0x3865C000, 0x3865E000, 0x38660000, + 0x38662000, 0x38664000, 0x38666000, 0x38668000, 0x3866A000, 0x3866C000, 0x3866E000, + 0x38670000, 0x38672000, 0x38674000, 0x38676000, 0x38678000, 0x3867A000, 0x3867C000, + 0x3867E000, 0x38680000, 0x38682000, 0x38684000, 0x38686000, 0x38688000, 0x3868A000, + 0x3868C000, 0x3868E000, 0x38690000, 0x38692000, 0x38694000, 0x38696000, 0x38698000, + 0x3869A000, 0x3869C000, 0x3869E000, 0x386A0000, 0x386A2000, 0x386A4000, 0x386A6000, + 0x386A8000, 0x386AA000, 0x386AC000, 0x386AE000, 0x386B0000, 0x386B2000, 0x386B4000, + 0x386B6000, 0x386B8000, 0x386BA000, 0x386BC000, 0x386BE000, 0x386C0000, 0x386C2000, + 0x386C4000, 0x386C6000, 0x386C8000, 0x386CA000, 0x386CC000, 0x386CE000, 0x386D0000, + 0x386D2000, 0x386D4000, 0x386D6000, 0x386D8000, 0x386DA000, 0x386DC000, 0x386DE000, + 0x386E0000, 0x386E2000, 0x386E4000, 0x386E6000, 0x386E8000, 0x386EA000, 0x386EC000, + 0x386EE000, 0x386F0000, 0x386F2000, 0x386F4000, 0x386F6000, 0x386F8000, 0x386FA000, + 0x386FC000, 0x386FE000, 0x38700000, 0x38702000, 0x38704000, 0x38706000, 0x38708000, + 0x3870A000, 0x3870C000, 0x3870E000, 0x38710000, 0x38712000, 0x38714000, 0x38716000, + 0x38718000, 0x3871A000, 0x3871C000, 0x3871E000, 0x38720000, 0x38722000, 0x38724000, + 0x38726000, 0x38728000, 0x3872A000, 0x3872C000, 0x3872E000, 0x38730000, 0x38732000, + 0x38734000, 0x38736000, 0x38738000, 0x3873A000, 0x3873C000, 0x3873E000, 0x38740000, + 0x38742000, 0x38744000, 0x38746000, 0x38748000, 0x3874A000, 0x3874C000, 0x3874E000, + 0x38750000, 0x38752000, 0x38754000, 0x38756000, 0x38758000, 0x3875A000, 0x3875C000, + 0x3875E000, 0x38760000, 0x38762000, 0x38764000, 0x38766000, 0x38768000, 0x3876A000, + 0x3876C000, 0x3876E000, 0x38770000, 0x38772000, 0x38774000, 0x38776000, 0x38778000, + 0x3877A000, 0x3877C000, 0x3877E000, 0x38780000, 0x38782000, 0x38784000, 0x38786000, + 0x38788000, 0x3878A000, 0x3878C000, 0x3878E000, 0x38790000, 0x38792000, 0x38794000, + 0x38796000, 0x38798000, 0x3879A000, 0x3879C000, 0x3879E000, 0x387A0000, 0x387A2000, + 0x387A4000, 0x387A6000, 0x387A8000, 0x387AA000, 0x387AC000, 0x387AE000, 0x387B0000, + 0x387B2000, 0x387B4000, 0x387B6000, 0x387B8000, 0x387BA000, 0x387BC000, 0x387BE000, + 0x387C0000, 0x387C2000, 0x387C4000, 0x387C6000, 0x387C8000, 0x387CA000, 0x387CC000, + 0x387CE000, 0x387D0000, 0x387D2000, 0x387D4000, 0x387D6000, 0x387D8000, 0x387DA000, + 0x387DC000, 0x387DE000, 0x387E0000, 0x387E2000, 0x387E4000, 0x387E6000, 0x387E8000, + 0x387EA000, 0x387EC000, 0x387EE000, 0x387F0000, 0x387F2000, 0x387F4000, 0x387F6000, + 0x387F8000, 0x387FA000, 0x387FC000, 0x387FE000}; + static const bits::type exponent_table[64] = { + 0x00000000, 0x00800000, 0x01000000, 0x01800000, 0x02000000, 0x02800000, 0x03000000, + 0x03800000, 0x04000000, 0x04800000, 0x05000000, 0x05800000, 0x06000000, 0x06800000, + 0x07000000, 0x07800000, 0x08000000, 0x08800000, 0x09000000, 0x09800000, 0x0A000000, + 0x0A800000, 0x0B000000, 0x0B800000, 0x0C000000, 0x0C800000, 0x0D000000, 0x0D800000, + 0x0E000000, 0x0E800000, 0x0F000000, 0x47800000, 0x80000000, 0x80800000, 0x81000000, + 0x81800000, 0x82000000, 0x82800000, 0x83000000, 0x83800000, 0x84000000, 0x84800000, + 0x85000000, 0x85800000, 0x86000000, 0x86800000, 0x87000000, 0x87800000, 0x88000000, + 0x88800000, 0x89000000, 0x89800000, 0x8A000000, 0x8A800000, 0x8B000000, 0x8B800000, + 0x8C000000, 0x8C800000, 0x8D000000, 0x8D800000, 0x8E000000, 0x8E800000, 0x8F000000, + 0xC7800000}; + static const unsigned short offset_table[64] = { + 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 0, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024}; + bits::type fbits = + mantissa_table[offset_table[value >> 10] + (value & 0x3FF)] + exponent_table[value >> 10]; +#endif + float out; + std::memcpy(&out, &fbits, sizeof(float)); + return out; +#endif +} + +/// Convert half-precision to IEEE double-precision. +/// \param value half-precision value to convert +/// \return double-precision value +inline double half2float_impl(unsigned int value, double, true_type) +{ +#if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtsd_f64(_mm_cvtps_pd(_mm_cvtph_ps(_mm_cvtsi32_si128(value)))); +#else + uint32 hi = static_cast(value & 0x8000) << 16; + unsigned int abs = value & 0x7FFF; + if(abs) + { + hi |= 0x3F000000 << static_cast(abs >= 0x7C00); + for(; abs < 0x400; abs <<= 1, hi -= 0x100000) + ; + hi += static_cast(abs) << 10; + } + bits::type dbits = static_cast::type>(hi) << 32; + double out; + std::memcpy(&out, &dbits, sizeof(double)); + return out; +#endif +} + +/// Convert half-precision to non-IEEE floating-point. +/// \tparam T type to convert to (builtin integer type) +/// \param value half-precision value to convert +/// \return floating-point value +template +T half2float_impl(unsigned int value, T, ...) +{ + T out; + unsigned int abs = value & 0x7FFF; + if(abs > 0x7C00) + out = + (std::numeric_limits::has_signaling_NaN && !(abs & 0x200)) + ? std::numeric_limits::signaling_NaN() + : std::numeric_limits::has_quiet_NaN ? std::numeric_limits::quiet_NaN() : T(); + else if(abs == 0x7C00) + out = std::numeric_limits::has_infinity ? std::numeric_limits::infinity() + : std::numeric_limits::max(); + else if(abs > 0x3FF) + out = std::ldexp(static_cast((abs & 0x3FF) | 0x400), (abs >> 10) - 25); + else + out = std::ldexp(static_cast(abs), -24); + return (value & 0x8000) ? -out : out; +} + +/// Convert half-precision to floating-point. +/// \tparam T type to convert to (builtin integer type) +/// \param value half-precision value to convert +/// \return floating-point value +template +T half2float(unsigned int value) +{ + return half2float_impl(value, + T(), + bool_type < std::numeric_limits::is_iec559 && + sizeof(typename bits::type) == sizeof(T) > ()); +} + +/// Convert half-precision floating-point to integer. +/// \tparam R rounding mode to use +/// \tparam E `true` for round to even, `false` for round away from zero +/// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never raise it +/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding +/// any implicit sign bits) +/// \param value half-precision value to convert +/// \return rounded integer value +/// \exception FE_INVALID if value is not representable in type \a T +/// \exception FE_INEXACT if value had to be rounded and \a I is `true` +template +T half2int(unsigned int value) +{ + unsigned int abs = value & 0x7FFF; + if(abs >= 0x7C00) + { + raise(FE_INVALID); + return (value & 0x8000) ? std::numeric_limits::min() : std::numeric_limits::max(); + } + if(abs < 0x3800) + { + raise(FE_INEXACT, I); + return (R == std::round_toward_infinity) + ? T(~(value >> 15) & (abs != 0)) + : (R == std::round_toward_neg_infinity) ? -T(value > 0x8000) : T(); + } + int exp = 25 - (abs >> 10); + unsigned int m = (value & 0x3FF) | 0x400; + int32 i = static_cast( + (exp <= 0) + ? (m << -exp) + : ((m + ((R == std::round_to_nearest) ? ((1 << (exp - 1)) - (~(m >> exp) & E)) + : (R == std::round_toward_infinity) + ? (((1 << exp) - 1) & ((value >> 15) - 1)) + : (R == std::round_toward_neg_infinity) + ? (((1 << exp) - 1) & -(value >> 15)) + : 0)) >> + exp)); + if((!std::numeric_limits::is_signed && (value & 0x8000)) || + (std::numeric_limits::digits < 16 && + ((value & 0x8000) ? (-i < std::numeric_limits::min()) + : (i > std::numeric_limits::max())))) + raise(FE_INVALID); + else if(I && exp > 0 && (m & ((1 << exp) - 1))) + raise(FE_INEXACT); + return static_cast((value & 0x8000) ? -i : i); +} + +/// \} +/// \name Mathematics +/// \{ + +/// upper part of 64-bit multiplication. +/// \tparam R rounding mode to use +/// \param x first factor +/// \param y second factor +/// \return upper 32 bit of \a x * \a y +template +uint32 mulhi(uint32 x, uint32 y) +{ + uint32 xy = (x >> 16) * (y & 0xFFFF), yx = (x & 0xFFFF) * (y >> 16), + c = (xy & 0xFFFF) + (yx & 0xFFFF) + (((x & 0xFFFF) * (y & 0xFFFF)) >> 16); + return (x >> 16) * (y >> 16) + (xy >> 16) + (yx >> 16) + (c >> 16) + + ((R == std::round_to_nearest) + ? ((c >> 15) & 1) + : (R == std::round_toward_infinity) ? ((c & 0xFFFF) != 0) : 0); +} + +/// 64-bit multiplication. +/// \param x first factor +/// \param y second factor +/// \return upper 32 bit of \a x * \a y rounded to nearest +inline uint32 multiply64(uint32 x, uint32 y) +{ +#if HALF_ENABLE_CPP11_LONG_LONG + return static_cast( + (static_cast(x) * static_cast(y) + 0x80000000) >> + 32); +#else + return mulhi(x, y); +#endif +} + +/// 64-bit division. +/// \param x upper 32 bit of dividend +/// \param y divisor +/// \param s variable to store sticky bit for rounding +/// \return (\a x << 32) / \a y +inline uint32 divide64(uint32 x, uint32 y, int& s) +{ +#if HALF_ENABLE_CPP11_LONG_LONG + unsigned long long xx = static_cast(x) << 32; + return s = (xx % y != 0), static_cast(xx / y); +#else + y >>= 1; + uint32 rem = x, div = 0; + for(unsigned int i = 0; i < 32; ++i) + { + div <<= 1; + if(rem >= y) + { + rem -= y; + div |= 1; + } + rem <<= 1; + } + return s = rem > 1, div; +#endif +} + +/// Half precision positive modulus. +/// \tparam Q `true` to compute full quotient, `false` else +/// \tparam R `true` to compute signed remainder, `false` for positive remainder +/// \param x first operand as positive finite half-precision value +/// \param y second operand as positive finite half-precision value +/// \param quo adress to store quotient at, `nullptr` if \a Q `false` +/// \return modulus of \a x / \a y +template +unsigned int mod(unsigned int x, unsigned int y, int* quo = NULL) +{ + unsigned int q = 0; + if(x > y) + { + int absx = x, absy = y, expx = 0, expy = 0; + for(; absx < 0x400; absx <<= 1, --expx) + ; + for(; absy < 0x400; absy <<= 1, --expy) + ; + expx += absx >> 10; + expy += absy >> 10; + int mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; + for(int d = expx - expy; d; --d) + { + if(!Q && mx == my) + return 0; + if(mx >= my) + { + mx -= my; + q += Q; + } + mx <<= 1; + q <<= static_cast(Q); + } + if(!Q && mx == my) + return 0; + if(mx >= my) + { + mx -= my; + ++q; + } + if(Q) + { + q &= (1 << (std::numeric_limits::digits - 1)) - 1; + if(!mx) + return *quo = q, 0; + } + for(; mx < 0x400; mx <<= 1, --expy) + ; + x = (expy > 0) ? ((expy << 10) | (mx & 0x3FF)) : (mx >> (1 - expy)); + } + if(R) + { + unsigned int a, b; + if(y < 0x800) + { + a = (x < 0x400) ? (x << 1) : (x + 0x400); + b = y; + } + else + { + a = x; + b = y - 0x400; + } + if(a > b || (a == b && (q & 1))) + { + int exp = (y >> 10) + (y <= 0x3FF), d = exp - (x >> 10) - (x <= 0x3FF); + int m = (((y & 0x3FF) | ((y > 0x3FF) << 10)) << 1) - + (((x & 0x3FF) | ((x > 0x3FF) << 10)) << (1 - d)); + for(; m < 0x800 && exp > 1; m <<= 1, --exp) + ; + x = 0x8000 + ((exp - 1) << 10) + (m >> 1); + q += Q; + } + } + if(Q) + *quo = q; + return x; +} + +/// Fixed point square root. +/// \tparam F number of fractional bits +/// \param r radicand in Q1.F fixed point format +/// \param exp exponent +/// \return square root as Q1.F/2 +template +uint32 sqrt(uint32& r, int& exp) +{ + int i = exp & 1; + r <<= i; + exp = (exp - i) / 2; + uint32 m = 0; + for(uint32 bit = static_cast(1) << F; bit; bit >>= 2) + { + if(r < m + bit) + m >>= 1; + else + { + r -= m + bit; + m = (m >> 1) + bit; + } + } + return m; +} + +/// Fixed point binary exponential. +/// This uses the BKM algorithm in E-mode. +/// \param m exponent in [0,1) as Q0.31 +/// \param n number of iterations (at most 32) +/// \return 2 ^ \a m as Q1.31 +inline uint32 exp2(uint32 m, unsigned int n = 32) +{ + static const uint32 logs[] = { + 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, 0x02DCF2D1, + 0x016FE50B, 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, 0x000B8A47, 0x0005C53B, + 0x0002E2A3, 0x00017153, 0x0000B8AA, 0x00005C55, 0x00002E2B, 0x00001715, 0x00000B8B, + 0x000005C5, 0x000002E3, 0x00000171, 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, + 0x0000000C, 0x00000006, 0x00000003, 0x00000001}; + if(!m) + return 0x80000000; + uint32 mx = 0x80000000, my = 0; + for(unsigned int i = 1; i < n; ++i) + { + uint32 mz = my + logs[i]; + if(mz <= m) + { + my = mz; + mx += mx >> i; + } + } + return mx; +} + +/// Fixed point binary logarithm. +/// This uses the BKM algorithm in L-mode. +/// \param m mantissa in [1,2) as Q1.30 +/// \param n number of iterations (at most 32) +/// \return log2(\a m) as Q0.31 +inline uint32 log2(uint32 m, unsigned int n = 32) +{ + static const uint32 logs[] = { + 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, 0x02DCF2D1, + 0x016FE50B, 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, 0x000B8A47, 0x0005C53B, + 0x0002E2A3, 0x00017153, 0x0000B8AA, 0x00005C55, 0x00002E2B, 0x00001715, 0x00000B8B, + 0x000005C5, 0x000002E3, 0x00000171, 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, + 0x0000000C, 0x00000006, 0x00000003, 0x00000001}; + if(m == 0x40000000) + return 0; + uint32 mx = 0x40000000, my = 0; + for(unsigned int i = 1; i < n; ++i) + { + uint32 mz = mx + (mx >> i); + if(mz <= m) + { + mx = mz; + my += logs[i]; + } + } + return my; +} + +/// Fixed point sine and cosine. +/// This uses the CORDIC algorithm in rotation mode. +/// \param mz angle in [-pi/2,pi/2] as Q1.30 +/// \param n number of iterations (at most 31) +/// \return sine and cosine of \a mz as Q1.30 +inline std::pair sincos(uint32 mz, unsigned int n = 31) +{ + static const uint32 angles[] = { + 0x3243F6A9, 0x1DAC6705, 0x0FADBAFD, 0x07F56EA7, 0x03FEAB77, 0x01FFD55C, 0x00FFFAAB, + 0x007FFF55, 0x003FFFEB, 0x001FFFFD, 0x00100000, 0x00080000, 0x00040000, 0x00020000, + 0x00010000, 0x00008000, 0x00004000, 0x00002000, 0x00001000, 0x00000800, 0x00000400, + 0x00000200, 0x00000100, 0x00000080, 0x00000040, 0x00000020, 0x00000010, 0x00000008, + 0x00000004, 0x00000002, 0x00000001}; + uint32 mx = 0x26DD3B6A, my = 0; + for(unsigned int i = 0; i < n; ++i) + { + uint32 sign = sign_mask(mz); + uint32 tx = mx - (arithmetic_shift(my, i) ^ sign) + sign; + uint32 ty = my + (arithmetic_shift(mx, i) ^ sign) - sign; + mx = tx; + my = ty; + mz -= (angles[i] ^ sign) - sign; + } + return std::make_pair(my, mx); +} + +/// Fixed point arc tangent. +/// This uses the CORDIC algorithm in vectoring mode. +/// \param my y coordinate as Q0.30 +/// \param mx x coordinate as Q0.30 +/// \param n number of iterations (at most 31) +/// \return arc tangent of \a my / \a mx as Q1.30 +inline uint32 atan2(uint32 my, uint32 mx, unsigned int n = 31) +{ + static const uint32 angles[] = { + 0x3243F6A9, 0x1DAC6705, 0x0FADBAFD, 0x07F56EA7, 0x03FEAB77, 0x01FFD55C, 0x00FFFAAB, + 0x007FFF55, 0x003FFFEB, 0x001FFFFD, 0x00100000, 0x00080000, 0x00040000, 0x00020000, + 0x00010000, 0x00008000, 0x00004000, 0x00002000, 0x00001000, 0x00000800, 0x00000400, + 0x00000200, 0x00000100, 0x00000080, 0x00000040, 0x00000020, 0x00000010, 0x00000008, + 0x00000004, 0x00000002, 0x00000001}; + uint32 mz = 0; + for(unsigned int i = 0; i < n; ++i) + { + uint32 sign = sign_mask(my); + uint32 tx = mx + (arithmetic_shift(my, i) ^ sign) - sign; + uint32 ty = my - (arithmetic_shift(mx, i) ^ sign) + sign; + mx = tx; + my = ty; + mz += (angles[i] ^ sign) - sign; + } + return mz; +} + +/// Reduce argument for trigonometric functions. +/// \param abs half-precision floating-point value +/// \param k value to take quarter period +/// \return \a abs reduced to [-pi/4,pi/4] as Q0.30 +inline uint32 angle_arg(unsigned int abs, int& k) +{ + uint32 m = (abs & 0x3FF) | ((abs > 0x3FF) << 10); + int exp = (abs >> 10) + (abs <= 0x3FF) - 15; + if(abs < 0x3A48) + return k = 0, m << (exp + 20); +#if HALF_ENABLE_CPP11_LONG_LONG + unsigned long long y = m * 0xA2F9836E4E442, mask = (1ULL << (62 - exp)) - 1, + yi = (y + (mask >> 1)) & ~mask, f = y - yi; + uint32 sign = -static_cast(f >> 63); + k = static_cast(yi >> (62 - exp)); + return (multiply64(static_cast((sign ? -f : f) >> (31 - exp)), 0xC90FDAA2) ^ sign) - + sign; +#else + uint32 yh = m * 0xA2F98 + mulhi(m, 0x36E4E442), + yl = (m * 0x36E4E442) & 0xFFFFFFFF; + uint32 mask = (static_cast(1) << (30 - exp)) - 1, yi = (yh + (mask >> 1)) & ~mask, + sign = -static_cast(yi > yh); + k = static_cast(yi >> (30 - exp)); + uint32 fh = (yh ^ sign) + (yi ^ ~sign) - ~sign, fl = (yl ^ sign) - sign; + return (multiply64((exp > -1) + ? (((fh << (1 + exp)) & 0xFFFFFFFF) | ((fl & 0xFFFFFFFF) >> (31 - exp))) + : fh, + 0xC90FDAA2) ^ + sign) - + sign; +#endif +} + +/// Get arguments for atan2 function. +/// \param abs half-precision floating-point value +/// \return \a abs and sqrt(1 - \a abs^2) as Q0.30 +inline std::pair atan2_args(unsigned int abs) +{ + int exp = -15; + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + uint32 my = ((abs & 0x3FF) | 0x400) << 5, r = my * my; + int rexp = 2 * exp; + r = 0x40000000 - + ((rexp > -31) ? ((r >> -rexp) | ((r & ((static_cast(1) << -rexp) - 1)) != 0)) : 1); + for(rexp = 0; r < 0x40000000; r <<= 1, --rexp) + ; + uint32 mx = sqrt<30>(r, rexp); + int d = exp - rexp; + if(d < 0) + return std::make_pair((d < -14) ? ((my >> (-d - 14)) + ((my >> (-d - 15)) & 1)) + : (my << (14 + d)), + (mx << 14) + (r << 13) / mx); + if(d > 0) + return std::make_pair(my << 14, + (d > 14) + ? ((mx >> (d - 14)) + ((mx >> (d - 15)) & 1)) + : ((d == 14) ? mx : ((mx << (14 - d)) + (r << (13 - d)) / mx))); + return std::make_pair(my << 13, (mx << 13) + (r << 12) / mx); +} + +/// Get exponentials for hyperbolic computation +/// \param abs half-precision floating-point value +/// \param exp variable to take unbiased exponent of larger result +/// \param n number of BKM iterations (at most 32) +/// \return exp(abs) and exp(-\a abs) as Q1.31 with same exponent +inline std::pair hyperbolic_args(unsigned int abs, int& exp, unsigned int n = 32) +{ + uint32 mx = detail::multiply64(static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, + 0xB8AA3B29), + my; + int e = (abs >> 10) + (abs <= 0x3FF); + if(e < 14) + { + exp = 0; + mx >>= 14 - e; + } + else + { + exp = mx >> (45 - e); + mx = (mx << (e - 14)) & 0x7FFFFFFF; + } + mx = exp2(mx, n); + int d = exp << 1, s; + if(mx > 0x80000000) + { + my = divide64(0x80000000, mx, s); + my |= s; + ++d; + } + else + my = mx; + return std::make_pair( + mx, (d < 31) ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) : 1); +} + +/// Postprocessing for binary exponential. +/// \tparam R rounding mode to use +/// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results +/// \param m mantissa as Q1.31 +/// \param exp absolute value of unbiased exponent +/// \param esign sign of actual exponent +/// \param sign sign bit of result +/// \return value converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded or \a I is `true` +template +unsigned int exp2_post(uint32 m, int exp, bool esign, unsigned int sign = 0) +{ + int s = 0; + if(esign) + { + if(m > 0x80000000) + { + m = divide64(0x80000000, m, s); + ++exp; + } + if(exp > 25) + return underflow(sign); + else if(exp == 25) + return rounded(sign, 1, (m & 0x7FFFFFFF) != 0); + exp = -exp; + } + else if(exp > 15) + return overflow(sign); + return fixed2half(m, exp + 14, sign, s); +} + +/// Postprocessing for binary logarithm. +/// \tparam R rounding mode to use +/// \tparam L logarithm for base transformation as Q1.31 +/// \param m fractional part of logarithm as Q0.31 +/// \param ilog signed integer part of logarithm +/// \param exp biased exponent of result +/// \param sign sign bit of result +/// \return value base-transformed and converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int log2_post(uint32 m, int ilog, int exp, unsigned int sign = 0) +{ + uint32 msign = sign_mask(ilog); + m = (((static_cast(ilog) << 27) + (m >> 4)) ^ msign) - msign; + if(!m) + return 0; + for(; m < 0x80000000; m <<= 1, --exp) + ; + int i = m >= L, s; + exp += i; + m >>= 1 + i; + sign ^= msign & 0x8000; + if(exp < -11) + return underflow(sign); + m = divide64(m, L, s); + return fixed2half(m, exp, sign, 1); +} + +/// Hypotenuse square root and postprocessing. +/// \tparam R rounding mode to use +/// \param r mantissa as Q2.30 +/// \param exp unbiased exponent +/// \return square root converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int hypot_post(uint32 r, int exp) +{ + int i = r >> 31; + if((exp += i) > 46) + return overflow(); + if(exp < -34) + return underflow(); + r = (r >> i) | (r & i); + uint32 m = sqrt<30>(r, exp += 15); + return fixed2half(m, exp - 1, 0, r != 0); +} + +/// Division and postprocessing for tangents. +/// \tparam R rounding mode to use +/// \param my dividend as Q1.31 +/// \param mx divisor as Q1.31 +/// \param exp biased exponent of result +/// \param sign sign bit of result +/// \return quotient converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int tangent_post(uint32 my, uint32 mx, int exp, unsigned int sign = 0) +{ + int i = my >= mx, s; + exp += i; + if(exp > 29) + return overflow(sign); + if(exp < -11) + return underflow(sign); + uint32 m = divide64(my >> (i + 1), mx, s); + return fixed2half(m, exp, sign, s); +} + +/// Area function and postprocessing. +/// This computes the value directly in Q2.30 using the representation `asinh|acosh(x) = +/// log(x+sqrt(x^2+|-1))`. +/// \tparam R rounding mode to use +/// \tparam S `true` for asinh, `false` for acosh +/// \param arg half-precision argument +/// \return asinh|acosh(\a arg) converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int area(unsigned int arg) +{ + int abs = arg & 0x7FFF, expx = (abs >> 10) + (abs <= 0x3FF) - 15, expy = -15, ilog, i; + uint32 mx = static_cast((abs & 0x3FF) | ((abs > 0x3FF) << 10)) << 20, my, r; + for(; abs < 0x400; abs <<= 1, --expy) + ; + expy += abs >> 10; + r = ((abs & 0x3FF) | 0x400) << 5; + r *= r; + i = r >> 31; + expy = 2 * expy + i; + r >>= i; + if(S) + { + if(expy < 0) + { + r = 0x40000000 + ((expy > -30) ? ((r >> -expy) | + ((r & ((static_cast(1) << -expy) - 1)) != 0)) + : 1); + expy = 0; + } + else + { + r += 0x40000000 >> expy; + i = r >> 31; + r = (r >> i) | (r & i); + expy += i; + } + } + else + { + r -= 0x40000000 >> expy; + for(; r < 0x40000000; r <<= 1, --expy) + ; + } + my = sqrt<30>(r, expy); + my = (my << 15) + (r << 14) / my; + if(S) + { + mx >>= expy - expx; + ilog = expy; + } + else + { + my >>= expx - expy; + ilog = expx; + } + my += mx; + i = my >> 31; + static const int G = S && (R == std::round_to_nearest); + return log2_post( + log2(my >> i, 26 + S + G) + (G << 3), ilog + i, 17, arg & (static_cast(S) << 15)); +} + +/// Class for 1.31 unsigned floating-point computation +struct f31 +{ + /// Constructor. + /// \param mant mantissa as 1.31 + /// \param e exponent + HALF_CONSTEXPR f31(uint32 mant, int e) : m(mant), exp(e) {} + + /// Constructor. + /// \param abs unsigned half-precision value + f31(unsigned int abs) : exp(-15) + { + for(; abs < 0x400; abs <<= 1, --exp) + ; + m = static_cast((abs & 0x3FF) | 0x400) << 21; + exp += (abs >> 10); + } + + /// Addition operator. + /// \param a first operand + /// \param b second operand + /// \return \a a + \a b + friend f31 operator+(f31 a, f31 b) + { + if(b.exp > a.exp) + std::swap(a, b); + int d = a.exp - b.exp; + uint32 m = a.m + ((d < 32) ? (b.m >> d) : 0); + int i = (m & 0xFFFFFFFF) < a.m; + return f31(((m + i) >> i) | 0x80000000, a.exp + i); + } + + /// Subtraction operator. + /// \param a first operand + /// \param b second operand + /// \return \a a - \a b + friend f31 operator-(f31 a, f31 b) + { + int d = a.exp - b.exp, exp = a.exp; + uint32 m = a.m - ((d < 32) ? (b.m >> d) : 0); + if(!m) + return f31(0, -32); + for(; m < 0x80000000; m <<= 1, --exp) + ; + return f31(m, exp); + } + + /// Multiplication operator. + /// \param a first operand + /// \param b second operand + /// \return \a a * \a b + friend f31 operator*(f31 a, f31 b) + { + uint32 m = multiply64(a.m, b.m); + int i = m >> 31; + return f31(m << (1 - i), a.exp + b.exp + i); + } + + /// Division operator. + /// \param a first operand + /// \param b second operand + /// \return \a a / \a b + friend f31 operator/(f31 a, f31 b) + { + int i = a.m >= b.m, s; + uint32 m = divide64((a.m + i) >> i, b.m, s); + return f31(m, a.exp - b.exp + i - 1); + } + + uint32 m; ///< mantissa as 1.31. + int exp; ///< exponent. +}; + +/// Error function and postprocessing. +/// This computes the value directly in Q1.31 using the approximations given +/// [here](https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions). +/// \tparam R rounding mode to use +/// \tparam C `true` for comlementary error function, `false` else +/// \param arg half-precision function argument +/// \return approximated value of error function in half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int erf(unsigned int arg) +{ + unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; + f31 x(abs), x2 = x * x * f31(0xB8AA3B29, 0), + t = f31(0x80000000, 0) / (f31(0x80000000, 0) + f31(0xA7BA054A, -2) * x), t2 = t * t; + f31 e = ((f31(0x87DC2213, 0) * t2 + f31(0xB5F0E2AE, 0)) * t2 + f31(0x82790637, -2) - + (f31(0xBA00E2B8, 0) * t2 + f31(0x91A98E62, -2)) * t) * + t / + ((x2.exp < 0) ? f31(exp2((x2.exp > -32) ? (x2.m >> -x2.exp) : 0, 30), 0) + : f31(exp2((x2.m << x2.exp) & 0x7FFFFFFF, 22), x2.m >> (31 - x2.exp))); + return (!C || sign) + ? fixed2half( + 0x80000000 - (e.m >> (C - e.exp)), 14 + C, sign & (C - 1U)) + : (e.exp < -25) + ? underflow() + : fixed2half(e.m >> 1, e.exp + 14, 0, e.m & 1); +} + +/// Gamma function and postprocessing. +/// This approximates the value of either the gamma function or its logarithm directly in Q1.31. +/// \tparam R rounding mode to use +/// \tparam L `true` for lograithm of gamma function, `false` for gamma function +/// \param arg half-precision floating-point value +/// \return lgamma/tgamma(\a arg) in half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if \a arg is not a positive integer +template +unsigned int gamma(unsigned int arg) +{ + /* static const double p[] ={ 2.50662827563479526904, 225.525584619175212544, + -268.295973841304927459, 80.9030806934622512966, -5.00757863970517583837, + 0.0114684895434781459556 }; double t = arg + 4.65, s = p[0]; for(unsigned int i=0; i<5; ++i) + s += p[i+1] / (arg+i); + return std::log(s) + (arg-0.5)*std::log(t) - t; +*/ static const f31 pi(0xC90FDAA2, 1), lbe(0xB8AA3B29, 0); + unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; + bool bsign = sign != 0; + f31 z(abs), x = sign ? (z + f31(0x80000000, 0)) : z, t = x + f31(0x94CCCCCD, 2), + s = f31(0xA06C9901, 1) + f31(0xBBE654E2, -7) / (x + f31(0x80000000, 2)) + + f31(0xA1CE6098, 6) / (x + f31(0x80000000, 1)) + f31(0xE1868CB7, 7) / x - + f31(0x8625E279, 8) / (x + f31(0x80000000, 0)) - + f31(0xA03E158F, 2) / (x + f31(0xC0000000, 1)); + int i = (s.exp >= 2) + (s.exp >= 4) + (s.exp >= 8) + (s.exp >= 16); + s = f31((static_cast(s.exp) << (31 - i)) + (log2(s.m >> 1, 28) >> i), i) / lbe; + if(x.exp != -1 || x.m != 0x80000000) + { + i = (t.exp >= 2) + (t.exp >= 4) + (t.exp >= 8); + f31 l = f31((static_cast(t.exp) << (31 - i)) + (log2(t.m >> 1, 30) >> i), i) / lbe; + s = (x.exp < -1) ? (s - (f31(0x80000000, -1) - x) * l) + : (s + (x - f31(0x80000000, -1)) * l); + } + s = x.exp ? (s - t) : (t - s); + if(bsign) + { + if(z.exp >= 0) + { + sign &= (L | ((z.m >> (31 - z.exp)) & 1)) - 1; + for(z = f31((z.m << (1 + z.exp)) & 0xFFFFFFFF, -1); z.m < 0x80000000; + z.m <<= 1, --z.exp) + ; + } + if(z.exp == -1) + z = f31(0x80000000, 0) - z; + if(z.exp < -1) + { + z = z * pi; + z.m = sincos(z.m >> (1 - z.exp), 30).first; + for(z.exp = 1; z.m < 0x80000000; z.m <<= 1, --z.exp) + ; + } + else + z = f31(0x80000000, 0); + } + if(L) + { + if(bsign) + { + f31 l(0x92868247, 0); + if(z.exp < 0) + { + uint32 m = log2((z.m + 1) >> 1, 27); + z = f31(-((static_cast(z.exp) << 26) + (m >> 5)), 5); + for(; z.m < 0x80000000; z.m <<= 1, --z.exp) + ; + l = l + z / lbe; + } + sign = static_cast(x.exp && (l.exp < s.exp || (l.exp == s.exp && l.m < s.m))) + << 15; + s = sign ? (s - l) : x.exp ? (l - s) : (l + s); + } + else + { + sign = static_cast(x.exp == 0) << 15; + if(s.exp < -24) + return underflow(sign); + if(s.exp > 15) + return overflow(sign); + } + } + else + { + s = s * lbe; + uint32 m; + if(s.exp < 0) + { + m = s.m >> -s.exp; + s.exp = 0; + } + else + { + m = (s.m << s.exp) & 0x7FFFFFFF; + s.exp = (s.m >> (31 - s.exp)); + } + s.m = exp2(m, 27); + if(!x.exp) + s = f31(0x80000000, 0) / s; + if(bsign) + { + if(z.exp < 0) + s = s * z; + s = pi / s; + if(s.exp < -24) + return underflow(sign); + } + else if(z.exp > 0 && !(z.m & ((1 << (31 - z.exp)) - 1))) + return ((s.exp + 14) << 10) + (s.m >> 21); + if(s.exp > 15) + return overflow(sign); + } + return fixed2half(s.m, s.exp + 14, sign); +} +/// \} + +template +struct half_caster; +} // namespace detail + +/// Half-precision floating-point type. +/// This class implements an IEEE-conformant half-precision floating-point type with the usual +/// arithmetic +/// operators and conversions. It is implicitly convertible to single-precision floating-point, +/// which makes artihmetic +/// expressions and functions with mixed-type operands to be of the most precise operand type. +/// +/// According to the C++98/03 definition, the half type is not a POD type. But according to C++11's +/// less strict and +/// extended definitions it is both a standard layout type and a trivially copyable type (even if +/// not a POD type), which +/// means it can be standard-conformantly copied using raw binary copies. But in this context some +/// more words about the +/// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not +/// neccessarily have to be of +/// exactly 16-bits size. But on any reasonable implementation the actual binary representation of +/// this type will most +/// probably not ivolve any additional "magic" or padding beyond the simple binary representation of +/// the underlying 16-bit +/// IEEE number, even if not strictly guaranteed by the standard. But even then it only has an +/// actual size of 16 bits if +/// your C++ implementation supports an unsigned integer type of exactly 16 bits width. But this +/// should be the case on +/// nearly any reasonable platform. +/// +/// So if your C++ implementation is not totally exotic or imposes special alignment requirements, +/// it is a reasonable +/// assumption that the data of a half is just comprised of the 2 bytes of the underlying IEEE +/// representation. +class half +{ + public: + /// \name Construction and assignment + /// \{ + + /// Default constructor. + /// This initializes the half to 0. Although this does not match the builtin types' + /// default-initialization semantics + /// and may be less efficient than no initialization, it is needed to provide proper + /// value-initialization semantics. + HALF_CONSTEXPR half() HALF_NOEXCEPT : data_() {} + + /// Conversion constructor. + /// \param rhs float to convert + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + explicit half(float rhs) + : data_(static_cast(detail::float2half(rhs))) + { + } + + /// Conversion to single-precision. + /// \return single precision value representing expression value + operator float() const { return detail::half2float(data_); } + + /// Assignment operator. + /// \param rhs single-precision value to copy from + /// \return reference to this half + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + half& operator=(float rhs) + { + data_ = static_cast(detail::float2half(rhs)); + return *this; + } + + /// \} + /// \name Arithmetic updates + /// \{ + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to add + /// \return reference to this half + /// \exception FE_... according to operator+(half,half) + half& operator+=(half rhs) { return *this = *this + rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to subtract + /// \return reference to this half + /// \exception FE_... according to operator-(half,half) + half& operator-=(half rhs) { return *this = *this - rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to multiply with + /// \return reference to this half + /// \exception FE_... according to operator*(half,half) + half& operator*=(half rhs) { return *this = *this * rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to divide by + /// \return reference to this half + /// \exception FE_... according to operator/(half,half) + half& operator/=(half rhs) { return *this = *this / rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to add + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator+=(float rhs) { return *this = *this + rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to subtract + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator-=(float rhs) { return *this = *this - rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to multiply with + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator*=(float rhs) { return *this = *this * rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to divide by + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator/=(float rhs) { return *this = *this / rhs; } + + /// \} + /// \name Increment and decrement + /// \{ + + /// Prefix increment. + /// \return incremented half value + /// \exception FE_... according to operator+(half,half) + half& operator++() { return *this = *this + half(detail::binary, 0x3C00); } + + /// Prefix decrement. + /// \return decremented half value + /// \exception FE_... according to operator-(half,half) + half& operator--() { return *this = *this + half(detail::binary, 0xBC00); } + + /// Postfix increment. + /// \return non-incremented half value + /// \exception FE_... according to operator+(half,half) + half operator++(int) + { + half out(*this); + ++*this; + return out; + } + + /// Postfix decrement. + /// \return non-decremented half value + /// \exception FE_... according to operator-(half,half) + half operator--(int) + { + half out(*this); + --*this; + return out; + } + /// \} + + private: + /// Rounding mode to use + static const std::float_round_style round_style = (std::float_round_style)(HALF_ROUND_STYLE); + + /// Constructor. + /// \param bits binary representation to set half to + HALF_CONSTEXPR half(detail::binary_t, unsigned int bits) HALF_NOEXCEPT + : data_(static_cast(bits)) + { + } + + /// Internal binary representation + detail::uint16 data_; + +#ifndef HALF_DOXYGEN_ONLY + friend HALF_CONSTEXPR_NOERR bool operator==(half, half); + friend HALF_CONSTEXPR_NOERR bool operator!=(half, half); + friend HALF_CONSTEXPR_NOERR bool operator<(half, half); + friend HALF_CONSTEXPR_NOERR bool operator>(half, half); + friend HALF_CONSTEXPR_NOERR bool operator<=(half, half); + friend HALF_CONSTEXPR_NOERR bool operator>=(half, half); + friend HALF_CONSTEXPR half operator-(half); + friend half operator+(half, half); + friend half operator-(half, half); + friend half operator*(half, half); + friend half operator/(half, half); + template + friend std::basic_ostream& operator<<(std::basic_ostream&, half); + template + friend std::basic_istream& operator>>(std::basic_istream&, half&); + friend HALF_CONSTEXPR half fabs(half); + friend half fmod(half, half); + friend half remainder(half, half); + friend half remquo(half, half, int*); + friend half fma(half, half, half); + friend HALF_CONSTEXPR_NOERR half fmax(half, half); + friend HALF_CONSTEXPR_NOERR half fmin(half, half); + friend half fdim(half, half); + friend half nanh(const char*); + friend half exp(half); + friend half exp2(half); + friend half expm1(half); + friend half log(half); + friend half log10(half); + friend half log2(half); + friend half log1p(half); + friend half sqrt(half); + friend half cbrt(half); + friend half hypot(half, half); + friend half hypot(half, half, half); + friend half pow(half, half); + friend void sincos(half, half*, half*); + friend half sin(half); + friend half cos(half); + friend half tan(half); + friend half asin(half); + friend half acos(half); + friend half atan(half); + friend half atan2(half, half); + friend half sinh(half); + friend half cosh(half); + friend half tanh(half); + friend half asinh(half); + friend half acosh(half); + friend half atanh(half); + friend half erf(half); + friend half erfc(half); + friend half lgamma(half); + friend half tgamma(half); + friend half ceil(half); + friend half floor(half); + friend half trunc(half); + friend half round(half); + friend long lround(half); + friend half rint(half); + friend long lrint(half); + friend half nearbyint(half); +#ifdef HALF_ENABLE_CPP11_LONG_LONG + friend long long llround(half); + friend long long llrint(half); +#endif + friend half frexp(half, int*); + friend half scalbln(half, long); + friend half modf(half, half*); + friend int ilogb(half); + friend half logb(half); + friend half nextafter(half, half); + friend half nexttoward(half, long double); + friend HALF_CONSTEXPR half copysign(half, half); + friend HALF_CONSTEXPR int fpclassify(half); + friend HALF_CONSTEXPR bool isfinite(half); + friend HALF_CONSTEXPR bool isinf(half); + friend HALF_CONSTEXPR bool isnan(half); + friend HALF_CONSTEXPR bool isnormal(half); + friend HALF_CONSTEXPR bool signbit(half); + friend HALF_CONSTEXPR bool isgreater(half, half); + friend HALF_CONSTEXPR bool isgreaterequal(half, half); + friend HALF_CONSTEXPR bool isless(half, half); + friend HALF_CONSTEXPR bool islessequal(half, half); + friend HALF_CONSTEXPR bool islessgreater(half, half); + template + friend struct detail::half_caster; + friend class std::numeric_limits; +#if HALF_ENABLE_CPP11_HASH + friend struct std::hash; +#endif +#if HALF_ENABLE_CPP11_USER_LITERALS + friend half literal::operator"" _h(long double); +#endif +#endif +}; + +#if HALF_ENABLE_CPP11_USER_LITERALS +namespace literal { +/// Half literal. +/// While this returns a properly rounded half-precision value, half literals can unfortunately not +/// be constant +/// expressions due to rather involved conversions. So don't expect this to be a literal literal +/// without involving +/// conversion operations at runtime. It is a convenience feature, not a performance optimization. +/// \param value literal value +/// \return half with of given value (possibly rounded) +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator"" _h(long double value) +{ + return half(detail::binary, detail::float2half(value)); +} +} // namespace literal +#endif + +namespace detail { +/// Helper class for half casts. +/// This class template has to be specialized for all valid cast arguments to define an appropriate +/// static +/// `cast` member function and a corresponding `type` member denoting its return type. +/// \tparam T destination type +/// \tparam U source type +/// \tparam R rounding mode to use +template +struct half_caster +{ +}; +template +struct half_caster +{ +#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, "half_cast from non-arithmetic type unsupported"); +#endif + + static half cast(U arg) { return cast_impl(arg, is_float()); }; + + private: + static half cast_impl(U arg, true_type) { return half(binary, float2half(arg)); } + static half cast_impl(U arg, false_type) { return half(binary, int2half(arg)); } +}; +template +struct half_caster +{ +#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, "half_cast to non-arithmetic type unsupported"); +#endif + + static T cast(half arg) { return cast_impl(arg, is_float()); } + + private: + static T cast_impl(half arg, true_type) { return half2float(arg.data_); } + static T cast_impl(half arg, false_type) { return half2int(arg.data_); } +}; +template +struct half_caster +{ + static half cast(half arg) { return arg; } +}; +} // namespace detail +} // namespace half_float + +/// Extensions to the C++ standard library. +namespace std { +/// Numeric limits for half-precision floats. +/// **See also:** Documentation for +/// [std::numeric_limits](https://en.cppreference.com/w/cpp/types/numeric_limits) +template <> +class numeric_limits +{ + public: + /// Is template specialization. + static HALF_CONSTEXPR_CONST bool is_specialized = true; + + /// Supports signed values. + static HALF_CONSTEXPR_CONST bool is_signed = true; + + /// Is not an integer type. + static HALF_CONSTEXPR_CONST bool is_integer = false; + + /// Is not exact. + static HALF_CONSTEXPR_CONST bool is_exact = false; + + /// Doesn't provide modulo arithmetic. + static HALF_CONSTEXPR_CONST bool is_modulo = false; + + /// Has a finite set of values. + static HALF_CONSTEXPR_CONST bool is_bounded = true; + + /// IEEE conformant. + static HALF_CONSTEXPR_CONST bool is_iec559 = true; + + /// Supports infinity. + static HALF_CONSTEXPR_CONST bool has_infinity = true; + + /// Supports quiet NaNs. + static HALF_CONSTEXPR_CONST bool has_quiet_NaN = true; + + /// Supports signaling NaNs. + static HALF_CONSTEXPR_CONST bool has_signaling_NaN = true; + + /// Supports subnormal values. + static HALF_CONSTEXPR_CONST float_denorm_style has_denorm = denorm_present; + + /// Supports no denormalization detection. + static HALF_CONSTEXPR_CONST bool has_denorm_loss = false; + +#if HALF_ERRHANDLING_THROWS + static HALF_CONSTEXPR_CONST bool traps = true; +#else + /// Traps only if [HALF_ERRHANDLING_THROW_...](\ref HALF_ERRHANDLING_THROW_INVALID) is + /// acitvated. + static HALF_CONSTEXPR_CONST bool traps = false; +#endif + + /// Does not support no pre-rounding underflow detection. + static HALF_CONSTEXPR_CONST bool tinyness_before = false; + + /// Rounding mode. + static HALF_CONSTEXPR_CONST float_round_style round_style = half_float::half::round_style; + + /// Significant digits. + static HALF_CONSTEXPR_CONST int digits = 11; + + /// Significant decimal digits. + static HALF_CONSTEXPR_CONST int digits10 = 3; + + /// Required decimal digits to represent all possible values. + static HALF_CONSTEXPR_CONST int max_digits10 = 5; + + /// Number base. + static HALF_CONSTEXPR_CONST int radix = 2; + + /// One more than smallest exponent. + static HALF_CONSTEXPR_CONST int min_exponent = -13; + + /// Smallest normalized representable power of 10. + static HALF_CONSTEXPR_CONST int min_exponent10 = -4; + + /// One more than largest exponent + static HALF_CONSTEXPR_CONST int max_exponent = 16; + + /// Largest finitely representable power of 10. + static HALF_CONSTEXPR_CONST int max_exponent10 = 4; + + /// Smallest positive normal value. + static HALF_CONSTEXPR half_float::half min() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x0400); + } + + /// Smallest finite value. + static HALF_CONSTEXPR half_float::half lowest() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0xFBFF); + } + + /// Largest finite value. + static HALF_CONSTEXPR half_float::half max() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x7BFF); + } + + /// Difference between 1 and next representable value. + static HALF_CONSTEXPR half_float::half epsilon() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x1400); + } + + /// Maximum rounding error in ULP (units in the last place). + static HALF_CONSTEXPR half_float::half round_error() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, + (round_style == std::round_to_nearest) ? 0x3800 : 0x3C00); + } + + /// Positive infinity. + static HALF_CONSTEXPR half_float::half infinity() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x7C00); + } + + /// Quiet NaN. + static HALF_CONSTEXPR half_float::half quiet_NaN() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x7FFF); + } + + /// Signaling NaN. + static HALF_CONSTEXPR half_float::half signaling_NaN() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x7DFF); + } + + /// Smallest positive subnormal value. + static HALF_CONSTEXPR half_float::half denorm_min() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x0001); + } +}; + +#if HALF_ENABLE_CPP11_HASH +/// Hash function for half-precision floats. +/// This is only defined if C++11 `std::hash` is supported and enabled. +/// +/// **See also:** Documentation for [std::hash](https://en.cppreference.com/w/cpp/utility/hash) +template <> +struct hash +{ + /// Type of function argument. + typedef half_float::half argument_type; + + /// Function return type. + typedef size_t result_type; + + /// Compute hash function. + /// \param arg half to hash + /// \return hash value + result_type operator()(argument_type arg) const + { + return hash()(arg.data_ & + -static_cast(arg.data_ != 0x8000)); + } +}; +#endif +} // namespace std + +namespace half_float { +/// \anchor compop +/// \name Comparison operators +/// \{ + +/// Comparison for equality. +/// \param x first operand +/// \param y second operand +/// \retval true if operands equal +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator==(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + (x.data_ == y.data_ || !((x.data_ | y.data_) & 0x7FFF)); +} + +/// Comparison for inequality. +/// \param x first operand +/// \param y second operand +/// \retval true if operands not equal +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator!=(half x, half y) +{ + return detail::compsignal(x.data_, y.data_) || + (x.data_ != y.data_ && ((x.data_ | y.data_) & 0x7FFF)); +} + +/// Comparison for less than. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less than \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator<(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) < + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); +} + +/// Comparison for greater than. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater than \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator>(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) > + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); +} + +/// Comparison for less equal. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less equal \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator<=(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) <= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); +} + +/// Comparison for greater equal. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater equal \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator>=(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) >= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); +} + +/// \} +/// \anchor arithmetics +/// \name Arithmetic operators +/// \{ + +/// Identity. +/// \param arg operand +/// \return unchanged operand +inline HALF_CONSTEXPR half operator+(half arg) { return arg; } + +/// Negation. +/// \param arg operand +/// \return negated operand +inline HALF_CONSTEXPR half operator-(half arg) { return half(detail::binary, arg.data_ ^ 0x8000); } + +/// Addition. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return sum of half expressions +/// \exception FE_INVALID if \a x and \a y are infinities with different signs or signaling NaNs +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator+(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half( + detail::binary, + detail::float2half(detail::half2float(x.data_) + + detail::half2float(y.data_))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF; + bool sub = ((x.data_ ^ y.data_) & 0x8000) != 0; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absy != 0x7C00) ? x.data_ + : (sub && absx == 0x7C00) ? detail::invalid() : y.data_); + if(!absx) + return absy ? y + : half(detail::binary, + (half::round_style == std::round_toward_neg_infinity) + ? (x.data_ | y.data_) + : (x.data_ & y.data_)); + if(!absy) + return x; + unsigned int sign = ((sub && absy > absx) ? y.data_ : x.data_) & 0x8000; + if(absy > absx) + std::swap(absx, absy); + int exp = (absx >> 10) + (absx <= 0x3FF), d = exp - (absy >> 10) - (absy <= 0x3FF), + mx = ((absx & 0x3FF) | ((absx > 0x3FF) << 10)) << 3, my; + if(d < 13) + { + my = ((absy & 0x3FF) | ((absy > 0x3FF) << 10)) << 3; + my = (my >> d) | ((my & ((1 << d) - 1)) != 0); + } + else + my = 1; + if(sub) + { + if(!(mx -= my)) + return half(detail::binary, + static_cast(half::round_style == std::round_toward_neg_infinity) + << 15); + for(; mx < 0x2000 && exp > 1; mx <<= 1, --exp) + ; + } + else + { + mx += my; + int i = mx >> 14; + if((exp += i) > 30) + return half(detail::binary, detail::overflow(sign)); + mx = (mx >> i) | (mx & i); + } + return half(detail::binary, + detail::rounded( + sign + ((exp - 1) << 10) + (mx >> 3), (mx >> 2) & 1, (mx & 0x3) != 0)); +#endif +} + +/// Subtraction. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return difference of half expressions +/// \exception FE_INVALID if \a x and \a y are infinities with equal signs or signaling NaNs +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator-(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half( + detail::binary, + detail::float2half(detail::half2float(x.data_) - + detail::half2float(y.data_))); +#else + return x + -y; +#endif +} + +/// Multiplication. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return product of half expressions +/// \exception FE_INVALID if multiplying 0 with infinity or if \a x or \a y is signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator*(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half( + detail::binary, + detail::float2half(detail::half2float(x.data_) * + detail::half2float(y.data_))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -16; + unsigned int sign = (x.data_ ^ y.data_) & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : ((absx == 0x7C00 && !absy) || (absy == 0x7C00 && !absx)) + ? detail::invalid() + : (sign | 0x7C00)); + if(!absx || !absy) + return half(detail::binary, sign); + for(; absx < 0x400; absx <<= 1, --exp) + ; + for(; absy < 0x400; absy <<= 1, --exp) + ; + detail::uint32 m = static_cast((absx & 0x3FF) | 0x400) * + static_cast((absy & 0x3FF) | 0x400); + int i = m >> 21, s = m & i; + exp += (absx >> 10) + (absy >> 10) + i; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -11) + return half(detail::binary, detail::underflow(sign)); + return half( + detail::binary, + detail::fixed2half(m >> i, exp, sign, s)); +#endif +} + +/// Division. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return quotient of half expressions +/// \exception FE_INVALID if dividing 0s or infinities with each other or if \a x or \a y is +/// signaling NaN +/// \exception FE_DIVBYZERO if dividing finite value by 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator/(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half( + detail::binary, + detail::float2half(detail::half2float(x.data_) / + detail::half2float(y.data_))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = 14; + unsigned int sign = (x.data_ ^ y.data_) & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == absy) ? detail::invalid() + : (sign | ((absx == 0x7C00) ? 0x7C00 : 0))); + if(!absx) + return half(detail::binary, absy ? sign : detail::invalid()); + if(!absy) + return half(detail::binary, detail::pole(sign)); + for(; absx < 0x400; absx <<= 1, --exp) + ; + for(; absy < 0x400; absy <<= 1, ++exp) + ; + detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; + int i = mx < my; + exp += (absx >> 10) - (absy >> 10) - i; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -11) + return half(detail::binary, detail::underflow(sign)); + mx <<= 12 + i; + my <<= 1; + return half(detail::binary, + detail::fixed2half( + mx / my, exp, sign, mx % my != 0)); +#endif +} + +/// \} +/// \anchor streaming +/// \name Input and output +/// \{ + +/// Output operator. +/// This uses the built-in functionality for streaming out floating-point numbers. +/// \param out output stream to write into +/// \param arg half expression to write +/// \return reference to output stream +template +std::basic_ostream& operator<<(std::basic_ostream& out, half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return out << detail::half2float(arg.data_); +#else + return out << detail::half2float(arg.data_); +#endif +} + +/// Input operator. +/// This uses the built-in functionality for streaming in floating-point numbers, specifically +/// double precision floating +/// point numbers (unless overridden with [HALF_ARITHMETIC_TYPE](\ref HALF_ARITHMETIC_TYPE)). So the +/// input string is first +/// rounded to double precision using the underlying platform's current floating-point rounding mode +/// before being rounded +/// to half-precision using the library's half-precision rounding mode. +/// \param in input stream to read from +/// \param arg half to read into +/// \return reference to input stream +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +template +std::basic_istream& operator>>(std::basic_istream& in, half& arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t f; +#else + double f; +#endif + if(in >> f) + arg.data_ = detail::float2half(f); + return in; +} + +/// \} +/// \anchor basic +/// \name Basic mathematical operations +/// \{ + +/// Absolute value. +/// **See also:** Documentation for +/// [std::fabs](https://en.cppreference.com/w/cpp/numeric/math/fabs). +/// \param arg operand +/// \return absolute value of \a arg +inline HALF_CONSTEXPR half fabs(half arg) { return half(detail::binary, arg.data_ & 0x7FFF); } + +/// Absolute value. +/// **See also:** Documentation for [std::abs](https://en.cppreference.com/w/cpp/numeric/math/fabs). +/// \param arg operand +/// \return absolute value of \a arg +inline HALF_CONSTEXPR half abs(half arg) { return fabs(arg); } + +/// Remainder of division. +/// **See also:** Documentation for +/// [std::fmod](https://en.cppreference.com/w/cpp/numeric/math/fmod). +/// \param x first operand +/// \param y second operand +/// \return remainder of floating-point division. +/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN +inline half fmod(half x, half y) +{ + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, sign = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == 0x7C00) ? detail::invalid() : x.data_); + if(!absy) + return half(detail::binary, detail::invalid()); + if(!absx) + return x; + if(absx == absy) + return half(detail::binary, sign); + return half(detail::binary, sign | detail::mod(absx, absy)); +} + +/// Remainder of division. +/// **See also:** Documentation for +/// [std::remainder](https://en.cppreference.com/w/cpp/numeric/math/remainder). +/// \param x first operand +/// \param y second operand +/// \return remainder of floating-point division. +/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN +inline half remainder(half x, half y) +{ + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, sign = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == 0x7C00) ? detail::invalid() : x.data_); + if(!absy) + return half(detail::binary, detail::invalid()); + if(absx == absy) + return half(detail::binary, sign); + return half(detail::binary, sign ^ detail::mod(absx, absy)); +} + +/// Remainder of division. +/// **See also:** Documentation for +/// [std::remquo](https://en.cppreference.com/w/cpp/numeric/math/remquo). +/// \param x first operand +/// \param y second operand +/// \param quo address to store some bits of quotient at +/// \return remainder of floating-point division. +/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN +inline half remquo(half x, half y, int* quo) +{ + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, value = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == 0x7C00) ? detail::invalid() : (*quo = 0, x.data_)); + if(!absy) + return half(detail::binary, detail::invalid()); + bool qsign = ((value ^ y.data_) & 0x8000) != 0; + int q = 1; + if(absx != absy) + value ^= detail::mod(absx, absy, &q); + return *quo = qsign ? -q : q, half(detail::binary, value); +} + +/// Fused multiply add. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::fma](https://en.cppreference.com/w/cpp/numeric/math/fma). +/// \param x first operand +/// \param y second operand +/// \param z third operand +/// \return ( \a x * \a y ) + \a z rounded as one operation. +/// \exception FE_INVALID according to operator*() and operator+() unless any argument is a quiet +/// NaN and no argument is a signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding the final addition +inline half fma(half x, half y, half z) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), + fy = detail::half2float(y.data_), + fz = detail::half2float(z.data_); +#if HALF_ENABLE_CPP11_CMATH && FP_FAST_FMA + return half(detail::binary, detail::float2half(std::fma(fx, fy, fz))); +#else + return half(detail::binary, detail::float2half(fx * fy + fz)); +#endif +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, exp = -15; + unsigned int sign = (x.data_ ^ y.data_) & 0x8000; + bool sub = ((sign ^ z.data_) & 0x8000) != 0; + if(absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) + return (absx > 0x7C00 || absy > 0x7C00 || absz > 0x7C00) + ? half(detail::binary, detail::signal(x.data_, y.data_, z.data_)) + : (absx == 0x7C00) ? half(detail::binary, + (!absy || (sub && absz == 0x7C00)) ? detail::invalid() + : (sign | 0x7C00)) + : (absy == 0x7C00) ? half(detail::binary, + (!absx || (sub && absz == 0x7C00)) + ? detail::invalid() + : (sign | 0x7C00)) + : z; + if(!absx || !absy) + return absz + ? z + : half(detail::binary, + (half::round_style == std::round_toward_neg_infinity) ? (z.data_ | sign) + : (z.data_ & sign)); + for(; absx < 0x400; absx <<= 1, --exp) + ; + for(; absy < 0x400; absy <<= 1, --exp) + ; + detail::uint32 m = static_cast((absx & 0x3FF) | 0x400) * + static_cast((absy & 0x3FF) | 0x400); + int i = m >> 21; + exp += (absx >> 10) + (absy >> 10) + i; + m <<= 3 - i; + if(absz) + { + int expz = 0; + for(; absz < 0x400; absz <<= 1, --expz) + ; + expz += absz >> 10; + detail::uint32 mz = static_cast((absz & 0x3FF) | 0x400) << 13; + if(expz > exp || (expz == exp && mz > m)) + { + std::swap(m, mz); + std::swap(exp, expz); + if(sub) + sign = z.data_ & 0x8000; + } + int d = exp - expz; + mz = (d < 23) ? ((mz >> d) | ((mz & ((static_cast(1) << d) - 1)) != 0)) : 1; + if(sub) + { + m = m - mz; + if(!m) + return half( + detail::binary, + static_cast(half::round_style == std::round_toward_neg_infinity) + << 15); + for(; m < 0x800000; m <<= 1, --exp) + ; + } + else + { + m += mz; + i = m >> 24; + m = (m >> i) | (m & i); + exp += i; + } + } + if(exp > 30) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -10) + return half(detail::binary, detail::underflow(sign)); + return half(detail::binary, + detail::fixed2half(m, exp - 1, sign)); +#endif +} + +/// Maximum of half expressions. +/// **See also:** Documentation for +/// [std::fmax](https://en.cppreference.com/w/cpp/numeric/math/fmax). +/// \param x first operand +/// \param y second operand +/// \return maximum of operands, ignoring quiet NaNs +/// \exception FE_INVALID if \a x or \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR half fmax(half x, half y) +{ + return half(detail::binary, + (!isnan(y) && (isnan(x) || (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) < + (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))))) + ? detail::select(y.data_, x.data_) + : detail::select(x.data_, y.data_)); +} + +/// Minimum of half expressions. +/// **See also:** Documentation for +/// [std::fmin](https://en.cppreference.com/w/cpp/numeric/math/fmin). +/// \param x first operand +/// \param y second operand +/// \return minimum of operands, ignoring quiet NaNs +/// \exception FE_INVALID if \a x or \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR half fmin(half x, half y) +{ + return half(detail::binary, + (!isnan(y) && (isnan(x) || (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) > + (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))))) + ? detail::select(y.data_, x.data_) + : detail::select(x.data_, y.data_)); +} + +/// Positive difference. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::fdim](https://en.cppreference.com/w/cpp/numeric/math/fdim). +/// \param x first operand +/// \param y second operand +/// \return \a x - \a y or 0 if difference negative +/// \exception FE_... according to operator-(half,half) +inline half fdim(half x, half y) +{ + if(isnan(x) || isnan(y)) + return half(detail::binary, detail::signal(x.data_, y.data_)); + return (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) <= + (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + ? half(detail::binary, 0) + : (x - y); +} + +/// Get NaN value. +/// **See also:** Documentation for [std::nan](https://en.cppreference.com/w/cpp/numeric/math/nan). +/// \param arg string code +/// \return quiet NaN +inline half nanh(const char* arg) +{ + unsigned int value = 0x7FFF; + while(*arg) + value ^= static_cast(*arg++) & 0xFF; + return half(detail::binary, value); +} + +/// \} +/// \anchor exponential +/// \name Exponential functions +/// \{ + +/// Exponential function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::exp](https://en.cppreference.com/w/cpp/numeric/math/exp). +/// \param arg function argument +/// \return e raised to \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half exp(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::exp(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? (0x7C00 & ((arg.data_ >> 15) - 1U)) + : detail::signal(arg.data_)); + if(abs >= 0x4C80) + return half(detail::binary, + (arg.data_ & 0x8000) ? detail::underflow() + : detail::overflow()); + detail::uint32 m = detail::multiply64( + static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, 0xB8AA3B29); + int e = (abs >> 10) + (abs <= 0x3FF), exp; + if(e < 14) + { + exp = 0; + m >>= 14 - e; + } + else + { + exp = m >> (45 - e); + m = (m << (e - 14)) & 0x7FFFFFFF; + } + return half(detail::binary, + detail::exp2_post( + detail::exp2(m, 26), exp, (arg.data_ & 0x8000) != 0)); +#endif +} + +/// Binary exponential. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::exp2](https://en.cppreference.com/w/cpp/numeric/math/exp2). +/// \param arg function argument +/// \return 2 raised to \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half exp2(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::exp2(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? (0x7C00 & ((arg.data_ >> 15) - 1U)) + : detail::signal(arg.data_)); + if(abs >= 0x4E40) + return half(detail::binary, + (arg.data_ & 0x8000) ? detail::underflow() + : detail::overflow()); + int e = (abs >> 10) + (abs <= 0x3FF), exp = (abs & 0x3FF) + ((abs > 0x3FF) << 10); + detail::uint32 m = detail::exp2((static_cast(exp) << (6 + e)) & 0x7FFFFFFF, 28); + exp >>= 25 - e; + if(m == 0x80000000) + { + if(arg.data_ & 0x8000) + exp = -exp; + else if(exp > 15) + return half(detail::binary, detail::overflow()); + return half(detail::binary, + detail::fixed2half(m, exp + 14)); + } + return half(detail::binary, + detail::exp2_post(m, exp, (arg.data_ & 0x8000) != 0)); +#endif +} + +/// Exponential minus one. +/// This function may be 1 ULP off the correctly rounded exact result in <0.05% of inputs for +/// `std::round_to_nearest` +/// and in <1% of inputs for any other rounding mode. +/// +/// **See also:** Documentation for +/// [std::expm1](https://en.cppreference.com/w/cpp/numeric/math/expm1). +/// \param arg function argument +/// \return e raised to \a arg and subtracted by 1 +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half expm1(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::expm1(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? (0x7C00 + (sign >> 1)) : detail::signal(arg.data_)); + if(abs >= 0x4A00) + return half(detail::binary, + (arg.data_ & 0x8000) ? detail::rounded(0xBBFF, 1, 1) + : detail::overflow()); + detail::uint32 m = detail::multiply64( + static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, 0xB8AA3B29); + int e = (abs >> 10) + (abs <= 0x3FF), exp; + if(e < 14) + { + exp = 0; + m >>= 14 - e; + } + else + { + exp = m >> (45 - e); + m = (m << (e - 14)) & 0x7FFFFFFF; + } + m = detail::exp2(m); + if(sign) + { + int s = 0; + if(m > 0x80000000) + { + ++exp; + m = detail::divide64(0x80000000, m, s); + } + m = 0x80000000 - + ((m >> exp) | ((m & ((static_cast(1) << exp) - 1)) != 0) | s); + exp = 0; + } + else + m -= (exp < 31) ? (0x80000000 >> exp) : 1; + for(exp += 14; m < 0x80000000 && exp; m <<= 1, --exp) + ; + if(exp > 29) + return half(detail::binary, detail::overflow()); + return half(detail::binary, + detail::rounded( + sign + (exp << 10) + (m >> 21), (m >> 20) & 1, (m & 0xFFFFF) != 0)); +#endif +} + +/// Natural logarithm. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::log](https://en.cppreference.com/w/cpp/numeric/math/log). +/// \param arg function argument +/// \return logarithm of \a arg to base e +/// \exception FE_INVALID for signaling NaN or negative argument +/// \exception FE_DIVBYZERO for 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::log(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, + (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + return half(detail::binary, + detail::log2_post( + detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, 27) + 8, + exp, + 17)); +#endif +} + +/// Common logarithm. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::log10](https://en.cppreference.com/w/cpp/numeric/math/log10). +/// \param arg function argument +/// \return logarithm of \a arg to base 10 +/// \exception FE_INVALID for signaling NaN or negative argument +/// \exception FE_DIVBYZERO for 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log10(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::log10(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, + (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + switch(abs) + { + case 0x4900: return half(detail::binary, 0x3C00); + case 0x5640: return half(detail::binary, 0x4000); + case 0x63D0: return half(detail::binary, 0x4200); + case 0x70E2: return half(detail::binary, 0x4400); + } + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + return half(detail::binary, + detail::log2_post( + detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, 27) + 8, + exp, + 16)); +#endif +} + +/// Binary logarithm. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::log2](https://en.cppreference.com/w/cpp/numeric/math/log2). +/// \param arg function argument +/// \return logarithm of \a arg to base 2 +/// \exception FE_INVALID for signaling NaN or negative argument +/// \exception FE_DIVBYZERO for 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log2(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::log2(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15, s = 0; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, + (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + if(abs == 0x3C00) + return half(detail::binary, 0); + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += (abs >> 10); + if(!(abs & 0x3FF)) + { + unsigned int value = static_cast(exp < 0) << 15, m = std::abs(exp) << 6; + for(exp = 18; m < 0x400; m <<= 1, --exp) + ; + return half(detail::binary, value + (exp << 10) + m); + } + detail::uint32 ilog = exp, sign = detail::sign_mask(ilog), + m = (((ilog << 27) + + (detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, + 28) >> + 4)) ^ + sign) - + sign; + if(!m) + return half(detail::binary, 0); + for(exp = 14; m < 0x8000000 && exp; m <<= 1, --exp) + ; + for(; m > 0xFFFFFFF; m >>= 1, ++exp) + s |= m & 1; + return half( + detail::binary, + detail::fixed2half(m, exp, sign & 0x8000, s)); +#endif +} + +/// Natural logarithm plus one. +/// This function may be 1 ULP off the correctly rounded exact result in <0.05% of inputs for +/// `std::round_to_nearest` +/// and in ~1% of inputs for any other rounding mode. +/// +/// **See also:** Documentation for +/// [std::log1p](https://en.cppreference.com/w/cpp/numeric/math/log1p). +/// \param arg function argument +/// \return logarithm of \a arg plus 1 to base e +/// \exception FE_INVALID for signaling NaN or argument <-1 +/// \exception FE_DIVBYZERO for -1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log1p(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::log1p(detail::half2float(arg.data_)))); +#else + if(arg.data_ >= 0xBC00) + return half(detail::binary, + (arg.data_ == 0xBC00) + ? detail::pole(0x8000) + : (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + detail::uint32 m = static_cast((abs & 0x3FF) | 0x400) << 20; + if(arg.data_ & 0x8000) + { + m = 0x40000000 - (m >> -exp); + for(exp = 0; m < 0x40000000; m <<= 1, --exp) + ; + } + else + { + if(exp < 0) + { + m = 0x40000000 + (m >> -exp); + exp = 0; + } + else + { + m += 0x40000000 >> exp; + int i = m >> 31; + m >>= i; + exp += i; + } + } + return half(detail::binary, + detail::log2_post(detail::log2(m), exp, 17)); +#endif +} + +/// \} +/// \anchor power +/// \name Power functions +/// \{ + +/// Square root. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::sqrt](https://en.cppreference.com/w/cpp/numeric/math/sqrt). +/// \param arg function argument +/// \return square root of \a arg +/// \exception FE_INVALID for signaling NaN and negative arguments +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half sqrt(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::sqrt(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = 15; + if(!abs || arg.data_ >= 0x7C00) + return half(detail::binary, + (abs > 0x7C00) ? detail::signal(arg.data_) + : (arg.data_ > 0x8000) ? detail::invalid() : arg.data_); + for(; abs < 0x400; abs <<= 1, --exp) + ; + detail::uint32 r = static_cast((abs & 0x3FF) | 0x400) << 10, + m = detail::sqrt<20>(r, exp += abs >> 10); + return half( + detail::binary, + detail::rounded((exp << 10) + (m & 0x3FF), r > m, r != 0)); +#endif +} + +/// Cubic root. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::cbrt](https://en.cppreference.com/w/cpp/numeric/math/cbrt). +/// \param arg function argument +/// \return cubic root of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half cbrt(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::cbrt(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs || abs == 0x3C00 || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs < 0x400; abs <<= 1, --exp) + ; + detail::uint32 ilog = exp + (abs >> 10), sign = detail::sign_mask(ilog), f, + m = (((ilog << 27) + + (detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, + 24) >> + 4)) ^ + sign) - + sign; + for(exp = 2; m < 0x80000000; m <<= 1, --exp) + ; + m = detail::multiply64(m, 0xAAAAAAAB); + int i = m >> 31, s; + exp += i; + m <<= 1 - i; + if(exp < 0) + { + f = m >> -exp; + exp = 0; + } + else + { + f = (m << exp) & 0x7FFFFFFF; + exp = m >> (31 - exp); + } + m = detail::exp2(f, (half::round_style == std::round_to_nearest) ? 29 : 26); + if(sign) + { + if(m > 0x80000000) + { + m = detail::divide64(0x80000000, m, s); + ++exp; + } + exp = -exp; + } + return half(detail::binary, + (half::round_style == std::round_to_nearest) + ? detail::fixed2half( + m, exp + 14, arg.data_ & 0x8000) + : detail::fixed2half( + (m + 0x80) >> 8, exp + 14, arg.data_ & 0x8000)); +#endif +} + +/// Hypotenuse function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). +/// \param x first argument +/// \param y second argument +/// \return square root of sum of squares without internal over- or underflows +/// \exception FE_INVALID if \a x or \a y is signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of the final square root +inline half hypot(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), + fy = detail::half2float(y.data_); +#if HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::hypot(fx, fy))); +#else + return half(detail::binary, + detail::float2half(std::sqrt(fx * fx + fy * fy))); +#endif +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, expx = 0, expy = 0; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx == 0x7C00) ? detail::select(0x7C00, y.data_) + : (absy == 0x7C00) ? detail::select(0x7C00, x.data_) + : detail::signal(x.data_, y.data_)); + if(!absx) + return half(detail::binary, absy ? detail::check_underflow(absy) : 0); + if(!absy) + return half(detail::binary, detail::check_underflow(absx)); + if(absy > absx) + std::swap(absx, absy); + for(; absx < 0x400; absx <<= 1, --expx) + ; + for(; absy < 0x400; absy <<= 1, --expy) + ; + detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; + mx *= mx; + my *= my; + int ix = mx >> 21, iy = my >> 21; + expx = 2 * (expx + (absx >> 10)) - 15 + ix; + expy = 2 * (expy + (absy >> 10)) - 15 + iy; + mx <<= 10 - ix; + my <<= 10 - iy; + int d = expx - expy; + my = (d < 30) ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) : 1; + return half(detail::binary, detail::hypot_post(mx + my, expx)); +#endif +} + +/// Hypotenuse function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). +/// \param x first argument +/// \param y second argument +/// \param z third argument +/// \return square root of sum of squares without internal over- or underflows +/// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of the final square root +inline half hypot(half x, half y, half z) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), + fy = detail::half2float(y.data_), + fz = detail::half2float(z.data_); + return half(detail::binary, + detail::float2half(std::sqrt(fx * fx + fy * fy + fz * fz))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, expx = 0, + expy = 0, expz = 0; + if(!absx) + return hypot(y, z); + if(!absy) + return hypot(x, z); + if(!absz) + return hypot(x, y); + if(absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) + return half(detail::binary, + (absx == 0x7C00) + ? detail::select(0x7C00, detail::select(y.data_, z.data_)) + : (absy == 0x7C00) + ? detail::select(0x7C00, detail::select(x.data_, z.data_)) + : (absz == 0x7C00) + ? detail::select(0x7C00, detail::select(x.data_, y.data_)) + : detail::signal(x.data_, y.data_, z.data_)); + if(absz > absy) + std::swap(absy, absz); + if(absy > absx) + std::swap(absx, absy); + if(absz > absy) + std::swap(absy, absz); + for(; absx < 0x400; absx <<= 1, --expx) + ; + for(; absy < 0x400; absy <<= 1, --expy) + ; + for(; absz < 0x400; absz <<= 1, --expz) + ; + detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400, + mz = (absz & 0x3FF) | 0x400; + mx *= mx; + my *= my; + mz *= mz; + int ix = mx >> 21, iy = my >> 21, iz = mz >> 21; + expx = 2 * (expx + (absx >> 10)) - 15 + ix; + expy = 2 * (expy + (absy >> 10)) - 15 + iy; + expz = 2 * (expz + (absz >> 10)) - 15 + iz; + mx <<= 10 - ix; + my <<= 10 - iy; + mz <<= 10 - iz; + int d = expy - expz; + mz = (d < 30) ? ((mz >> d) | ((mz & ((static_cast(1) << d) - 1)) != 0)) : 1; + my += mz; + if(my & 0x80000000) + { + my = (my >> 1) | (my & 1); + if(++expy > expx) + { + std::swap(mx, my); + std::swap(expx, expy); + } + } + d = expx - expy; + my = (d < 30) ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) : 1; + return half(detail::binary, detail::hypot_post(mx + my, expx)); +#endif +} + +/// Power function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in +/// ~0.00025% of inputs. +/// +/// **See also:** Documentation for [std::pow](https://en.cppreference.com/w/cpp/numeric/math/pow). +/// \param x base +/// \param y exponent +/// \return \a x raised to \a y +/// \exception FE_INVALID if \a x or \a y is signaling NaN or if \a x is finite an negative and \a y +/// is finite and not integral +/// \exception FE_DIVBYZERO if \a x is 0 and \a y is negative +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half pow(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::pow(detail::half2float(x.data_), + detail::half2float(y.data_)))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -15; + if(!absy || x.data_ == 0x3C00) + return half(detail::binary, + detail::select(0x3C00, (x.data_ == 0x3C00) ? y.data_ : x.data_)); + bool is_int = absy >= 0x6400 || (absy >= 0x3C00 && !(absy & ((1 << (25 - (absy >> 10))) - 1))); + unsigned int sign = + x.data_ & + (static_cast((absy < 0x6800) && is_int && ((absy >> (25 - (absy >> 10))) & 1)) + << 15); + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absy == 0x7C00) + ? ((absx == 0x3C00) + ? 0x3C00 + : (!absx && y.data_ == 0xFC00) + ? detail::pole() + : (0x7C00 & -((y.data_ >> 15) ^ (absx > 0x3C00)))) + : (sign | (0x7C00 & ((y.data_ >> 15) - 1U)))); + if(!absx) + return half(detail::binary, (y.data_ & 0x8000) ? detail::pole(sign) : sign); + if((x.data_ & 0x8000) && !is_int) + return half(detail::binary, detail::invalid()); + if(x.data_ == 0xBC00) + return half(detail::binary, sign | 0x3C00); + if(y.data_ == 0x3800) + return sqrt(x); + if(y.data_ == 0x3C00) + return half(detail::binary, detail::check_underflow(x.data_)); + if(y.data_ == 0x4000) + return x * x; + for(; absx < 0x400; absx <<= 1, --exp) + ; + detail::uint32 ilog = exp + (absx >> 10), msign = detail::sign_mask(ilog), f, + m = (((ilog << 27) + + ((detail::log2(static_cast((absx & 0x3FF) | 0x400) << 20) + + 8) >> + 4)) ^ + msign) - + msign; + for(exp = -11; m < 0x80000000; m <<= 1, --exp) + ; + for(; absy < 0x400; absy <<= 1, --exp) + ; + m = detail::multiply64(m, static_cast((absy & 0x3FF) | 0x400) << 21); + int i = m >> 31; + exp += (absy >> 10) + i; + m <<= 1 - i; + if(exp < 0) + { + f = m >> -exp; + exp = 0; + } + else + { + f = (m << exp) & 0x7FFFFFFF; + exp = m >> (31 - exp); + } + return half(detail::binary, + detail::exp2_post( + detail::exp2(f), exp, ((msign & 1) ^ (y.data_ >> 15)) != 0, sign)); +#endif +} + +/// \} +/// \anchor trigonometric +/// \name Trigonometric functions +/// \{ + +/// Compute sine and cosine simultaneously. +/// This returns the same results as sin() and cos() but is faster than calling each function +/// individually. +/// +/// This function is exact to rounding for all rounding modes. +/// \param arg function argument +/// \param sin variable to take sine of \a arg +/// \param cos variable to take cosine of \a arg +/// \exception FE_INVALID for signaling NaN or infinity +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline void sincos(half arg, half* sin, half* cos) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t f = detail::half2float(arg.data_); + *sin = half(detail::binary, detail::float2half(std::sin(f))); + *cos = half(detail::binary, detail::float2half(std::cos(f))); +#else + int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15, k; + if(abs >= 0x7C00) + *sin = *cos = + half(detail::binary, (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + else if(!abs) + { + *sin = arg; + *cos = half(detail::binary, 0x3C00); + } + else if(abs < 0x2500) + { + *sin = half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + } + else + { + if(half::round_style != std::round_to_nearest) + { + switch(abs) + { + case 0x48B7: + *sin = half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x1D07, 1, 1)); + *cos = half(detail::binary, detail::rounded(0xBBFF, 1, 1)); + return; + case 0x598C: + *sin = half( + detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x3BFF, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x80FC, 1, 1)); + return; + case 0x6A64: + *sin = half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x3BFE, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x27FF, 1, 1)); + return; + case 0x6D8C: + *sin = half( + detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x0FE6, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + return; + } + } + std::pair sc = + detail::sincos(detail::angle_arg(abs, k), 28); + switch(k & 3) + { + case 1: sc = std::make_pair(sc.second, -sc.first); break; + case 2: sc = std::make_pair(-sc.first, -sc.second); break; + case 3: sc = std::make_pair(-sc.second, sc.first); break; + } + *sin = half(detail::binary, + detail::fixed2half( + (sc.first ^ -static_cast(sign)) + sign)); + *cos = half(detail::binary, + detail::fixed2half(sc.second)); + } +#endif +} + +/// Sine function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::sin](https://en.cppreference.com/w/cpp/numeric/math/sin). +/// \param arg function argument +/// \return sine value of \a arg +/// \exception FE_INVALID for signaling NaN or infinity +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half sin(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::sin(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, k; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2900) + return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x48B7: + return half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x1D07, 1, 1)); + case 0x6A64: + return half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x3BFE, 1, 1)); + case 0x6D8C: + return half( + detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x0FE6, 1, 1)); + } + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); + detail::uint32 sign = -static_cast(((k >> 1) & 1) ^ (arg.data_ >> 15)); + return half(detail::binary, + detail::fixed2half( + (((k & 1) ? sc.second : sc.first) ^ sign) - sign)); +#endif +} + +/// Cosine function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::cos](https://en.cppreference.com/w/cpp/numeric/math/cos). +/// \param arg function argument +/// \return cosine value of \a arg +/// \exception FE_INVALID for signaling NaN or infinity +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half cos(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::cos(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, k; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2500) + return half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + if(half::round_style != std::round_to_nearest && abs == 0x598C) + return half(detail::binary, detail::rounded(0x80FC, 1, 1)); + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); + detail::uint32 sign = -static_cast(((k >> 1) ^ k) & 1); + return half(detail::binary, + detail::fixed2half( + (((k & 1) ? sc.first : sc.second) ^ sign) - sign)); +#endif +} + +/// Tangent function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::tan](https://en.cppreference.com/w/cpp/numeric/math/tan). +/// \param arg function argument +/// \return tangent value of \a arg +/// \exception FE_INVALID for signaling NaN or infinity +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half tan(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::tan(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = 13, k; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x658C: + return half( + detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x07E6, 1, 1)); + case 0x7330: + return half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x4B62, 1, 1)); + } + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 30); + if(k & 1) + sc = std::make_pair(-sc.second, sc.first); + detail::uint32 signy = detail::sign_mask(sc.first), signx = detail::sign_mask(sc.second); + detail::uint32 my = (sc.first ^ signy) - signy, mx = (sc.second ^ signx) - signx; + for(; my < 0x80000000; my <<= 1, --exp) + ; + for(; mx < 0x80000000; mx <<= 1, ++exp) + ; + return half( + detail::binary, + detail::tangent_post(my, mx, exp, (signy ^ signx ^ arg.data_) & 0x8000)); +#endif +} + +/// Arc sine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::asin](https://en.cppreference.com/w/cpp/numeric/math/asin). +/// \param arg function argument +/// \return arc sine value of \a arg +/// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half asin(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::asin(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x3C00) + return half(detail::binary, + (abs > 0x7C00) + ? detail::signal(arg.data_) + : (abs > 0x3C00) + ? detail::invalid() + : detail::rounded(sign | 0x3E48, 0, 1)); + if(abs < 0x2900) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + if(half::round_style != std::round_to_nearest && (abs == 0x2B44 || abs == 0x2DC3)) + return half(detail::binary, detail::rounded(arg.data_ + 1, 1, 1)); + std::pair sc = detail::atan2_args(abs); + detail::uint32 m = + detail::atan2(sc.first, sc.second, (half::round_style == std::round_to_nearest) ? 27 : 26); + return half(detail::binary, + detail::fixed2half(m, 14, sign)); +#endif +} + +/// Arc cosine function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::acos](https://en.cppreference.com/w/cpp/numeric/math/acos). +/// \param arg function argument +/// \return arc cosine value of \a arg +/// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half acos(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::acos(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15; + if(!abs) + return half(detail::binary, detail::rounded(0x3E48, 0, 1)); + if(abs >= 0x3C00) + return half(detail::binary, + (abs > 0x7C00) + ? detail::signal(arg.data_) + : (abs > 0x3C00) + ? detail::invalid() + : sign ? detail::rounded(0x4248, 0, 1) : 0); + std::pair cs = detail::atan2_args(abs); + detail::uint32 m = detail::atan2(cs.second, cs.first, 28); + return half(detail::binary, + detail::fixed2half( + sign ? (0xC90FDAA2 - m) : m, 15, 0, sign)); +#endif +} + +/// Arc tangent function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::atan](https://en.cppreference.com/w/cpp/numeric/math/atan). +/// \param arg function argument +/// \return arc tangent value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half atan(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::atan(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? detail::rounded(sign | 0x3E48, 0, 1) + : detail::signal(arg.data_)); + if(abs <= 0x2700) + return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + int exp = (abs >> 10) + (abs <= 0x3FF); + detail::uint32 my = (abs & 0x3FF) | ((abs > 0x3FF) << 10); + detail::uint32 m = (exp > 15) + ? detail::atan2(my << 19, + 0x20000000 >> (exp - 15), + (half::round_style == std::round_to_nearest) ? 26 : 24) + : detail::atan2(my << (exp + 4), + 0x20000000, + (half::round_style == std::round_to_nearest) ? 30 : 28); + return half(detail::binary, + detail::fixed2half(m, 14, sign)); +#endif +} + +/// Arc tangent function. +/// This function may be 1 ULP off the correctly rounded exact result in ~0.005% of inputs for +/// `std::round_to_nearest`, +/// in ~0.1% of inputs for `std::round_toward_zero` and in ~0.02% of inputs for any other rounding +/// mode. +/// +/// **See also:** Documentation for +/// [std::atan2](https://en.cppreference.com/w/cpp/numeric/math/atan2). +/// \param y numerator +/// \param x denominator +/// \return arc tangent value +/// \exception FE_INVALID if \a x or \a y is signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half atan2(half y, half x) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::atan2(detail::half2float(y.data_), + detail::half2float(x.data_)))); +#else + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, signx = x.data_ >> 15, + signy = y.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + { + if(absx > 0x7C00 || absy > 0x7C00) + return half(detail::binary, detail::signal(x.data_, y.data_)); + if(absy == 0x7C00) + return half(detail::binary, + (absx < 0x7C00) + ? detail::rounded(signy | 0x3E48, 0, 1) + : signx + ? detail::rounded(signy | 0x40B6, 0, 1) + : detail::rounded(signy | 0x3A48, 0, 1)); + return (x.data_ == 0x7C00) + ? half(detail::binary, signy) + : half(detail::binary, + detail::rounded(signy | 0x4248, 0, 1)); + } + if(!absy) + return signx ? half(detail::binary, + detail::rounded(signy | 0x4248, 0, 1)) + : y; + if(!absx) + return half(detail::binary, detail::rounded(signy | 0x3E48, 0, 1)); + int d = (absy >> 10) + (absy <= 0x3FF) - (absx >> 10) - (absx <= 0x3FF); + if(d > (signx ? 18 : 12)) + return half(detail::binary, detail::rounded(signy | 0x3E48, 0, 1)); + if(signx && d < -11) + return half(detail::binary, detail::rounded(signy | 0x4248, 0, 1)); + if(!signx && d < ((half::round_style == std::round_toward_zero) ? -15 : -9)) + { + for(; absy < 0x400; absy <<= 1, --d) + ; + detail::uint32 mx = ((absx << 1) & 0x7FF) | 0x800, my = ((absy << 1) & 0x7FF) | 0x800; + int i = my < mx; + d -= i; + if(d < -25) + return half(detail::binary, detail::underflow(signy)); + my <<= 11 + i; + return half(detail::binary, + detail::fixed2half( + my / mx, d + 14, signy, my % mx != 0)); + } + detail::uint32 m = detail::atan2( + ((absy & 0x3FF) | ((absy > 0x3FF) << 10)) << (19 + ((d < 0) ? d : (d > 0) ? 0 : -1)), + ((absx & 0x3FF) | ((absx > 0x3FF) << 10)) << (19 - ((d > 0) ? d : (d < 0) ? 0 : 1))); + return half(detail::binary, + detail::fixed2half( + signx ? (0xC90FDAA2 - m) : m, 15, signy, signx)); +#endif +} + +/// \} +/// \anchor hyperbolic +/// \name Hyperbolic functions +/// \{ + +/// Hyperbolic sine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::sinh](https://en.cppreference.com/w/cpp/numeric/math/sinh). +/// \param arg function argument +/// \return hyperbolic sine value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half sinh(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::sinh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + if(abs <= 0x2900) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + std::pair mm = + detail::hyperbolic_args(abs, exp, (half::round_style == std::round_to_nearest) ? 29 : 27); + detail::uint32 m = mm.first - mm.second; + for(exp += 13; m < 0x80000000 && exp; m <<= 1, --exp) + ; + unsigned int sign = arg.data_ & 0x8000; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + return half(detail::binary, + detail::fixed2half(m, exp, sign)); +#endif +} + +/// Hyperbolic cosine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::cosh](https://en.cppreference.com/w/cpp/numeric/math/cosh). +/// \param arg function argument +/// \return hyperbolic cosine value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half cosh(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::cosh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, (abs > 0x7C00) ? detail::signal(arg.data_) : 0x7C00); + std::pair mm = + detail::hyperbolic_args(abs, exp, (half::round_style == std::round_to_nearest) ? 23 : 26); + detail::uint32 m = mm.first + mm.second, i = (~m & 0xFFFFFFFF) >> 31; + m = (m >> i) | (m & i) | 0x80000000; + if((exp += 13 + i) > 29) + return half(detail::binary, detail::overflow()); + return half(detail::binary, + detail::fixed2half(m, exp)); +#endif +} + +/// Hyperbolic tangent. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::tanh](https://en.cppreference.com/w/cpp/numeric/math/tanh). +/// \param arg function argument +/// \return hyperbolic tangent value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half tanh(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::tanh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs > 0x7C00) ? detail::signal(arg.data_) : (arg.data_ - 0x4000)); + if(abs >= 0x4500) + return half(detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x3BFF, 1, 1)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + if(half::round_style != std::round_to_nearest && abs == 0x2D3F) + return half(detail::binary, detail::rounded(arg.data_ - 3, 0, 1)); + std::pair mm = detail::hyperbolic_args(abs, exp, 27); + detail::uint32 my = mm.first - mm.second - (half::round_style != std::round_to_nearest), + mx = mm.first + mm.second, i = (~mx & 0xFFFFFFFF) >> 31; + for(exp = 13; my < 0x80000000; my <<= 1, --exp) + ; + mx = (mx >> i) | 0x80000000; + return half(detail::binary, + detail::tangent_post(my, mx, exp - i, arg.data_ & 0x8000)); +#endif +} + +/// Hyperbolic area sine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::asinh](https://en.cppreference.com/w/cpp/numeric/math/asinh). +/// \param arg function argument +/// \return area sine value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half asinh(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::asinh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if(!abs || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + if(abs <= 0x2900) + return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x32D4: + return half(detail::binary, + detail::rounded(arg.data_ - 13, 1, 1)); + case 0x3B5B: + return half(detail::binary, + detail::rounded(arg.data_ - 197, 1, 1)); + } + return half(detail::binary, detail::area(arg.data_)); +#endif +} + +/// Hyperbolic area cosine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::acosh](https://en.cppreference.com/w/cpp/numeric/math/acosh). +/// \param arg function argument +/// \return area cosine value of \a arg +/// \exception FE_INVALID for signaling NaN or arguments <1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half acosh(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::acosh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if((arg.data_ & 0x8000) || abs < 0x3C00) + return half(detail::binary, + (abs <= 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs == 0x3C00) + return half(detail::binary, 0); + if(arg.data_ >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + return half(detail::binary, detail::area(arg.data_)); +#endif +} + +/// Hyperbolic area tangent. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::atanh](https://en.cppreference.com/w/cpp/numeric/math/atanh). +/// \param arg function argument +/// \return area tangent value of \a arg +/// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 +/// \exception FE_DIVBYZERO for +/-1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half atanh(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::atanh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = 0; + if(!abs) + return arg; + if(abs >= 0x3C00) + return half(detail::binary, + (abs == 0x3C00) + ? detail::pole(arg.data_ & 0x8000) + : (abs <= 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + detail::uint32 m = static_cast((abs & 0x3FF) | ((abs > 0x3FF) << 10)) + << ((abs >> 10) + (abs <= 0x3FF) + 6), + my = 0x80000000 + m, mx = 0x80000000 - m; + for(; mx < 0x80000000; mx <<= 1, ++exp) + ; + int i = my >= mx, s; + return half(detail::binary, + detail::log2_post( + detail::log2((detail::divide64(my >> i, mx, s) + 1) >> 1, 27) + 0x10, + exp + i - 1, + 16, + arg.data_ & 0x8000)); +#endif +} + +/// \} +/// \anchor special +/// \name Error and gamma functions +/// \{ + +/// Error function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.5% +/// of inputs. +/// +/// **See also:** Documentation for [std::erf](https://en.cppreference.com/w/cpp/numeric/math/erf). +/// \param arg function argument +/// \return error function value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half erf(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::erf(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF; + if(!abs || abs >= 0x7C00) + return (abs >= 0x7C00) + ? half(detail::binary, + (abs == 0x7C00) ? (arg.data_ - 0x4000) : detail::signal(arg.data_)) + : arg; + if(abs >= 0x4200) + return half(detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x3BFF, 1, 1)); + return half(detail::binary, detail::erf(arg.data_)); +#endif +} + +/// Complementary error function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.5% +/// of inputs. +/// +/// **See also:** Documentation for +/// [std::erfc](https://en.cppreference.com/w/cpp/numeric/math/erfc). +/// \param arg function argument +/// \return 1 minus error function value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half erfc(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::erfc(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(abs >= 0x7C00) + return (abs >= 0x7C00) + ? half(detail::binary, (abs == 0x7C00) ? (sign >> 1) : detail::signal(arg.data_)) + : arg; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x4400) + return half( + detail::binary, + detail::rounded((sign >> 1) - (sign >> 15), sign >> 15, 1)); + return half(detail::binary, detail::erf(arg.data_)); +#endif +} + +/// Natural logarithm of gamma function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in +/// ~0.025% of inputs. +/// +/// **See also:** Documentation for +/// [std::lgamma](https://en.cppreference.com/w/cpp/numeric/math/lgamma). +/// \param arg function argument +/// \return natural logarith of gamma function for \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_DIVBYZERO for 0 or negative integer arguments +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half lgamma(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::lgamma(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if(abs >= 0x7C00) + return half(detail::binary, (abs == 0x7C00) ? 0x7C00 : detail::signal(arg.data_)); + if(!abs || arg.data_ >= 0xE400 || + (arg.data_ >= 0xBC00 && !(abs & ((1 << (25 - (abs >> 10))) - 1)))) + return half(detail::binary, detail::pole()); + if(arg.data_ == 0x3C00 || arg.data_ == 0x4000) + return half(detail::binary, 0); + return half(detail::binary, detail::gamma(arg.data_)); +#endif +} + +/// Gamma function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in +/// <0.25% of inputs. +/// +/// **See also:** Documentation for +/// [std::tgamma](https://en.cppreference.com/w/cpp/numeric/math/tgamma). +/// \param arg function argument +/// \return gamma function value of \a arg +/// \exception FE_INVALID for signaling NaN, negative infinity or negative integer arguments +/// \exception FE_DIVBYZERO for 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half tgamma(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::tgamma(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(detail::binary, detail::pole(arg.data_)); + if(abs >= 0x7C00) + return (arg.data_ == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + if(arg.data_ >= 0xE400 || (arg.data_ >= 0xBC00 && !(abs & ((1 << (25 - (abs >> 10))) - 1)))) + return half(detail::binary, detail::invalid()); + if(arg.data_ >= 0xCA80) + return half( + detail::binary, + detail::underflow((1 - ((abs >> (25 - (abs >> 10))) & 1)) << 15)); + if(arg.data_ <= 0x100 || (arg.data_ >= 0x4900 && arg.data_ < 0x8000)) + return half(detail::binary, detail::overflow()); + if(arg.data_ == 0x3C00) + return arg; + return half(detail::binary, detail::gamma(arg.data_)); +#endif +} + +/// \} +/// \anchor rounding +/// \name Rounding +/// \{ + +/// Nearest integer not less than half value. +/// **See also:** Documentation for +/// [std::ceil](https://en.cppreference.com/w/cpp/numeric/math/ceil). +/// \param arg half to round +/// \return nearest integer not less than \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half ceil(half arg) +{ + return half(detail::binary, + detail::integral(arg.data_)); +} + +/// Nearest integer not greater than half value. +/// **See also:** Documentation for +/// [std::floor](https://en.cppreference.com/w/cpp/numeric/math/floor). +/// \param arg half to round +/// \return nearest integer not greater than \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half floor(half arg) +{ + return half(detail::binary, + detail::integral(arg.data_)); +} + +/// Nearest integer not greater in magnitude than half value. +/// **See also:** Documentation for +/// [std::trunc](https://en.cppreference.com/w/cpp/numeric/math/trunc). +/// \param arg half to round +/// \return nearest integer not greater in magnitude than \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half trunc(half arg) +{ + return half(detail::binary, detail::integral(arg.data_)); +} + +/// Nearest integer. +/// **See also:** Documentation for +/// [std::round](https://en.cppreference.com/w/cpp/numeric/math/round). +/// \param arg half to round +/// \return nearest integer, rounded away from zero in half-way cases +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half round(half arg) +{ + return half(detail::binary, detail::integral(arg.data_)); +} + +/// Nearest integer. +/// **See also:** Documentation for +/// [std::lround](https://en.cppreference.com/w/cpp/numeric/math/round). +/// \param arg half to round +/// \return nearest integer, rounded away from zero in half-way cases +/// \exception FE_INVALID if value is not representable as `long` +inline long lround(half arg) +{ + return detail::half2int(arg.data_); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::rint](https://en.cppreference.com/w/cpp/numeric/math/rint). +/// \param arg half expression to round +/// \return nearest integer using default rounding mode +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half rint(half arg) +{ + return half(detail::binary, detail::integral(arg.data_)); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::lrint](https://en.cppreference.com/w/cpp/numeric/math/rint). +/// \param arg half expression to round +/// \return nearest integer using default rounding mode +/// \exception FE_INVALID if value is not representable as `long` +/// \exception FE_INEXACT if value had to be rounded +inline long lrint(half arg) +{ + return detail::half2int(arg.data_); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::nearbyint](https://en.cppreference.com/w/cpp/numeric/math/nearbyint). +/// \param arg half expression to round +/// \return nearest integer using default rounding mode +/// \exception FE_INVALID for signaling NaN +inline half nearbyint(half arg) +{ + return half(detail::binary, detail::integral(arg.data_)); +} +#if HALF_ENABLE_CPP11_LONG_LONG +/// Nearest integer. +/// **See also:** Documentation for +/// [std::llround](https://en.cppreference.com/w/cpp/numeric/math/round). +/// \param arg half to round +/// \return nearest integer, rounded away from zero in half-way cases +/// \exception FE_INVALID if value is not representable as `long long` +inline long long llround(half arg) +{ + return detail::half2int(arg.data_); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::llrint](https://en.cppreference.com/w/cpp/numeric/math/rint). +/// \param arg half expression to round +/// \return nearest integer using default rounding mode +/// \exception FE_INVALID if value is not representable as `long long` +/// \exception FE_INEXACT if value had to be rounded +inline long long llrint(half arg) +{ + return detail::half2int(arg.data_); +} +#endif + +/// \} +/// \anchor float +/// \name Floating point manipulation +/// \{ + +/// Decompress floating-point number. +/// **See also:** Documentation for +/// [std::frexp](https://en.cppreference.com/w/cpp/numeric/math/frexp). +/// \param arg number to decompress +/// \param exp address to store exponent at +/// \return significant in range [0.5, 1) +/// \exception FE_INVALID for signaling NaN +inline half frexp(half arg, int* exp) +{ + *exp = 0; + unsigned int abs = arg.data_ & 0x7FFF; + if(abs >= 0x7C00 || !abs) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs < 0x400; abs <<= 1, --*exp) + ; + *exp += (abs >> 10) - 14; + return half(detail::binary, (arg.data_ & 0x8000) | 0x3800 | (abs & 0x3FF)); +} + +/// Multiply by power of two. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::scalbln](https://en.cppreference.com/w/cpp/numeric/math/scalbn). +/// \param arg number to modify +/// \param exp power of two to multiply with +/// \return \a arg multplied by 2 raised to \a exp +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half scalbln(half arg, long exp) +{ + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(abs >= 0x7C00 || !abs) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + if(exp > 30) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -10) + return half(detail::binary, detail::underflow(sign)); + else if(exp > 0) + return half(detail::binary, sign | (exp << 10) | (abs & 0x3FF)); + unsigned int m = (abs & 0x3FF) | 0x400; + return half(detail::binary, + detail::rounded( + sign | (m >> (1 - exp)), (m >> -exp) & 1, (m & ((1 << -exp) - 1)) != 0)); +} + +/// Multiply by power of two. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::scalbn](https://en.cppreference.com/w/cpp/numeric/math/scalbn). +/// \param arg number to modify +/// \param exp power of two to multiply with +/// \return \a arg multplied by 2 raised to \a exp +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half scalbn(half arg, int exp) { return scalbln(arg, exp); } + +/// Multiply by power of two. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::ldexp](https://en.cppreference.com/w/cpp/numeric/math/ldexp). +/// \param arg number to modify +/// \param exp power of two to multiply with +/// \return \a arg multplied by 2 raised to \a exp +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half ldexp(half arg, int exp) { return scalbln(arg, exp); } + +/// Extract integer and fractional parts. +/// **See also:** Documentation for +/// [std::modf](https://en.cppreference.com/w/cpp/numeric/math/modf). +/// \param arg number to decompress +/// \param iptr address to store integer part at +/// \return fractional part +/// \exception FE_INVALID for signaling NaN +inline half modf(half arg, half* iptr) +{ + unsigned int abs = arg.data_ & 0x7FFF; + if(abs > 0x7C00) + { + arg = half(detail::binary, detail::signal(arg.data_)); + return *iptr = arg, arg; + } + if(abs >= 0x6400) + return *iptr = arg, half(detail::binary, arg.data_ & 0x8000); + if(abs < 0x3C00) + return iptr->data_ = arg.data_ & 0x8000, arg; + unsigned int exp = abs >> 10, mask = (1 << (25 - exp)) - 1, m = arg.data_ & mask; + iptr->data_ = arg.data_ & ~mask; + if(!m) + return half(detail::binary, arg.data_ & 0x8000); + for(; m < 0x400; m <<= 1, --exp) + ; + return half(detail::binary, (arg.data_ & 0x8000) | (exp << 10) | (m & 0x3FF)); +} + +/// Extract exponent. +/// **See also:** Documentation for +/// [std::ilogb](https://en.cppreference.com/w/cpp/numeric/math/ilogb). +/// \param arg number to query +/// \return floating-point exponent +/// \retval FP_ILOGB0 for zero +/// \retval FP_ILOGBNAN for NaN +/// \retval INT_MAX for infinity +/// \exception FE_INVALID for 0 or infinite values +inline int ilogb(half arg) +{ + int abs = arg.data_ & 0x7FFF, exp; + if(!abs || abs >= 0x7C00) + { + detail::raise(FE_INVALID); + return !abs ? FP_ILOGB0 : (abs == 0x7C00) ? INT_MAX : FP_ILOGBNAN; + } + for(exp = (abs >> 10) - 15; abs < 0x200; abs <<= 1, --exp) + ; + return exp; +} + +/// Extract exponent. +/// **See also:** Documentation for +/// [std::logb](https://en.cppreference.com/w/cpp/numeric/math/logb). +/// \param arg number to query +/// \return floating-point exponent +/// \exception FE_INVALID for signaling NaN +/// \exception FE_DIVBYZERO for 0 +inline half logb(half arg) +{ + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(abs >= 0x7C00) + return half(detail::binary, (abs == 0x7C00) ? 0x7C00 : detail::signal(arg.data_)); + for(exp = (abs >> 10) - 15; abs < 0x200; abs <<= 1, --exp) + ; + unsigned int value = static_cast(exp < 0) << 15; + if(exp) + { + unsigned int m = std::abs(exp) << 6; + for(exp = 18; m < 0x400; m <<= 1, --exp) + ; + value |= (exp << 10) + m; + } + return half(detail::binary, value); +} + +/// Next representable value. +/// **See also:** Documentation for +/// [std::nextafter](https://en.cppreference.com/w/cpp/numeric/math/nextafter). +/// \param from value to compute next representable value for +/// \param to direction towards which to compute next value +/// \return next representable value after \a from in direction towards \a to +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW for infinite result from finite argument +/// \exception FE_UNDERFLOW for subnormal result +inline half nextafter(half from, half to) +{ + int fabs = from.data_ & 0x7FFF, tabs = to.data_ & 0x7FFF; + if(fabs > 0x7C00 || tabs > 0x7C00) + return half(detail::binary, detail::signal(from.data_, to.data_)); + if(from.data_ == to.data_ || !(fabs | tabs)) + return to; + if(!fabs) + { + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); + return half(detail::binary, (to.data_ & 0x8000) + 1); + } + unsigned int out = + from.data_ + + (((from.data_ >> 15) ^ + static_cast((from.data_ ^ (0x8000 | (0x8000 - (from.data_ >> 15)))) < + (to.data_ ^ (0x8000 | (0x8000 - (to.data_ >> 15)))))) + << 1) - + 1; + detail::raise(FE_OVERFLOW, fabs < 0x7C00 && (out & 0x7C00) == 0x7C00); + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT && (out & 0x7C00) < 0x400); + return half(detail::binary, out); +} + +/// Next representable value. +/// **See also:** Documentation for +/// [std::nexttoward](https://en.cppreference.com/w/cpp/numeric/math/nexttoward). +/// \param from value to compute next representable value for +/// \param to direction towards which to compute next value +/// \return next representable value after \a from in direction towards \a to +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW for infinite result from finite argument +/// \exception FE_UNDERFLOW for subnormal result +inline half nexttoward(half from, long double to) +{ + int fabs = from.data_ & 0x7FFF; + if(fabs > 0x7C00) + return half(detail::binary, detail::signal(from.data_)); + long double lfrom = static_cast(from); + if(detail::builtin_isnan(to) || lfrom == to) + return half(static_cast(to)); + if(!fabs) + { + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); + return half(detail::binary, (static_cast(detail::builtin_signbit(to)) << 15) + 1); + } + unsigned int out = + from.data_ + (((from.data_ >> 15) ^ static_cast(lfrom < to)) << 1) - 1; + detail::raise(FE_OVERFLOW, (out & 0x7FFF) == 0x7C00); + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT && (out & 0x7FFF) < 0x400); + return half(detail::binary, out); +} + +/// Take sign. +/// **See also:** Documentation for +/// [std::copysign](https://en.cppreference.com/w/cpp/numeric/math/copysign). +/// \param x value to change sign for +/// \param y value to take sign from +/// \return value equal to \a x in magnitude and to \a y in sign +inline HALF_CONSTEXPR half copysign(half x, half y) +{ + return half(detail::binary, x.data_ ^ ((x.data_ ^ y.data_) & 0x8000)); +} + +/// \} +/// \anchor classification +/// \name Floating point classification +/// \{ + +/// Classify floating-point value. +/// **See also:** Documentation for +/// [std::fpclassify](https://en.cppreference.com/w/cpp/numeric/math/fpclassify). +/// \param arg number to classify +/// \retval FP_ZERO for positive and negative zero +/// \retval FP_SUBNORMAL for subnormal numbers +/// \retval FP_INFINITY for positive and negative infinity +/// \retval FP_NAN for NaNs +/// \retval FP_NORMAL for all other (normal) values +inline HALF_CONSTEXPR int fpclassify(half arg) +{ + return !(arg.data_ & 0x7FFF) + ? FP_ZERO + : ((arg.data_ & 0x7FFF) < 0x400) + ? FP_SUBNORMAL + : ((arg.data_ & 0x7FFF) < 0x7C00) + ? FP_NORMAL + : ((arg.data_ & 0x7FFF) == 0x7C00) ? FP_INFINITE : FP_NAN; +} + +/// Check if finite number. +/// **See also:** Documentation for +/// [std::isfinite](https://en.cppreference.com/w/cpp/numeric/math/isfinite). +/// \param arg number to check +/// \retval true if neither infinity nor NaN +/// \retval false else +inline HALF_CONSTEXPR bool isfinite(half arg) { return (arg.data_ & 0x7C00) != 0x7C00; } + +/// Check for infinity. +/// **See also:** Documentation for +/// [std::isinf](https://en.cppreference.com/w/cpp/numeric/math/isinf). +/// \param arg number to check +/// \retval true for positive or negative infinity +/// \retval false else +inline HALF_CONSTEXPR bool isinf(half arg) { return (arg.data_ & 0x7FFF) == 0x7C00; } + +/// Check for NaN. +/// **See also:** Documentation for +/// [std::isnan](https://en.cppreference.com/w/cpp/numeric/math/isnan). +/// \param arg number to check +/// \retval true for NaNs +/// \retval false else +inline HALF_CONSTEXPR bool isnan(half arg) { return (arg.data_ & 0x7FFF) > 0x7C00; } + +/// Check if normal number. +/// **See also:** Documentation for +/// [std::isnormal](https://en.cppreference.com/w/cpp/numeric/math/isnormal). +/// \param arg number to check +/// \retval true if normal number +/// \retval false if either subnormal, zero, infinity or NaN +inline HALF_CONSTEXPR bool isnormal(half arg) +{ + return ((arg.data_ & 0x7C00) != 0) & ((arg.data_ & 0x7C00) != 0x7C00); +} + +/// Check sign. +/// **See also:** Documentation for +/// [std::signbit](https://en.cppreference.com/w/cpp/numeric/math/signbit). +/// \param arg number to check +/// \retval true for negative number +/// \retval false for positive number +inline HALF_CONSTEXPR bool signbit(half arg) { return (arg.data_ & 0x8000) != 0; } + +/// \} +/// \anchor compfunc +/// \name Comparison +/// \{ + +/// Quiet comparison for greater than. +/// **See also:** Documentation for +/// [std::isgreater](https://en.cppreference.com/w/cpp/numeric/math/isgreater). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater than \a y +/// \retval false else +inline HALF_CONSTEXPR bool isgreater(half x, half y) +{ + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) > + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comparison for greater equal. +/// **See also:** Documentation for +/// [std::isgreaterequal](https://en.cppreference.com/w/cpp/numeric/math/isgreaterequal). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater equal \a y +/// \retval false else +inline HALF_CONSTEXPR bool isgreaterequal(half x, half y) +{ + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) >= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comparison for less than. +/// **See also:** Documentation for +/// [std::isless](https://en.cppreference.com/w/cpp/numeric/math/isless). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less than \a y +/// \retval false else +inline HALF_CONSTEXPR bool isless(half x, half y) +{ + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) < + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comparison for less equal. +/// **See also:** Documentation for +/// [std::islessequal](https://en.cppreference.com/w/cpp/numeric/math/islessequal). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less equal \a y +/// \retval false else +inline HALF_CONSTEXPR bool islessequal(half x, half y) +{ + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) <= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comarison for less or greater. +/// **See also:** Documentation for +/// [std::islessgreater](https://en.cppreference.com/w/cpp/numeric/math/islessgreater). +/// \param x first operand +/// \param y second operand +/// \retval true if either less or greater +/// \retval false else +inline HALF_CONSTEXPR bool islessgreater(half x, half y) +{ + return x.data_ != y.data_ && ((x.data_ | y.data_) & 0x7FFF) && !isnan(x) && !isnan(y); +} + +/// Quiet check if unordered. +/// **See also:** Documentation for +/// [std::isunordered](https://en.cppreference.com/w/cpp/numeric/math/isunordered). +/// \param x first operand +/// \param y second operand +/// \retval true if unordered (one or two NaN operands) +/// \retval false else +inline HALF_CONSTEXPR bool isunordered(half x, half y) { return isnan(x) || isnan(y); } + +/// \} +/// \anchor casting +/// \name Casting +/// \{ + +/// Cast to or from half-precision floating-point number. +/// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values +/// are converted +/// directly using the default rounding mode, without any roundtrip over `float` that a +/// `static_cast` would otherwise do. +/// +/// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any +/// of the two types +/// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) +/// results in a compiler +/// error and casting between [half](\ref half_float::half)s returns the argument unmodified. +/// \tparam T destination type (half or built-in arithmetic type) +/// \tparam U source type (half or built-in arithmetic type) +/// \param arg value to cast +/// \return \a arg converted to destination type +/// \exception FE_INVALID if \a T is integer type and result is not representable as \a T +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +template +T half_cast(U arg) +{ + return detail::half_caster::cast(arg); +} + +/// Cast to or from half-precision floating-point number. +/// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values +/// are converted +/// directly using the specified rounding mode, without any roundtrip over `float` that a +/// `static_cast` would otherwise do. +/// +/// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any +/// of the two types +/// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) +/// results in a compiler +/// error and casting between [half](\ref half_float::half)s returns the argument unmodified. +/// \tparam T destination type (half or built-in arithmetic type) +/// \tparam R rounding mode to use. +/// \tparam U source type (half or built-in arithmetic type) +/// \param arg value to cast +/// \return \a arg converted to destination type +/// \exception FE_INVALID if \a T is integer type and result is not representable as \a T +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +template +T half_cast(U arg) +{ + return detail::half_caster::cast(arg); +} +/// \} + +/// \} +/// \anchor errors +/// \name Error handling +/// \{ + +/// Clear exception flags. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::feclearexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feclearexcept). +/// \param excepts OR of exceptions to clear +/// \retval 0 all selected flags cleared successfully +inline int feclearexcept(int excepts) +{ + detail::errflags() &= ~excepts; + return 0; +} + +/// Test exception flags. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::fetestexcept](https://en.cppreference.com/w/cpp/numeric/fenv/fetestexcept). +/// \param excepts OR of exceptions to test +/// \return OR of selected exceptions if raised +inline int fetestexcept(int excepts) { return detail::errflags() & excepts; } + +/// Raise exception flags. +/// This raises the specified floating point exceptions and also invokes any additional automatic +/// exception handling as +/// configured with the [HALF_ERRHANDLIG_...](\ref HALF_ERRHANDLING_ERRNO) preprocessor symbols. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::feraiseexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feraiseexcept). +/// \param excepts OR of exceptions to raise +/// \retval 0 all selected exceptions raised successfully +inline int feraiseexcept(int excepts) +{ + detail::errflags() |= excepts; + detail::raise(excepts); + return 0; +} + +/// Save exception flags. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::fegetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). +/// \param flagp adress to store flag state at +/// \param excepts OR of flags to save +/// \retval 0 for success +inline int fegetexceptflag(int* flagp, int excepts) +{ + *flagp = detail::errflags() & excepts; + return 0; +} + +/// Restore exception flags. +/// This only copies the specified exception state (including unset flags) without incurring any +/// additional exception handling. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::fesetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). +/// \param flagp adress to take flag state from +/// \param excepts OR of flags to restore +/// \retval 0 for success +inline int fesetexceptflag(const int* flagp, int excepts) +{ + detail::errflags() = (detail::errflags() | (*flagp & excepts)) & (*flagp | ~excepts); + return 0; +} + +/// Throw C++ exceptions based on set exception flags. +/// This function manually throws a corresponding C++ exception if one of the specified flags is +/// set, +/// no matter if automatic throwing (via [HALF_ERRHANDLING_THROW_...](\ref +/// HALF_ERRHANDLING_THROW_INVALID)) is enabled or not. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// \param excepts OR of exceptions to test +/// \param msg error message to use for exception description +/// \throw std::domain_error if `FE_INVALID` or `FE_DIVBYZERO` is selected and set +/// \throw std::overflow_error if `FE_OVERFLOW` is selected and set +/// \throw std::underflow_error if `FE_UNDERFLOW` is selected and set +/// \throw std::range_error if `FE_INEXACT` is selected and set +inline void fethrowexcept(int excepts, const char* msg = "") +{ + excepts &= detail::errflags(); + if(excepts & (FE_INVALID | FE_DIVBYZERO)) + throw std::domain_error(msg); + if(excepts & FE_OVERFLOW) + throw std::overflow_error(msg); + if(excepts & FE_UNDERFLOW) + throw std::underflow_error(msg); + if(excepts & FE_INEXACT) + throw std::range_error(msg); +} +/// \} +} // namespace half_float + +#undef HALF_UNUSED_NOERR +#undef HALF_CONSTEXPR +#undef HALF_CONSTEXPR_CONST +#undef HALF_CONSTEXPR_NOERR +#undef HALF_NOEXCEPT +#undef HALF_NOTHROW +#undef HALF_THREAD_LOCAL +#undef HALF_TWOS_COMPLEMENT_INT +#ifdef HALF_POP_WARNINGS +#pragma warning(pop) +#undef HALF_POP_WARNINGS +#endif + +#endif diff --git a/external/rocm/include/bfloat16_dev.hpp b/external/rocm/include/bfloat16_dev.hpp deleted file mode 100644 index 52d00346cfc..00000000000 --- a/external/rocm/include/bfloat16_dev.hpp +++ /dev/null @@ -1,125 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2019 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef BFLOAT16_DEVICE_HPP -#define BFLOAT16_DEVICE_HPP - -#ifdef __cplusplus -extern "C" { -#endif - -#ifdef __HIP_PLATFORM_HCC__ -#define EXECUTION_SPECIFIER __device__ -#else -#define EXECUTION_SPECIFIER -#endif // MIOPEN_BACKEND_HIP - -typedef union -{ - uint u32; - ushort2 ushortx2; - -// Composable kernels are written in HIP language. The language doesnt support -// ushort2.hi or ushort2.low. -#ifdef __HIP_PLATFORM_HCC__ - ushort ushortvec[2]; -#endif // MIOPEN_BACKEND_HIP - float f32; -} cvt_bf16_fp32_t; - -EXECUTION_SPECIFIER float bfloat16_to_float(ushort src_val) -{ - cvt_bf16_fp32_t target_val; - -#ifdef __HIP_PLATFORM_HCC__ - target_val.ushortx2 = make_ushort2(0, src_val); -#else - target_val.ushortx2 = (ushort2)(0, src_val); -#endif - - return target_val.f32; -} - -EXECUTION_SPECIFIER ushort float_to_bfloat16(float src_val) -{ - cvt_bf16_fp32_t target_val; - target_val.f32 = src_val; - // BF16 round and NaN preservation code matches - // https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/include/rocblas_bfloat16.h - if((~target_val.u32 & 0x7f800000) == 0) // Inf or NaN - { - // When all of the exponent bits are 1, the value is Inf or NaN. - // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero - // mantissa bit. Quiet NaN is indicated by the most significant mantissa - // bit being 1. Signaling NaN is indicated by the most significant - // mantissa bit being 0 but some other bit(s) being 1. If any of the - // lower 16 bits of the mantissa are 1, we set the least significant bit - // of the bfloat16 mantissa, in order to preserve signaling NaN in case - // the bloat16's mantissa bits are all 0. - if((target_val.u32 & 0xffff) != 0) - { - target_val.u32 |= 0x10000; // Preserve signaling NaN - } - } - else - { -#ifdef MIOPEN_USE_RNE_BFLOAT16 -// When the exponent bits are not all 1s, then the value is zero, normal, -// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus -// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). -// This causes the bfloat16's mantissa to be incremented by 1 if the 16 -// least significant bits of the float mantissa are greater than 0x8000, -// or if they are equal to 0x8000 and the least significant bit of the -// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when -// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already -// has the value 0x7f, then incrementing it causes it to become 0x00 and -// the exponent is incremented by one, which is the next higher FP value -// to the unrounded bfloat16 value. When the bfloat16 value is subnormal -// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up -// to a normal value with an exponent of 0x01 and a mantissa of 0x00. -// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, -// incrementing it causes it to become an exponent of 0xFF and a mantissa -// of 0x00, which is Inf, the next higher value to the unrounded value. -#ifdef __HIP_PLATFORM_HCC__ - target_val.u32 += (0x7fff + (target_val.ushortvec[1] & 1)); -#else - target_val.u32 += - (0x7fff + (target_val.ushortx2.hi & 1)); // Round to nearest, round to even -#endif // MIOPEN_BACKEND_HIP -#endif // MIOPEN_USE_RNE_BFLOAT16 - } - -#ifdef __HIP_PLATFORM_HCC__ - return target_val.ushortvec[1]; -#else - return target_val.ushortx2.hi; -#endif // MIOPEN_BACKEND_HIP -} - -#ifdef __cplusplus -} -#endif - -#endif // BFLOAT16_DEVICE_HPP diff --git a/host/CMakeLists.txt b/host/CMakeLists.txt deleted file mode 100644 index 30cc14d8caf..00000000000 --- a/host/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(host_tensor) -add_subdirectory(driver_offline) diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index b5e5f91d593..00000000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,190 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp" -#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp" - -template -void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - ck::index_t /* nrepeat */) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto N = out_n_k_ho_wo_lengths[I0]; - const auto K = out_n_k_ho_wo_lengths[I1]; - const auto C = wei_k_c_y_x_lengths[I1]; - - const auto Hi = in_n_c_hi_wi_lengths[I2]; - const auto Wi = in_n_c_hi_wi_lengths[I3]; - - const auto Ho = out_n_k_ho_wo_lengths[I2]; - const auto Wo = out_n_k_ho_wo_lengths[I3]; - - const auto Y = wei_k_c_y_x_lengths[I2]; - const auto X = wei_k_c_y_x_lengths[I3]; - - const auto C0 = C / Number{}; - const auto C1 = Number{}; - - const auto K0 = K / Number{}; - const auto K1 = Number{}; - - Tensor in_n_c0_hi_wi_c1( - HostTensorDescriptor(std::initializer_list{N, C0, Hi, Wi, C1})); - Tensor wei_k_c0_y_x_c1( - HostTensorDescriptor(std::initializer_list{K, C0, Y, X, C1})); - Tensor out_n_k0_ho_wo_k1( - HostTensorDescriptor(std::initializer_list{N, K0, Ho, Wo, K1})); - - auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) { - in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) = - in_n_c_hi_wi(n, c, hi, wi); - }; - - auto f_kcyx2kc0yxc1 = [&](auto k, auto y, auto x, auto c) { - wei_k_c0_y_x_c1(k, c / InWeiVectorSize, y, x, c % InWeiVectorSize) = - wei_k_c_y_x(k, c, y, x); - }; - - make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)(); - make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)(); - - DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) * - in_n_c0_hi_wi_c1.mDesc.GetElementSpace()); - DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace()); - DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) * - out_n_k0_ho_wo_k1.mDesc.GetElementSpace()); - - in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); - wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); - - const auto in_n_c0_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi)); - const auto wei_k_c0_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X)); - const auto out_n_k0_ho_wo_k1_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)); - -#if 1 - // cdata = 64, BlockSize = 64, 16x8x32x4 - constexpr index_t BlockSize = 64; - - constexpr index_t KPerBlock = 16; - constexpr index_t HoPerBlock = 8; - constexpr index_t WoPerBlock = 32; - constexpr index_t EPerBlock = 1; - - constexpr index_t KPerThread = KPerBlock; - constexpr index_t HoPerThread = 2; - constexpr index_t WoPerThread = 2; - constexpr index_t EPerThread = EPerBlock; - - using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>; - using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>; - - constexpr index_t ABlockTransferSrcScalarPerVector_E = 1; - constexpr index_t ABlockTransferDstScalarPerVector_K = 1; - - constexpr index_t BThreadTransferSrcScalarPerVector_W = 1; - - constexpr index_t CThreadTransferDstScalarPerVector_W = 16; - - static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, ""); -#else - constexpr index_t BlockSize = 64; - - constexpr index_t KPerBlock = 16; - constexpr index_t HoPerBlock = 8; - constexpr index_t WoPerBlock = 32; - constexpr index_t EPerBlock = 1; - - constexpr index_t KPerThread = 16; - constexpr index_t HoPerThread = 2; - constexpr index_t WoPerThread = 2; - constexpr index_t EPerThread = EPerBlock; - - using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>; - using ABlockTransferThreadClusterLengths_E_K = Sequence; - - constexpr index_t ABlockTransferSrcScalarPerVector_E = 1; - constexpr index_t ABlockTransferDstScalarPerVector_K = 1; - - constexpr index_t BThreadTransferSrcScalarPerVector_W = 1; - - constexpr index_t CThreadTransferDstScalarPerVector_W = K1; - - static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, ""); -#endif - - constexpr auto conv_driver = -#if 0 - DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad -#else - DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad -#endif - ::type, - TAcc, - TOut, - KPerBlock, - HoPerBlock, - WoPerBlock, - EPerBlock, - KPerThread, - HoPerThread, - WoPerThread, - EPerThread, - ABlockTransferThreadSliceLengths_E_K, - ABlockTransferThreadClusterLengths_E_K, - ABlockTransferSrcScalarPerVector_E, - ABlockTransferDstScalarPerVector_K, - BThreadTransferSrcScalarPerVector_W, - CThreadTransferDstScalarPerVector_W>{}; - - conv_driver.Run(wei_k_c0_y_x_desc, - in_n_c0_hi_wi_desc, - out_n_k0_ho_wo_k1_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - static_cast::type*>( - wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()), - static_cast::type*>( - in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()), - static_cast(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer())); - - out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data()); - - auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) { - out_n_k_ho_wo(n, k, ho, wo) = - out_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize); - }; - - make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Ho, Wo)(); -} diff --git a/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index efd4ce6a196..00000000000 --- a/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,349 +0,0 @@ -#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP -#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_dlops_v2.hpp" -#include "gridwise_operation_wrapper.hpp" - -template -struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad -{ - template - __host__ void Run(const ck::TensorDescriptor& wei_k_c_y_x_global_desc, - const ck::TensorDescriptor& in_n_c_hi_wi_global_desc, - const ck::TensorDescriptor& out_n_k0_ho_wo_k1_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const FloatAB* __restrict__ p_wei_global, - const FloatAB* __restrict__ p_in_global, - FloatC* __restrict__ p_out_global) const - { - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - - const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); - const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); - - const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); - - const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); - const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3); - - const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4); - - const auto K = wei_k_c_y_x_global_desc.GetLength(I0); - const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); - const auto X = wei_k_c_y_x_global_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - // weight tensor - const auto wei_e_k_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - // input tensor - const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( - in_n_c_hip_wip_global_desc, - make_tuple( - make_pass_through_transform(N), - make_pass_through_transform(C), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor( - in_n_c_y_ho_x_wo_global_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_pass_through_transform(N), - make_pass_through_transform(Ho), - make_pass_through_transform(Wo)), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - // output tensor - const auto out_k_n_ho_wo_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)), - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(N), - make_pass_through_transform(Ho), - make_pass_through_transform(Wo)), - make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto E = C * Y * X; - - if(!((K % KPerBlock) == 0 && (Ho % HoPerBlock) == 0 && (Wo % WoPerBlock) == 0 && - (E % EPerBlock) == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } - - // hack to control index calculation when iterating over a_k_m_global tensor - constexpr auto a_e_k_global_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - - constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{}; - - constexpr auto b_e_n_ho_wo_global_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); - - constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}; - - // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor - // hack for NKHW format - constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = - make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{})); - -#if 1 - // GEMM - using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - InMemoryDataOperationEnum_t::Set, - decltype(wei_e_k_global_desc), - decltype(in_e_n_ho_wo_global_desc), - decltype(out_k_n_ho_wo_global_desc), - KPerBlock, - HoPerBlock, - WoPerBlock, - EPerBlock, - KPerThread, - HoPerThread, - WoPerThread, - EPerThread, - ABlockTransferThreadSliceLengths_E_K, - ABlockTransferThreadClusterLengths_E_K, - Sequence<1, 0>, - Sequence<1, 0>, - 0, - ABlockTransferSrcScalarPerVector_E, - ABlockTransferDstScalarPerVector_K, - false, // don't move back src coordinate after threadwise copy - Sequence<0, 2, 3, 1>, - 3, - BThreadTransferSrcScalarPerVector_W, - false, // don't move back src coordinate after threadwise copy, which will be fused with - // MoveSrcSliceWindow() to save addr computation - Sequence<0, 2, 3, 1>, - 0, - CThreadTransferDstScalarPerVector_W, - decltype(a_e_k_global_step_hacks), - decltype(b_e_n_ho_wo_global_step_hacks), - decltype(c_k_n_ho_wo_global_tensor_step_hacks), - decltype(a_e_k_global_move_slice_window_step_hack), - decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>; - - const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N; - - const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1; - - const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0; - - index_t nrepeat = 100; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - std::cout << "has_main_k_block_loop: " << has_main_k_block_loop - << " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop - << std::endl; - - for(index_t j = 0; j < nrepeat; ++j) - { - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - wei_e_k_global_desc, - p_wei_global, - in_e_n_ho_wo_global_desc, - p_in_global, - out_k_n_ho_wo_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - wei_e_k_global_desc, - p_wei_global, - in_e_n_ho_wo_global_desc, - p_in_global, - out_k_n_ho_wo_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - wei_e_k_global_desc, - p_wei_global, - in_e_n_ho_wo_global_desc, - p_in_global, - out_k_n_ho_wo_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else - { - const auto kernel = run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - wei_e_k_global_desc, - p_wei_global, - in_e_n_ho_wo_global_desc, - p_in_global, - out_k_n_ho_wo_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = - static_cast(calculate_convolution_flops(in_n_c_hi_wi_global_desc, - wei_k_c_y_x_global_desc, - out_n_k0_ho_wo_k1_global_desc)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } -#endif - } -}; -#endif diff --git a/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp b/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp deleted file mode 100644 index 70f73cbf4a3..00000000000 --- a/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp +++ /dev/null @@ -1,364 +0,0 @@ -#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP -#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_dlops_v2.hpp" -#include "gridwise_operation_wrapper.hpp" - -template -struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad -{ - template - __host__ void Run(const ck::TensorDescriptor& wei_k_c_y_x_global_desc, - const ck::TensorDescriptor& in_n_c_hi_wi_global_desc, - const ck::TensorDescriptor& out_n_k0_ho_wo_k1_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const FloatAB* __restrict__ p_wei_global, - const FloatAB* __restrict__ p_in_global, - FloatC* __restrict__ p_out_global) const - { - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - - const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); - const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); - - const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); - - const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); - const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3); - - const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4); - - const auto K = wei_k_c_y_x_global_desc.GetLength(I0); - const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); - const auto X = wei_k_c_y_x_global_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock; - const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock; - - const auto OutRightPadH = Hop - Ho; - const auto OutRightPadW = Wop - Wo; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH; - const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW; - - std::cerr << "OutRightPadH = " << OutRightPadH << " OutRightPadW = " << OutRightPadW - << std::endl; - std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW - << std::endl; - - // weight tensor - const auto wei_e_k_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - // input tensor - const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( - in_n_c_hip_wip_global_desc, - make_tuple( - make_pass_through_transform(N), - make_pass_through_transform(C), - make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor( - in_n_c_y_ho_x_wo_global_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_pass_through_transform(N), - make_pass_through_transform(Hop), - make_pass_through_transform(Wop)), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - // output tensor - const auto out_k_n_hop_wop_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)), - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(N), - make_pad_transform(Ho, 0, OutRightPadH), - make_pad_transform(Wo, 0, OutRightPadW)), - make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto E = C * Y * X; - - std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl; - - if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 && - (E % EPerBlock) == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } - - // hack to control index calculation when iterating over a_k_m_global tensor - constexpr auto a_e_k_global_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - - constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{}; - - constexpr auto b_e_n_ho_wo_global_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); - - constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}; - - // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor - // hack for NKHW format - constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = - make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{})); - - // GEMM - using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - InMemoryDataOperationEnum_t::Set, - decltype(wei_e_k_global_desc), - decltype(in_e_n_ho_wo_global_desc), - decltype(out_k_n_hop_wop_global_desc), - KPerBlock, - HoPerBlock, - WoPerBlock, - EPerBlock, - KPerThread, - HoPerThread, - WoPerThread, - EPerThread, - ABlockTransferThreadSliceLengths_E_K, - ABlockTransferThreadClusterLengths_E_K, - Sequence<1, 0>, - Sequence<1, 0>, - 0, - ABlockTransferSrcScalarPerVector_E, - ABlockTransferDstScalarPerVector_K, - false, // don't move back src coordinate after threadwise copy - Sequence<0, 2, 3, 1>, - 3, - BThreadTransferSrcScalarPerVector_W, - false, // don't move back src coordinate after threadwise copy, which will be fused with - // MoveSrcSliceWindow() to save addr computation - Sequence<0, 2, 3, 1>, - 0, - CThreadTransferDstScalarPerVector_W, - decltype(a_e_k_global_step_hacks), - decltype(b_e_n_ho_wo_global_step_hacks), - decltype(c_k_n_ho_wo_global_tensor_step_hacks), - decltype(a_e_k_global_move_slice_window_step_hack), - decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>; - - const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; - - const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1; - - const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0; - - index_t nrepeat = 100; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - std::cout << "has_main_k_block_loop: " << has_main_k_block_loop - << " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop - << std::endl; - - for(index_t j = 0; j < nrepeat; ++j) - { - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - wei_e_k_global_desc, - p_wei_global, - in_e_n_ho_wo_global_desc, - p_in_global, - out_k_n_hop_wop_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - wei_e_k_global_desc, - p_wei_global, - in_e_n_ho_wo_global_desc, - p_in_global, - out_k_n_hop_wop_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - wei_e_k_global_desc, - p_wei_global, - in_e_n_ho_wo_global_desc, - p_in_global, - out_k_n_hop_wop_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - wei_e_k_global_desc, - p_wei_global, - in_e_n_ho_wo_global_desc, - p_in_global, - out_k_n_hop_wop_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = - static_cast(calculate_convolution_flops(in_n_c_hi_wi_global_desc, - wei_k_c_y_x_global_desc, - out_n_k0_ho_wo_k1_global_desc)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } -}; -#endif diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp deleted file mode 100644 index 4ccfbaab0aa..00000000000 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp +++ /dev/null @@ -1,275 +0,0 @@ -#ifndef DRIVER_GEMM_XDLOPS_V2R3_HPP -#define DRIVER_GEMM_XDLOPS_V2R3_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r3.hpp" - -template -__host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - ck::index_t M01, - ck::index_t N01, - AGridStepHacks, - BGridStepHacks, - CGridStepHacks, - AGridMoveSliceWindowStepHacks, - BGridMoveSliceWindowStepHacks, - ck::index_t nrepeat) - -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - using GridwiseGemm = - GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; - - { - std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", " - << a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2) - << "}" << std::endl; - - std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", " - << b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2) - << "}" << std::endl; - - std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " - << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; - } - - if(!GridwiseGemm::CheckValidity( - a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, M01, N01)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); - } - - const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = - GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); - - using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); - - const auto c_block_cluster_adaptor = - GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, M01, N01); - - using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); - - const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); - - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); - - float ave_time = 0; - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - if(has_main_k0_block_loop) - { - const auto kernel = kernel_gemm_xdlops_v2r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); - } - else - { - const auto kernel = kernel_gemm_xdlops_v2r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); - } -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); - DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); - DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc)); - DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); - - a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc); - b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc); - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); - c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); - - if(has_main_k0_block_loop) - { - const auto kernel = kernel_gemm_xdlops_v2r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - else - { - const auto kernel = kernel_gemm_xdlops_v2r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } -} -#endif - return ave_time; -} -#endif diff --git a/host/host_tensor/CMakeLists.txt b/host/host_tensor/CMakeLists.txt deleted file mode 100644 index 3dcecf64e1b..00000000000 --- a/host/host_tensor/CMakeLists.txt +++ /dev/null @@ -1,21 +0,0 @@ -include_directories(BEFORE - include -) - -set(HOST_TENSOR_SOURCE - src/host_tensor.cpp; - src/device.cpp; -) - -## the library target -add_library(host_tensor SHARED ${HOST_TENSOR_SOURCE}) - -target_include_directories(host_tensor SYSTEM PUBLIC $) - -target_link_libraries(host_tensor PRIVATE hip::device) -target_link_libraries(host_tensor INTERFACE hip::host) - -target_compile_features(host_tensor PUBLIC) -set_target_properties(host_tensor PROPERTIES POSITION_INDEPENDENT_CODE ON) - -install(TARGETS host_tensor LIBRARY DESTINATION lib) diff --git a/host/host_tensor/include/device.hpp b/host/host_tensor/include/device.hpp deleted file mode 100644 index cb1a6effa17..00000000000 --- a/host/host_tensor/include/device.hpp +++ /dev/null @@ -1,84 +0,0 @@ -#ifndef DEVICE_HPP -#define DEVICE_HPP - -#include -#include -#include -#include -#include "hip/hip_runtime.h" -#include "hip/hip_fp16.h" - -struct DeviceMem -{ - DeviceMem() = delete; - DeviceMem(std::size_t mem_size); - void* GetDeviceBuffer(); - void ToDevice(const void* p); - void FromDevice(void* p); - ~DeviceMem(); - - void* mpDeviceBuf; - std::size_t mMemSize; -}; - -struct KernelTimerImpl; - -struct KernelTimer -{ - KernelTimer(); - ~KernelTimer(); - void Start(); - void End(); - float GetElapsedTime() const; - - std::unique_ptr impl; -}; - -using device_stream_t = hipStream_t; - -template -void launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) -{ - hipStream_t stream_id = nullptr; - - hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); -} - -template -float launch_and_time_kernel( - F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) -{ - KernelTimer timer; - - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", - __func__, - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z); - - printf("Warm up\n"); - - hipStream_t stream_id = nullptr; - - // warm up - hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); - - printf("Start running %d times...\n", nrepeat); - - timer.Start(); - - for(int i = 0; i < nrepeat; ++i) - { - hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); - } - - timer.End(); - - // std::this_thread::sleep_for (std::chrono::microseconds(10)); - - return timer.GetElapsedTime() / nrepeat; -} -#endif diff --git a/host/host_tensor/include/gemm_common.hpp b/host/host_tensor/include/gemm_common.hpp deleted file mode 100644 index f6c0d6f930a..00000000000 --- a/host/host_tensor/include/gemm_common.hpp +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef GEMM_COMMON_HPP -#define GEMM_COMMON_HPP - -enum GemmMatrixLayout -{ - MK_KN_MN, // 0 - MK_NK_MN, // 1 - KM_KN_MN, // 2 - KM_NK_MN, // 3 - MK_KN_NM, // 4 - MK_NK_NM, // 5 - KM_KN_NM, // 6 - KM_NK_NM, // 7 -}; - -#endif diff --git a/host/host_tensor/include/host_conv.hpp b/host/host_tensor/include/host_conv.hpp deleted file mode 100644 index c1228f4832b..00000000000 --- a/host/host_tensor/include/host_conv.hpp +++ /dev/null @@ -1,324 +0,0 @@ -#pragma once -#include "host_tensor.hpp" - -template -void host_direct_convolution(const Tensor& in, - const Tensor& wei, - Tensor& out, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads&, - const ConvTensorLayout layout = ConvTensorLayout::NCHW) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { - double v = 0; - for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) - { - for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && - wi < in.mDesc.GetLengths()[3]) - { - v += static_cast(in(n, c, hi, wi)) * - static_cast(wei(k, c, y, x)); - } - } - } - } - out(n, k, ho, wo) = v; - }; - - auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) { - double v = 0; - for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c) - { - for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && - wi < in.mDesc.GetLengths()[2]) - { - v += static_cast(in(n, hi, wi, c)) * - static_cast(wei(k, y, x, c)); - } - } - } - } - out(n, ho, wo, k) = v; - }; - - if(layout == ConvTensorLayout::NCHW) - { - make_ParallelTensorFunctor(f_nchw, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else if(layout == ConvTensorLayout::NHWC) - { - make_ParallelTensorFunctor(f_nhwc, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else - { - throw std::runtime_error("wrong! not supported layout"); - } -} - -template -void host_winograd_3x3_convolution(const Tensor& in_nchw, - const Tensor& wei_kcyx, - Tensor& out_nkhw, - InLeftPads, - InRightPads) -{ - using namespace ck; - - constexpr std::size_t HoPerTile = 2; - constexpr std::size_t WoPerTile = 2; - - std::size_t N = in_nchw.mDesc.GetLengths()[0]; - std::size_t C = in_nchw.mDesc.GetLengths()[1]; - - std::size_t K = wei_kcyx.mDesc.GetLengths()[0]; - std::size_t Y = wei_kcyx.mDesc.GetLengths()[2]; - std::size_t X = wei_kcyx.mDesc.GetLengths()[3]; - - std::size_t Ho = out_nkhw.mDesc.GetLengths()[2]; - std::size_t Wo = out_nkhw.mDesc.GetLengths()[3]; - - index_t h_pad_low = InLeftPads{}.Get(Number<0>{}); - index_t w_pad_low = InLeftPads{}.Get(Number<1>{}); - - std::size_t HiPerTile = HoPerTile + Y - 1; - std::size_t WiPerTile = WoPerTile + X - 1; - - std::size_t HTile = (Ho + HoPerTile - 1) / HoPerTile; - std::size_t WTile = (Wo + WoPerTile - 1) / WoPerTile; - - Tensor in_hold({N, C, HTile, WTile, HiPerTile, WiPerTile}); - Tensor in_transform({N, C, HTile, WTile, HiPerTile, WiPerTile}); - Tensor wei_transform({K, C, HiPerTile, WiPerTile}); - Tensor out_transform({N, K, HTile, WTile, HiPerTile, HiPerTile}); - Tensor out_hold({N, K, HTile, WTile, HoPerTile, WoPerTile}); - - auto f_in_hold = [&](auto n, auto c, auto htile, auto wtile) { - for(int j = 0; j < HiPerTile; ++j) - { - int hi = HoPerTile * htile + j - h_pad_low; - for(int i = 0; i < WiPerTile; ++i) - { - int wi = WoPerTile * wtile + i - w_pad_low; - - if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 && - wi < in_nchw.mDesc.GetLengths()[3]) - { - in_hold(n, c, htile, wtile, j, i) = in_nchw(n, c, hi, wi); - } - else - { - in_hold(n, c, htile, wtile, j, i) = TIn(0); - } - } - } - }; - - auto f_in_transform = [&](auto n, auto c, auto htile, auto wtile) { - in_transform(n, c, htile, wtile, 0, 0) = - in_hold(n, c, htile, wtile, 0, 0) - in_hold(n, c, htile, wtile, 0, 2) - - in_hold(n, c, htile, wtile, 2, 0) + in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 0, 1) = - in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) - - in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 0, 2) = - -in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) + - in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 0, 3) = - in_hold(n, c, htile, wtile, 0, 1) - in_hold(n, c, htile, wtile, 0, 3) - - in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 3); - - in_transform(n, c, htile, wtile, 1, 0) = - in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) + - in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 1, 1) = - in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) + - in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 1, 2) = - -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) - - in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 1, 3) = - in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) + - in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3); - - in_transform(n, c, htile, wtile, 2, 0) = - -in_hold(n, c, htile, wtile, 1, 0) + in_hold(n, c, htile, wtile, 1, 2) + - in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 2, 1) = - -in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) + - in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 2, 2) = - in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) - - in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 2, 3) = - -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 3) + - in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3); - - in_transform(n, c, htile, wtile, 3, 0) = - in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) - - in_hold(n, c, htile, wtile, 3, 0) + in_hold(n, c, htile, wtile, 3, 2); - in_transform(n, c, htile, wtile, 3, 1) = - in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) - - in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2); - in_transform(n, c, htile, wtile, 3, 2) = - -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) + - in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2); - in_transform(n, c, htile, wtile, 3, 3) = - in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) - - in_hold(n, c, htile, wtile, 3, 1) + in_hold(n, c, htile, wtile, 3, 3); - }; - - auto f_wei_transform = [&](auto k, auto c) { - wei_transform(k, c, 0, 0) = double(wei_kcyx(k, c, 0, 0)); - wei_transform(k, c, 0, 1) = 0.5 * double(wei_kcyx(k, c, 0, 0)) + - 0.5 * double(wei_kcyx(k, c, 0, 1)) + - 0.5 * double(wei_kcyx(k, c, 0, 2)); - wei_transform(k, c, 0, 2) = 0.5 * double(wei_kcyx(k, c, 0, 0)) - - 0.5 * double(wei_kcyx(k, c, 0, 1)) + - 0.5 * double(wei_kcyx(k, c, 0, 2)); - wei_transform(k, c, 0, 3) = double(wei_kcyx(k, c, 0, 2)); - - wei_transform(k, c, 1, 0) = 0.5 * double(wei_kcyx(k, c, 0, 0)) + - 0.5 * double(wei_kcyx(k, c, 1, 0)) + - 0.5 * double(wei_kcyx(k, c, 2, 0)); - wei_transform(k, c, 1, 1) = - 0.25 * double(wei_kcyx(k, c, 0, 0)) + 0.25 * double(wei_kcyx(k, c, 0, 1)) + - 0.25 * double(wei_kcyx(k, c, 0, 2)) + 0.25 * double(wei_kcyx(k, c, 1, 0)) + - 0.25 * double(wei_kcyx(k, c, 1, 1)) + 0.25 * double(wei_kcyx(k, c, 1, 2)) + - 0.25 * double(wei_kcyx(k, c, 2, 0)) + 0.25 * double(wei_kcyx(k, c, 2, 1)) + - 0.25 * double(wei_kcyx(k, c, 2, 2)); - wei_transform(k, c, 1, 2) = - 0.25 * double(wei_kcyx(k, c, 0, 0)) - 0.25 * double(wei_kcyx(k, c, 0, 1)) + - 0.25 * double(wei_kcyx(k, c, 0, 2)) + 0.25 * double(wei_kcyx(k, c, 1, 0)) - - 0.25 * double(wei_kcyx(k, c, 1, 1)) + 0.25 * double(wei_kcyx(k, c, 1, 2)) + - 0.25 * double(wei_kcyx(k, c, 2, 0)) - 0.25 * double(wei_kcyx(k, c, 2, 1)) + - 0.25 * double(wei_kcyx(k, c, 2, 2)); - wei_transform(k, c, 1, 3) = 0.5 * double(wei_kcyx(k, c, 0, 2)) + - 0.5 * double(wei_kcyx(k, c, 1, 2)) + - 0.5 * double(wei_kcyx(k, c, 2, 2)); - - wei_transform(k, c, 2, 0) = 0.5 * double(wei_kcyx(k, c, 0, 0)) - - 0.5 * double(wei_kcyx(k, c, 1, 0)) + - 0.5 * double(wei_kcyx(k, c, 2, 0)); - wei_transform(k, c, 2, 1) = - 0.25 * double(wei_kcyx(k, c, 0, 0)) + 0.25 * double(wei_kcyx(k, c, 0, 1)) + - 0.25 * double(wei_kcyx(k, c, 0, 2)) - 0.25 * double(wei_kcyx(k, c, 1, 0)) - - 0.25 * double(wei_kcyx(k, c, 1, 1)) - 0.25 * double(wei_kcyx(k, c, 1, 2)) + - 0.25 * double(wei_kcyx(k, c, 2, 0)) + 0.25 * double(wei_kcyx(k, c, 2, 1)) + - 0.25 * double(wei_kcyx(k, c, 2, 2)); - wei_transform(k, c, 2, 2) = - 0.25 * double(wei_kcyx(k, c, 0, 0)) - 0.25 * double(wei_kcyx(k, c, 0, 1)) + - 0.25 * double(wei_kcyx(k, c, 0, 2)) - 0.25 * double(wei_kcyx(k, c, 1, 0)) + - 0.25 * double(wei_kcyx(k, c, 1, 1)) - 0.25 * double(wei_kcyx(k, c, 1, 2)) + - 0.25 * double(wei_kcyx(k, c, 2, 0)) - 0.25 * double(wei_kcyx(k, c, 2, 1)) + - 0.25 * double(wei_kcyx(k, c, 2, 2)); - wei_transform(k, c, 2, 3) = 0.5 * double(wei_kcyx(k, c, 0, 2)) - - 0.5 * double(wei_kcyx(k, c, 1, 2)) + - 0.5 * double(wei_kcyx(k, c, 2, 2)); - - wei_transform(k, c, 3, 0) = double(wei_kcyx(k, c, 2, 0)); - wei_transform(k, c, 3, 1) = 0.5 * double(wei_kcyx(k, c, 2, 0)) + - 0.5 * double(wei_kcyx(k, c, 2, 1)) + - 0.5 * double(wei_kcyx(k, c, 2, 2)); - wei_transform(k, c, 3, 2) = 0.5 * double(wei_kcyx(k, c, 2, 0)) - - 0.5 * double(wei_kcyx(k, c, 2, 1)) + - 0.5 * double(wei_kcyx(k, c, 2, 2)); - wei_transform(k, c, 3, 3) = double(wei_kcyx(k, c, 2, 2)); - }; - - auto f_out_transform = [&](auto n, auto k, auto htile, auto wtile) { - for(int j = 0; j < HiPerTile; ++j) - { - for(int i = 0; i < WiPerTile; ++i) - { - double v = 0; - for(int c = 0; c < C; ++c) - { - v += in_transform(n, c, htile, wtile, j, i) * wei_transform(k, c, j, i); - } - - out_transform(n, k, htile, wtile, j, i) = v; - } - } - }; - - auto f_out_hold = [&](auto n, auto k, auto htile, auto wtile) { - out_hold(n, k, htile, wtile, 0, 0) = - out_transform(n, k, htile, wtile, 0, 0) + out_transform(n, k, htile, wtile, 0, 1) + - out_transform(n, k, htile, wtile, 0, 2) + out_transform(n, k, htile, wtile, 1, 0) + - out_transform(n, k, htile, wtile, 1, 1) + out_transform(n, k, htile, wtile, 1, 2) + - out_transform(n, k, htile, wtile, 2, 0) + out_transform(n, k, htile, wtile, 2, 1) + - out_transform(n, k, htile, wtile, 2, 2); - out_hold(n, k, htile, wtile, 0, 1) = - out_transform(n, k, htile, wtile, 0, 1) - out_transform(n, k, htile, wtile, 0, 2) - - out_transform(n, k, htile, wtile, 0, 3) + out_transform(n, k, htile, wtile, 1, 1) - - out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 1, 3) + - out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) - - out_transform(n, k, htile, wtile, 2, 3); - out_hold(n, k, htile, wtile, 1, 0) = - out_transform(n, k, htile, wtile, 1, 0) + out_transform(n, k, htile, wtile, 1, 1) + - out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 2, 0) - - out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) - - out_transform(n, k, htile, wtile, 3, 0) - out_transform(n, k, htile, wtile, 3, 1) - - out_transform(n, k, htile, wtile, 3, 2); - out_hold(n, k, htile, wtile, 1, 1) = - out_transform(n, k, htile, wtile, 1, 1) - out_transform(n, k, htile, wtile, 1, 2) - - out_transform(n, k, htile, wtile, 1, 3) - out_transform(n, k, htile, wtile, 2, 1) + - out_transform(n, k, htile, wtile, 2, 2) + out_transform(n, k, htile, wtile, 2, 3) - - out_transform(n, k, htile, wtile, 3, 1) + out_transform(n, k, htile, wtile, 3, 2) + - out_transform(n, k, htile, wtile, 3, 3); - }; - - auto f_out = [&](auto n, auto k, auto htile, auto wtile) { - for(int j = 0; j < HoPerTile; ++j) - { - std::size_t ho = HoPerTile * htile + j; - for(int i = 0; i < WoPerTile; ++i) - { - std::size_t wo = WoPerTile * wtile + i; - out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i); - } - } - }; - - std::size_t num_thread = std::thread::hardware_concurrency(); - - make_ParallelTensorFunctor(f_in_hold, N, C, HTile, WTile)(num_thread); - make_ParallelTensorFunctor(f_in_transform, N, C, HTile, WTile)(num_thread); - make_ParallelTensorFunctor(f_wei_transform, K, C)(num_thread); - make_ParallelTensorFunctor(f_out_transform, N, K, HTile, WTile)(num_thread); - make_ParallelTensorFunctor(f_out_hold, N, K, HTile, WTile)(num_thread); - make_ParallelTensorFunctor(f_out, N, K, HTile, WTile)(num_thread); -} diff --git a/host/host_tensor/include/host_conv_bwd_data.hpp b/host/host_tensor/include/host_conv_bwd_data.hpp deleted file mode 100644 index ca23422e232..00000000000 --- a/host/host_tensor/include/host_conv_bwd_data.hpp +++ /dev/null @@ -1,135 +0,0 @@ -#pragma once -#include "host_tensor.hpp" - -template -void host_direct_convolution_backward_data(Tensor& in, - const Tensor& wei, - const Tensor& out, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& /* in_right_pads */, - const ConvTensorLayout layout = ConvTensorLayout::NCHW) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { - std::size_t K = wei.mDesc.GetLengths()[I0]; - std::size_t Y = wei.mDesc.GetLengths()[I2]; - std::size_t X = wei.mDesc.GetLengths()[I3]; - - std::size_t Ho = out.mDesc.GetLengths()[I2]; - std::size_t Wo = out.mDesc.GetLengths()[I3]; - - double v = 0; - - for(int y = 0; y < Y; ++y) - { - int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0]; - - if(h_tmp % conv_strides[I0] == 0) - { - int ho = h_tmp / conv_strides[I0]; - - if(ho >= 0 && ho < Ho) - { - for(int x = 0; x < X; ++x) - { - int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1]; - - if(w_tmp % conv_strides[I1] == 0) - { - int wo = w_tmp / conv_strides[I1]; - - if(wo >= 0 && wo < Wo) - { - for(int k = 0; k < K; ++k) - { - v += out(n, k, ho, wo) * wei(k, c, y, x); - } - } - } - } - } - } - } - - in(n, c, hi, wi) = v; - }; - - auto f_nhwc = [&](auto n, auto hi, auto wi, auto c) { - std::size_t K = wei.mDesc.GetLengths()[I0]; - std::size_t Y = wei.mDesc.GetLengths()[I1]; - std::size_t X = wei.mDesc.GetLengths()[I2]; - - std::size_t Ho = out.mDesc.GetLengths()[I1]; - std::size_t Wo = out.mDesc.GetLengths()[I2]; - - double v = 0; - - for(int y = 0; y < Y; ++y) - { - int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0]; - - if(h_tmp % conv_strides[I0] == 0) - { - int ho = h_tmp / conv_strides[I0]; - - if(ho >= 0 && ho < Ho) - { - for(int x = 0; x < X; ++x) - { - int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1]; - - if(w_tmp % conv_strides[I1] == 0) - { - int wo = w_tmp / conv_strides[I1]; - - if(wo >= 0 && wo < Wo) - { - for(int k = 0; k < K; ++k) - { - v += out(n, ho, wo, k) * wei(k, y, x, c); - } - } - } - } - } - } - } - - in(n, hi, wi, c) = v; - }; - - if(layout == ConvTensorLayout::NCHW) - { - make_ParallelTensorFunctor(f_nchw, - in.mDesc.GetLengths()[0], - in.mDesc.GetLengths()[1], - in.mDesc.GetLengths()[2], - in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else if(layout == ConvTensorLayout::NHWC) - { - make_ParallelTensorFunctor(f_nhwc, - in.mDesc.GetLengths()[0], - in.mDesc.GetLengths()[1], - in.mDesc.GetLengths()[2], - in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else - { - throw std::runtime_error("wrong! not supported layout"); - } -} diff --git a/host/host_tensor/include/host_conv_bwd_weight.hpp b/host/host_tensor/include/host_conv_bwd_weight.hpp deleted file mode 100644 index ed3e8c3042e..00000000000 --- a/host/host_tensor/include/host_conv_bwd_weight.hpp +++ /dev/null @@ -1,89 +0,0 @@ -#pragma once -#include "host_tensor.hpp" - -template -void host_direct_convolution_backward_weights( - const Tensor& out, - const Tensor& in, - Tensor& wei, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads&, - const ConvTensorLayout layout = ConvTensorLayout::NCHW) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - auto f_kcyx = [&](auto k, auto c, auto y, auto x) { - double v = 0; - for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) - { - for(int ho = 0; ho < out.mDesc.GetLengths()[2]; ++ho) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int wo = 0; wo < out.mDesc.GetLengths()[3]; ++wo) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && - wi < in.mDesc.GetLengths()[3]) - { - v += static_cast(in(n, c, hi, wi)) * - static_cast(out(n, k, ho, wo)); - } - } - } - } - wei(k, c, y, x) = v; - }; - - auto f_kyxc = [&](auto k, auto y, auto x, auto c) { - double v = 0; - for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) - { - for(int ho = 0; ho < out.mDesc.GetLengths()[1]; ++ho) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int wo = 0; wo < out.mDesc.GetLengths()[2]; ++wo) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && - wi < in.mDesc.GetLengths()[2]) - { - v += static_cast(in(n, hi, wi, c)) * - static_cast(out(n, ho, wo, k)); - } - } - } - } - wei(k, y, x, c) = v; - }; - - if(layout == ConvTensorLayout::NCHW) - { - make_ParallelTensorFunctor(f_kcyx, - wei.mDesc.GetLengths()[0], - wei.mDesc.GetLengths()[1], - wei.mDesc.GetLengths()[2], - wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else if(layout == ConvTensorLayout::NHWC) - { - make_ParallelTensorFunctor(f_kyxc, - wei.mDesc.GetLengths()[0], - wei.mDesc.GetLengths()[1], - wei.mDesc.GetLengths()[2], - wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else - { - throw std::runtime_error("wrong! not supported layout"); - } -} diff --git a/host/host_tensor/include/host_gemm.hpp b/host/host_tensor/include/host_gemm.hpp deleted file mode 100644 index c582a342585..00000000000 --- a/host/host_tensor/include/host_gemm.hpp +++ /dev/null @@ -1,159 +0,0 @@ -#pragma once -#include "host_tensor.hpp" -#include "gemm_common.hpp" - -template -void host_gemm(const Tensor& a, - const Tensor& b, - Tensor& c, - const GemmMatrixLayout layout) -{ - if(layout == GemmMatrixLayout::MK_KN_MN) - { - auto f_mk_kn_mn = [&](auto m, auto n) { - const int K = a.mDesc.GetLengths()[1]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(m, k)) * static_cast(b(k, n)); - } - - c(m, n) = v; - }; - - make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::MK_NK_MN) - { - auto f_mk_nk_mn = [&](auto m, auto n) { - const int K = a.mDesc.GetLengths()[1]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(m, k)) * static_cast(b(n, k)); - } - - c(m, n) = v; - }; - - make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::KM_KN_MN) - { - auto f_km_kn_mn = [&](auto m, auto n) { - const int K = a.mDesc.GetLengths()[0]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(k, m)) * static_cast(b(k, n)); - } - - c(m, n) = v; - }; - - make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::KM_NK_MN) - { - auto f_km_nk_mn = [&](auto m, auto n) { - const int K = a.mDesc.GetLengths()[0]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(k, m)) * static_cast(b(n, k)); - } - - c(m, n) = v; - }; - - make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::MK_KN_NM) - { - auto f_mk_kn_nm = [&](auto n, auto m) { - const int K = a.mDesc.GetLengths()[1]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(m, k)) * static_cast(b(k, n)); - } - - c(n, m) = v; - }; - - make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::MK_NK_NM) - { - auto f_mk_nk_nm = [&](auto n, auto m) { - const int K = a.mDesc.GetLengths()[1]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(m, k)) * static_cast(b(n, k)); - } - - c(n, m) = v; - }; - - make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::KM_KN_NM) - { - auto f_km_kn_nm = [&](auto n, auto m) { - const int K = a.mDesc.GetLengths()[0]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(k, m)) * static_cast(b(k, n)); - } - - c(n, m) = v; - }; - - make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::KM_NK_NM) - { - auto f_km_nk_nm = [&](auto n, auto m) { - const int K = a.mDesc.GetLengths()[0]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(k, m)) * static_cast(b(n, k)); - } - - c(n, m) = v; - }; - - make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else - { - throw std::runtime_error("wrong! not supported layout"); - } -} diff --git a/host/host_tensor/include/host_tensor_generator.hpp b/host/host_tensor/include/host_tensor_generator.hpp deleted file mode 100644 index b0d53995ede..00000000000 --- a/host/host_tensor/include/host_tensor_generator.hpp +++ /dev/null @@ -1,71 +0,0 @@ -#ifndef HOST_TENSOR_GENERATOR_HPP -#define HOST_TENSOR_GENERATOR_HPP - -#include -#include "config.hpp" - -struct GeneratorTensor_1 -{ - int value = 1; - - template - float operator()(Is...) - { - return value; - } -}; - -struct GeneratorTensor_0 -{ - int value = 0; - - template - float operator()(Is...) - { - return value; - } -}; - -struct GeneratorTensor_2 -{ - int min_value = 0; - int max_value = 1; - - template - float operator()(Is...) - { - return (std::rand() % (max_value - min_value)) + min_value; - } -}; - -template -struct GeneratorTensor_3 -{ - T min_value = 0; - T max_value = 1; - - template - float operator()(Is...) - { - float tmp = float(std::rand()) / float(RAND_MAX); - - return min_value + tmp * (max_value - min_value); - } -}; - -struct GeneratorTensor_Checkboard -{ - template - float operator()(Ts... Xs) const - { - std::array dims = {{static_cast(Xs)...}}; - return std::accumulate(dims.begin(), - dims.end(), - true, - [](bool init, ck::index_t x) -> int { return init != (x % 2); }) - ? 1 - : -1; - } -}; - -#endif diff --git a/host/host_tensor/src/device.cpp b/host/host_tensor/src/device.cpp deleted file mode 100644 index 0d1b3d6883b..00000000000 --- a/host/host_tensor/src/device.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include "device.hpp" - -DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) -{ - hipGetErrorString(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); -} - -void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; } - -void DeviceMem::ToDevice(const void* p) -{ - hipGetErrorString( - hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); -} - -void DeviceMem::FromDevice(void* p) -{ - hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); -} - -DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); } - -struct KernelTimerImpl -{ - KernelTimerImpl() - { - hipGetErrorString(hipEventCreate(&mStart)); - hipGetErrorString(hipEventCreate(&mEnd)); - } - - ~KernelTimerImpl() - { - hipGetErrorString(hipEventDestroy(mStart)); - hipGetErrorString(hipEventDestroy(mEnd)); - } - - void Start() - { - hipGetErrorString(hipDeviceSynchronize()); - hipGetErrorString(hipEventRecord(mStart, nullptr)); - } - - void End() - { - hipGetErrorString(hipEventRecord(mEnd, nullptr)); - hipGetErrorString(hipEventSynchronize(mEnd)); - } - - float GetElapsedTime() const - { - float time; - hipGetErrorString(hipEventElapsedTime(&time, mStart, mEnd)); - return time; - } - - hipEvent_t mStart, mEnd; -}; - -KernelTimer::KernelTimer() : impl(new KernelTimerImpl()) {} - -KernelTimer::~KernelTimer() {} - -void KernelTimer::Start() { impl->Start(); } - -void KernelTimer::End() { impl->End(); } - -float KernelTimer::GetElapsedTime() const { return impl->GetElapsedTime(); } diff --git a/host/solver/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp b/host/solver/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 2b645e3c3bc..00000000000 --- a/host/solver/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,689 +0,0 @@ -#ifndef CONV_IGEMM_FWD_V6R1_DLOPS_NCHW_KCYX_NKHW_HPP -#define CONV_IGEMM_FWD_V6R1_DLOPS_NCHW_KCYX_NKHW_HPP - -#include -#include - -namespace ck { -namespace driver { - -struct CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw -{ - auto GetCompileParameterString() const - { - auto param = std::stringstream(); - - // clang-format off - param << - " -DCK_PARAM_ABDataTypeEnum=" << - ABDataTypeEnum << - " -DCK_PARAM_AccDataTypeEnum=" << - AccDataTypeEnum << - " -DCK_PARAM_CDataTypeEnum=" << - CDataTypeEnum << - " -DCK_PARAM_BlockSize=" << - BlockSize << - " -DCK_PARAM_GN0=" << - GN0 << - " -DCK_PARAM_GK1=" << - GK1 << - " -DCK_PARAM_GM1PerBlockGM11=" - << GM1PerBlockGM11 << - " -DCK_PARAM_GN1PerBlockGN11=" << - GN1PerBlockGN11 << - " -DCK_PARAM_GK0PerBlock=" << - GK0PerBlock << - " -DCK_PARAM_BM1PerThreadBM11=" << - BM1PerThreadBM11 << - " -DCK_PARAM_BN1PerThreadBN11=" << - BN1PerThreadBN11 << - " -DCK_PARAM_BK0PerThread=" << - BK0PerThread << - " -DCK_PARAM_BM10BN10ThreadClusterBM10Xs=" << - BM10BN10ThreadClusterBM10Xs[0] << "," << - BM10BN10ThreadClusterBM10Xs[1] << - " -DCK_PARAM_BM10BN10ThreadClusterBN10Xs=" << - BM10BN10ThreadClusterBN10Xs[0] << "," << - BM10BN10ThreadClusterBN10Xs[1] << - " -DCK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1=" << - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[0] << "," << - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[1] << "," << - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[2] << "," << - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[3] << "," << - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[4] << - " -DCK_PARAM_ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1=" << - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[0] << "," << - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[1] << "," << - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[2] << "," << - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[3] << "," << - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[4] << - " -DCK_PARAM_ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" << - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0] << "," << - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1] << "," << - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2] << "," << - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3] << "," << - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4] << - " -DCK_PARAM_ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" << - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0] << "," << - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1] << "," << - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2] << "," << - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3] << "," << - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4] << - " -DCK_PARAM_BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1=" << - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[0] << "," << - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[1] << "," << - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[2] << "," << - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[3] << "," << - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[4] << - " -DCK_PARAM_BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1=" << - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[0] << "," << - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[1] << "," << - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[2] << "," << - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[3] << "," << - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[4] << - " -DCK_PARAM_BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" << - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0] << "," << - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1] << "," << - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2] << "," << - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3] << "," << - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4] << - " -DCK_PARAM_BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" << - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0] << "," << - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1] << "," << - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2] << "," << - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3] << "," << - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4] << - " -DCK_PARAM_CThreadTransferDstScalarPerVector=" << - CThreadTransferDstScalarPerVector << - " -DCK_PARAM_HasMainKBlockLoop=" << - static_cast(HasMainKBlockLoop) << - " -DCK_PARAM_HasDoubleTailKBlockLoop=" << - static_cast(HasDoubleTailKBlockLoop); - // clang-format on - - return param.str(); - } - - ck::DataTypeEnum_t ABDataTypeEnum = ck::DataTypeEnum_t::Unknown; - ck::DataTypeEnum_t AccDataTypeEnum = ck::DataTypeEnum_t::Unknown; - ck::DataTypeEnum_t CDataTypeEnum = ck::DataTypeEnum_t::Unknown; - - int BlockSize = -1; - - int GN0 = -1; - int GK1 = -1; - - int GM1PerBlockGM11 = -1; - int GN1PerBlockGN11 = -1; - int GK0PerBlock = -1; - - int BM1PerThreadBM11 = -1; - int BN1PerThreadBN11 = -1; - int BK0PerThread = -1; - - std::array BM10BN10ThreadClusterBM10Xs = {-1, -1}; - std::array BM10BN10ThreadClusterBN10Xs = {-1, -1}; - - std::array ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = { - -1, -1, -1, -1, -1}; - std::array ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = { - -1, -1, -1, -1, -1}; - std::array ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = { - -1, -1, -1, -1, -1}; - std::array ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = { - -1, -1, -1, -1, -1}; - - std::array BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = { - -1, -1, -1, -1, -1}; - std::array BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = { - -1, -1, -1, -1, -1}; - std::array BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = { - -1, -1, -1, -1, -1}; - std::array BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = { - -1, -1, -1, -1, -1}; - - int CThreadTransferDstScalarPerVector = -1; - - bool HasMainKBlockLoop = false; - bool HasDoubleTailKBlockLoop = false; -}; - -struct TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw -{ - ck::DataTypeEnum_t ABDataTypeEnum; - ck::DataTypeEnum_t CDataTypeEnum; - - int BlockSize; - - int GN0; - int GK1; - - int GM1PerBlockGM11; - int GN1PerBlockGN11; - int GK0PerBlock; - - int BM1PerThreadBM11; - int BN1PerThreadBN11; - int BK0PerThread; - - std::array BM10BN10ThreadClusterBM10Xs; - std::array BM10BN10ThreadClusterBN10Xs; - - std::array ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1; - std::array ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1; - std::array ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; - std::array ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; - - std::array BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1; - std::array BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1; - std::array BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; - std::array BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; -}; - -inline static auto generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw() -{ - constexpr auto f32 = ck::DataTypeEnum_t::Float; - constexpr auto f16 = ck::DataTypeEnum_t::Half; - constexpr auto i8 = ck::DataTypeEnum_t::Int8; - - return std::vector{ - // clang-format off - // fp32 - {f32, f32, 256, 1, 1, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 1}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, - - {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, - {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}}, - {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}}, - - {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 1}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - {f32, f32, 256, 2, 1, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 1}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - {f32, f32, 256, 4, 1, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - - {f32, f32, 256, 8, 1, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 1}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - - {f32, f32, 128, 1, 1, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 1}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - - // fp16 - {f16, f16, 256, 1, 2, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 2}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, - - {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, - {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}}, - {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}}, - - {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 2}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - {f16, f16, 256, 2, 2, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 2}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - {f16, f16, 256, 4, 2, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - - {f16, f16, 256, 8, 2, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 2}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - - {f16, f16, 128, 1, 2, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 2}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - - // i8 - { i8, i8, 256, 1, 4, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 4}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, - - { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, - { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}}, - { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}}, - - { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 4}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - { i8, i8, 256, 2, 4, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 4}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - { i8, i8, 256, 4, 4, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - - { i8, i8, 256, 8, 4, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 4}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - - { i8, i8, 128, 1, 4, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 4}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}} - // clang-format on - }; -} - -// TODO make this common interface and write specs for it -struct ConvIgemmFwdV6r1DlopsNchwKcyxNkhw -{ - static auto - CalculateCompileParameterBasedOnTunable(const ConvolutionProblemDescriptor& conv_problem_desc, - const TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw& tunable) - { - const int C = conv_problem_desc.C; - const int Y = conv_problem_desc.Y; - const int X = conv_problem_desc.X; - const int Ho = conv_problem_desc.Ho; - const int Wo = conv_problem_desc.Wo; - - if(!(conv_problem_desc.InDataTypeEnum == tunable.ABDataTypeEnum && - conv_problem_desc.WeiDataTypeEnum == tunable.ABDataTypeEnum && - conv_problem_desc.OutDataTypeEnum == tunable.CDataTypeEnum)) - return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); - - const auto ABDataTypeEnum = conv_problem_desc.InDataTypeEnum; - const auto CDataTypeEnum = conv_problem_desc.OutDataTypeEnum; - - DataTypeEnum_t AccDataTypeEnum; - - if(ABDataTypeEnum == DataTypeEnum_t::Float || ABDataTypeEnum == DataTypeEnum_t::Half) - { - AccDataTypeEnum = DataTypeEnum_t::Float; - } - else if(ABDataTypeEnum == DataTypeEnum_t::Int8) - { - AccDataTypeEnum = DataTypeEnum_t::Int32; - } - else - { - return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); - } - - const int BlockSize = tunable.BlockSize; - - const int GN0 = tunable.GN0; - const int GK1 = tunable.GK1; - - const int GM11 = tunable.GM1PerBlockGM11; - const int GN11 = tunable.GN1PerBlockGN11; - const int GK0PerBlock = tunable.GK0PerBlock; - - const int BM11 = tunable.BM1PerThreadBM11; - const int BN11 = tunable.BN1PerThreadBN11; - const int BK0PerThread = tunable.BK0PerThread; - - const auto BM10BN10ThreadClusterBM10Xs = tunable.BM10BN10ThreadClusterBM10Xs; - const auto BM10BN10ThreadClusterBN10Xs = tunable.BM10BN10ThreadClusterBN10Xs; - - const auto ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = - tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1; - const auto ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = - tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1; - const auto ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = - tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; - const auto ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = - tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; - - const auto BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = - tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1; - const auto BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = - tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1; - const auto BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = - tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; - const auto BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = - tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; - - // C threadwise copy: {BN11} or {BN} or {BN1} or {GN11} is Dst vector dim - const int CThreadTransferDstScalarPerVector = gcd(4, GN11, BN11, Ho * Wo); - - const int C0 = GK1; - - if(!(C % C0 == 0)) - return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); - - const int C1 = C / C0; - - const int GK0 = C1 * Y * X; - - if(!(GK0 % GK0PerBlock == 0)) - return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); - - const bool HasMainKBlockLoop = ((GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1); - - const bool HasDoubleTailKBlockLoop = ((GK0 / GK0PerBlock) % 2 == 0); - - return std::make_tuple( - CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{ - ABDataTypeEnum, - AccDataTypeEnum, - CDataTypeEnum, - BlockSize, - GN0, - GK1, - GM11, - GN11, - GK0PerBlock, - BM11, - BN11, - BK0PerThread, - BM10BN10ThreadClusterBM10Xs, - BM10BN10ThreadClusterBN10Xs, - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - CThreadTransferDstScalarPerVector, - HasMainKBlockLoop, - HasDoubleTailKBlockLoop}, - true); - } - - static auto GetDefaultCompileParameter(const ConvolutionProblemDescriptor& conv_problem_desc) - { - for(const auto& tunable : generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw()) - { - CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param{}; - bool found = false; - - std::tie(compile_param, found) = - CalculateCompileParameterBasedOnTunable(conv_problem_desc, tunable); - - if(found && IsValidCompileParameter(conv_problem_desc, compile_param)) - return std::make_tuple(compile_param, true); - } - - return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); - } - - static bool IsApplicable(const ConvolutionProblemDescriptor& conv_problem_desc) - { - bool found = false; - - std::tie(std::ignore, found) = GetDefaultCompileParameter(conv_problem_desc); - - return found; - } - - static bool - IsValidCompileParameter(const ConvolutionProblemDescriptor& conv_problem_desc, - const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param) - { - const int N = conv_problem_desc.N; - const int K = conv_problem_desc.K; - const int C = conv_problem_desc.C; - const int Y = conv_problem_desc.Y; - const int X = conv_problem_desc.X; - const int Ho = conv_problem_desc.Ho; - const int Wo = conv_problem_desc.Wo; - - const int GK1 = compile_param.GK1; - const int GN0 = compile_param.GN0; - const int GM11 = compile_param.GM1PerBlockGM11; - const int GN11 = compile_param.GN1PerBlockGN11; - - const int BM11 = compile_param.BM1PerThreadBM11; - const int BN11 = compile_param.BN1PerThreadBN11; - - const int C0 = GK1; - const int N0 = GN0; - - if(!(C % C0 == 0)) - return false; - - const int C1 = C / C0; - - if(!(N % N0 == 0)) - return false; - - const int N1 = N / N0; - - const int GM0 = 1; - const int GM1 = K; - const int GN1 = N1 * Ho * Wo; - const int GK0 = C1 * Y * X; - - // check data type - { - if(!(conv_problem_desc.InDataTypeEnum == conv_problem_desc.WeiDataTypeEnum && - conv_problem_desc.InDataTypeEnum == compile_param.ABDataTypeEnum)) - return false; - - if(compile_param.ABDataTypeEnum == DataTypeEnum_t::Float || - compile_param.ABDataTypeEnum == DataTypeEnum_t::Half) - { - if(!(compile_param.AccDataTypeEnum == DataTypeEnum_t::Float)) - return false; - } - else if(compile_param.ABDataTypeEnum == DataTypeEnum_t::Int8) - { - if(!(compile_param.AccDataTypeEnum == DataTypeEnum_t::Int32)) - return false; - } - } - - // check gridwise contraction - { - if(!(GM1 % GM11 == 0 && GN1 % GN11 == 0 && GK0 % compile_param.GK0PerBlock == 0)) - return false; - - const bool has_main_k_block_loop = - ((GK0 + compile_param.GK0PerBlock) / (2 * compile_param.GK0PerBlock) > 1); - - const bool has_double_tail_k_block_loop = ((GK0 / compile_param.GK0PerBlock) % 2 == 0); - - if(!(has_main_k_block_loop == compile_param.HasMainKBlockLoop && - has_double_tail_k_block_loop == compile_param.HasDoubleTailKBlockLoop)) - return false; - } - - // check A blockwise copy - { - const auto block_slice_lengths = - std::array{compile_param.GK0PerBlock, GM0, 1, GM11, GK1}; - const auto& cluster_lengths = - compile_param.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1; - const auto& thread_slice_lengths = - compile_param.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1; - const auto& src_vector_lengths = - compile_param.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; - const auto& dst_vector_lengths = - compile_param.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; - - // check number of working thread - const int num_work_thread = std::accumulate( - cluster_lengths.begin(), cluster_lengths.end(), 1, std::multiplies{}); - - if(!(compile_param.BlockSize >= num_work_thread)) - return false; - - // check block slice lengths vs thread slice lengths vs cluster lengths - for(int i = 0; i < 5; ++i) - { - if(!(cluster_lengths[i] * thread_slice_lengths[i] == block_slice_lengths[i])) - return false; - } - - // check thread slice lengths vs vector lengths - for(int i = 0; i < 5; ++i) - { - if(!(thread_slice_lengths[i] % src_vector_lengths[i] == 0)) - return false; - - if(!(thread_slice_lengths[i] % dst_vector_lengths[i] == 0)) - return false; - } - - // check Src vectorization, GK0 is global mem vector dim - if(!(src_vector_lengths[1] == 1 && src_vector_lengths[2] == 1 && - src_vector_lengths[3] == 1 && src_vector_lengths[4] == 1)) - return false; - - // check Dst vectorization, {GM11, GK1} are LDS vector dims - if(dst_vector_lengths[4] == GK1) - { // vectorize on {GM11, GK1} - if(!(GM11 % dst_vector_lengths[3] == 0)) - return false; - } - else - { // vectorize on {GK1} only - if(!(GK1 % dst_vector_lengths[4] == 0)) - return false; - - if(!(dst_vector_lengths[3] == 1)) - return false; - } - } - - // check B blockwise copy - { - const auto block_slice_lengths = - std::array{compile_param.GK0PerBlock, GN0, 1, GN11, GK1}; - const auto& cluster_lengths = - compile_param.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1; - const auto& thread_slice_lengths = - compile_param.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1; - const auto& src_vector_lengths = - compile_param.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; - const auto& dst_vector_lengths = - compile_param.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; - - // check number of working thread - const int num_work_thread = std::accumulate( - cluster_lengths.begin(), cluster_lengths.end(), 1, std::multiplies{}); - - if(!(compile_param.BlockSize >= num_work_thread)) - return false; - - // check block slice lengths vs thread slice lengths vs cluster lengths - for(int i = 0; i < 5; ++i) - { - if(!(cluster_lengths[i] * thread_slice_lengths[i] == block_slice_lengths[i])) - return false; - } - - // check thread slice lengths vs vector lengths - for(int i = 0; i < 5; ++i) - { - if(!(thread_slice_lengths[i] % src_vector_lengths[i] == 0 && - thread_slice_lengths[i] % dst_vector_lengths[i] == 0)) - return false; - } - - // check Src vectorization: {GN11} is global mem vector dim - if(!(src_vector_lengths[0] == 1 && src_vector_lengths[1] == 1 && - src_vector_lengths[2] == 1 && src_vector_lengths[4] == 1)) - return false; - - // check Src tensor layout related vectorization - if(Y == 1 && X == 1 && conv_problem_desc.ConvStrideH == 1 && - conv_problem_desc.ConvStrideW == 1 && conv_problem_desc.InLeftPadH == 0 && - conv_problem_desc.InLeftPadW == 0 && conv_problem_desc.InRightPadH == 0 && - conv_problem_desc.InRightPadW == 0) - { - if(!((Ho * Wo) % src_vector_lengths[3] == 0)) - return false; - } - else if(conv_problem_desc.ConvStrideW == 1 && conv_problem_desc.InLeftPadW == 0 && - conv_problem_desc.InRightPadW == 0) - { - if(!(Wo % src_vector_lengths[3] == 0)) - return false; - } - else - { - if(!(src_vector_lengths[3] == 1)) - return false; - } - - // check Dst vectorization: {GN11, GK1} are LDS vector dims - if(dst_vector_lengths[4] == GK1) - { // vectorize on {GN11, GK1} - if(!(GN11 % dst_vector_lengths[3] == 0)) - return false; - } - else - { // vectorize on {GK1} only - if(!(dst_vector_lengths[3] == 1)) - return false; - - if(!(GK1 % dst_vector_lengths[4] == 0)) - return false; - } - } - - // check blockwise GEMM - { - const int BM10 = std::accumulate(compile_param.BM10BN10ThreadClusterBM10Xs.begin(), - compile_param.BM10BN10ThreadClusterBM10Xs.end(), - 1, - std::multiplies{}); - - const int BN10 = std::accumulate(compile_param.BM10BN10ThreadClusterBN10Xs.begin(), - compile_param.BM10BN10ThreadClusterBN10Xs.end(), - 1, - std::multiplies{}); - - if(!(compile_param.BlockSize == BM10 * BN10)) - return false; - - const int BM = GM0 * GM11; - const int BN = GN0 * GN11; - - const int BM1 = BM10 * BM11; - const int BN1 = BN10 * BN11; - - if(!(BM % BM1 == 0 && BN % BN1 == 0)) - return false; - - const int BM0 = BM / BM1; - const int BN0 = BN / BN1; - - // blockwise GEMM currently only support BM0 == 2 && BN0 == 2 - if(!(BM0 == 2 && BN0 == 2)) - return false; - - if(!(compile_param.GK0PerBlock % compile_param.BK0PerThread == 0)) - return false; - } - - // check C threadwise copy - { - // {BN11} or {BN} or {BN1} or {GN11} is Dst vector dim - const int dst_vector_len_gn11 = compile_param.CThreadTransferDstScalarPerVector; - - // check slice length vs Dst vector length: - if(!(BN11 % dst_vector_len_gn11 == 0 && GN11 % dst_vector_len_gn11 == 0)) - return false; - - // check Dst memory layout related vectorization: - if(!((Ho * Wo) % compile_param.CThreadTransferDstScalarPerVector == 0)) - return false; - } - - return true; - }; - - static int GetBlockSize(const ConvolutionProblemDescriptor&, - const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param) - { - return compile_param.BlockSize; - } - - static int GetGridSize(const ConvolutionProblemDescriptor& conv_problem_desc, - const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param) - { - const int N = conv_problem_desc.N; - const int K = conv_problem_desc.K; - const int Ho = conv_problem_desc.Ho; - const int Wo = conv_problem_desc.Wo; - - const int N0 = compile_param.GN0; - const int N1 = N / N0; - - const int GM1 = K; - const int GN1 = N1 * Ho * Wo; - - const int GM11 = compile_param.GM1PerBlockGM11; - const int GN11 = compile_param.GN1PerBlockGN11; - - const int GM10 = GM1 / GM11; - const int GN10 = GN1 / GN11; - - return GM10 * GN10; - } - - static std::size_t GetWorkSpaceSize(const ConvolutionProblemDescriptor&, - const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw&) - { - // workspace is used for save transformed tensor descritpors created by prepare kernel - return 4096L; - } - - static std::size_t GetMaxWorkSpaceSize(const ConvolutionProblemDescriptor&) { return 4096L; } - - static auto GetTunableList() - { - return generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw(); - } -}; - -} // namespace driver -} // namespace ck -#endif diff --git a/host/solver/include/conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp b/host/solver/include/conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 58fe588ad98..00000000000 --- a/host/solver/include/conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,51 +0,0 @@ -#ifndef CONV_TUNABLE_FWD_V4R4_DLOPS_NCHW_KCYX_NKHW_HPP -#define CONV_TUNABLE_FWD_V4R4_DLOPS_NCHW_KCYX_NKHW_HPP - -struct tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw -{ - int BlockSize; - - int MPerBlock; - int NPerBlock; - int KPerBlock; - - int M1PerThread; - int N1PerThread; - int KPerThread; - - int M1N1ThreadClusterM10; - int M1N1ThreadClusterN10; - int M1N1ThreadClusterM11; - int M1N1ThreadClusterN11; - - std::array ABlockTransferThreadSliceLengths_K_M0_M1; - std::array ABlockTransferThreadClusterLengths_K_M0_M1; - std::array ABlockTransferThreadClusterArrangeOrder; - std::array ABlockTransferSrcAccessOrder; - int ABlockTransferSrcVectorDim; - int ABlockTransferSrcScalarPerVector; - int ABlockTransferDstScalarPerVector_M1; - bool AThreadTransferSrcResetCoordinateAfterRun; - - std::array BBlockTransferThreadSliceLengths_K_N0_N1; - std::array BBlockTransferThreadClusterLengths_K_N0_N1; - std::array BBlockTransferThreadClusterArrangeOrder; - std::array BBlockTransferSrcAccessOrder; - int BBlockTransferSrcVectorDim; - int BBlockTransferSrcScalarPerVector; - int BBlockTransferDstScalarPerVector_N1; - bool BThreadTransferSrcResetCoordinateAfterRun; - - std::array CThreadTransferSrcDstAccessOrder; - int CThreadTransferSrcDstVectorDim; - int CThreadTransferDstScalarPerVector; -}; - -static tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw - default_tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw = { - 256, 128, 128, 8, 4, 4, 1, - 8, 8, 2, 2, {4, 1, 1}, {2, 1, 128}, {2, 1, 0}, - {2, 1, 0}, 0, 4, 1, false, {4, 1, 1}, {2, 1, 128}, - {0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2}, - 5, 1}; -#endif diff --git a/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 361f6e4a26e..00000000000 --- a/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,73 +0,0 @@ -#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP -#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP - -struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw -{ - int BlockSize; - - int MPerBlock; - int NPerBlock; - int KPerBlock; - - int MPerXDL; - int NPerXDL; - int K1; - - int MRepeat; - int NRepeat; - - std::array ABlockTransferThreadSliceLengths_K0_M_K1; - std::array ABlockTransferThreadClusterLengths_K0_M_K1; - std::array ABlockTransferThreadClusterArrangeOrder; - std::array ABlockTransferSrcAccessOrder; - int ABlockTransferSrcVectorDim; - int ABlockTransferSrcScalarPerVector; - int ABlockTransferDstScalarPerVector_K1; - bool AThreadTransferSrcResetCoordinateAfterRun; - - std::array BBlockTransferThreadSliceLengths_K0_N_K1; - std::array BBlockTransferThreadClusterLengths_K0_N_K1; - std::array BBlockTransferThreadClusterArrangeOrder; - std::array BBlockTransferSrcAccessOrder; - int BBlockTransferSrcVectorDim; - int BBlockTransferSrcScalarPerVector; - int BBlockTransferDstScalarPerVector_K1; - bool BThreadTransferSrcResetCoordinateAfterRun; - - std::array CThreadTransferSrcDstAccessOrder; - int CThreadTransferSrcDstVectorDim; - int CThreadTransferDstScalarPerVector; -}; - -static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw - default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw = { - 256, // BlockSize - 128, // MPerBlock, - 128, // NPerBlock, - 4, // KPerBlock, - 32, // MPerXDL, - 32, // NPerXDL, - 4, // K1, - 2, // MRepeat, - 2, // NRepeat, - {1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1, - {4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1, - {1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder, - {1, 0, 2}, // ABlockTransferSrcAccessOrder, - 2, // ABlockTransferSrcVectorDim - 1, // ABlockTransferSrcScalarPerVector, - 4, // ABlockTransferDstScalarPerVector_K1, - false, // AThreadTransferSrcResetCoordinateAfterRun, - {1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1, - {4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1, - {0, 2, 1}, // BBlockTransferThreadClusterArrangeOrder, - {1, 0, 2}, // BBlockTransferSrcAccessOrder, - 1, // BBlockTransferSrcVectorDim - 1, // BBlockTransferSrcScalarPerVector - 4, // BBlockTransferDstScalarPerVector_K1 - false, // BThreadTransferSrcResetCoordinateAfterRun - {3, 0, 1, 2, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder - 7, // CThreadTransferSrcDstVectorDim, - 1 // CThreadTransferDstScalarPerVector -}; -#endif diff --git a/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 263c21a13b8..00000000000 --- a/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,73 +0,0 @@ -#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP -#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP - -struct tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk -{ - int BlockSize; - - int MPerBlock; - int NPerBlock; - int KPerBlock; - - int MPerWave; - int NPerWave; - int K1; - - int MRepeat; - int NRepeat; - - std::array ABlockTransferThreadSliceLengths_K0_M_K1; - std::array ABlockTransferThreadClusterLengths_K0_M_K1; - std::array ABlockTransferThreadClusterArrangeOrder; - std::array ABlockTransferSrcAccessOrder; - int ABlockTransferSrcVectorDim; - int ABlockTransferSrcScalarPerVector; - int ABlockTransferDstScalarPerVector_K1; - bool AThreadTransferSrcResetCoordinateAfterRun; - - std::array BBlockTransferThreadSliceLengths_K0_N_K1; - std::array BBlockTransferThreadClusterLengths_K0_N_K1; - std::array BBlockTransferThreadClusterArrangeOrder; - std::array BBlockTransferSrcAccessOrder; - int BBlockTransferSrcVectorDim; - int BBlockTransferSrcScalarPerVector; - int BBlockTransferDstScalarPerVector_K1; - bool BThreadTransferSrcResetCoordinateAfterRun; - - std::array CThreadTransferSrcDstAccessOrder; - int CThreadTransferSrcDstVectorDim; - int CThreadTransferDstScalarPerVector; -}; - -static tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk - default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk = { - 256, // BlockSize - 128, // MPerBlock, - 128, // NPerBlock, - 4, // KPerBlock, - 32, // MPerWave, - 32, // NPerWave, - 4, // K1, - 2, // MRepeat, - 2, // NRepeat, - {1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1, - {4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1, - {1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder, - {1, 0, 2}, // ABlockTransferSrcAccessOrder, - 2, // ABlockTransferSrcVectorDim - 4, // ABlockTransferSrcScalarPerVector, - 4, // ABlockTransferDstScalarPerVector_K1, - false, // AThreadTransferSrcResetCoordinateAfterRun, - {1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1, - {4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1, - {1, 0, 2}, // BBlockTransferThreadClusterArrangeOrder, - {1, 0, 2}, // BBlockTransferSrcAccessOrder, - 2, // BBlockTransferSrcVectorDim - 4, // BBlockTransferSrcScalarPerVector - 4, // BBlockTransferDstScalarPerVector_K1 - false, // BThreadTransferSrcResetCoordinateAfterRun - {2, 3, 0, 1, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder - 7, // CThreadTransferSrcDstVectorDim, - 1 // CThreadTransferDstScalarPerVector -}; -#endif diff --git a/host/solver/include/convolution_problem_descriptor.hpp b/host/solver/include/convolution_problem_descriptor.hpp deleted file mode 100644 index 8c0ecbee80b..00000000000 --- a/host/solver/include/convolution_problem_descriptor.hpp +++ /dev/null @@ -1,81 +0,0 @@ -#ifndef CONVOLUTION_PROBLEM_DESCRIPTOR -#define CONVOLUTION_PROBLEM_DESCRIPTOR - -namespace ck { -namespace driver { - -struct ConvolutionProblemDescriptor -{ - ConvolutionProblemDescriptor() = default; - - ConvolutionProblemDescriptor(int N_, - int K_, - int C_, - int Y_, - int X_, - int Hi_, - int Wi_, - int Ho_, - int Wo_, - int ConvStrideH_, - int ConvStrideW_, - int ConvDilationH_, - int ConvDilationW_, - int InLeftPadH_, - int InLeftPadW_, - int InRightPadH_, - int InRightPadW_, - ck::DataTypeEnum_t InDataTypeEnum_, - ck::DataTypeEnum_t WeiDataTypeEnum_, - ck::DataTypeEnum_t OutDataTypeEnum_) - : N{N_}, - K{K_}, - C{C_}, - Y{Y_}, - X{X_}, - Hi{Hi_}, - Wi{Wi_}, - Ho{Ho_}, - Wo{Wo_}, - ConvStrideH{ConvStrideH_}, - ConvStrideW{ConvStrideW_}, - ConvDilationH{ConvDilationH_}, - ConvDilationW{ConvDilationW_}, - InLeftPadH{InLeftPadH_}, - InLeftPadW{InLeftPadW_}, - InRightPadH{InRightPadH_}, - InRightPadW{InRightPadW_}, - InDataTypeEnum{InDataTypeEnum_}, - WeiDataTypeEnum{WeiDataTypeEnum_}, - OutDataTypeEnum{OutDataTypeEnum_} - { - } - - int N; - int K; - int C; - int Y; - int X; - int Hi; - int Wi; - int Ho; - int Wo; - int ConvStrideH; - int ConvStrideW; - int ConvDilationH; - int ConvDilationW; - int InLeftPadH; - int InLeftPadW; - int InRightPadH; - int InRightPadW; - - ck::DataTypeEnum_t InDataTypeEnum; - ck::DataTypeEnum_t WeiDataTypeEnum; - ck::DataTypeEnum_t OutDataTypeEnum; - - std::size_t CalculateFlop() const { return 2L * N * K * C * Y * X * Ho * Wo; } -}; - -} // namespace driver -} // namespace ck -#endif diff --git a/host/solver/include/solver_common.hpp b/host/solver/include/solver_common.hpp deleted file mode 100644 index d1792f7681a..00000000000 --- a/host/solver/include/solver_common.hpp +++ /dev/null @@ -1,46 +0,0 @@ -#ifndef CK_SOLVER_COMMON_HPP -#define CK_SOLVER_COMMON_HPP - -namespace ck { -namespace driver { - -// greatest common divisor, aka highest common factor -inline int gcd(int x, int y) -{ - if(x < 0) - { - return gcd(-x, y); - } - else if(y < 0) - { - return gcd(x, -y); - } - else if(x == y || x == 0) - { - return y; - } - else if(y == 0) - { - return x; - } - else if(x > y) - { - return gcd(x % y, y); - } - else - { - return gcd(x, y % x); - } -} - -template = 2, bool>::type = false> -auto gcd(X x, Ys... ys) -{ - return gcd(x, gcd(ys...)); -} - -} // namespace driver -} // namespace ck -#endif diff --git a/include/ck/config.hpp b/include/ck/config.hpp new file mode 100644 index 00000000000..66996404241 --- /dev/null +++ b/include/ck/config.hpp @@ -0,0 +1,182 @@ +#ifndef CK_CONFIG_AMD_HPP +#define CK_CONFIG_AMD_HPP + +#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS +#include "hip/hip_runtime.h" +#include "hip/hip_fp16.h" +#endif + +// constant address space for kernel parameter +// https://llvm.org/docs/AMDGPUUsage.html#address-spaces +#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4))) + +// launch bounds +#define CK_USE_LAUNCH_BOUNDS 1 + +#ifdef CK_USE_LAUNCH_BOUNDS +#define CK_MAX_THREAD_PER_BLOCK 256 +#define CK_MIN_BLOCK_PER_CU 2 +#endif + +// check GPU target +#ifdef __HIP_DEVICE_COMPILE__ +#if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx1030__)) +#error Not supported target +#endif +#endif + +// buffer resource +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_BUFFER_RESOURCE_3RD_DWORD -1 +#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) // for GPU code +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 +#elif defined(__gfx1030__) // for GPU code +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 +#endif + +// FMA instruction +#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing +#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code +#define CK_USE_AMD_V_MAC_F32 +#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx1030__) // for GPU code +#define CK_USE_AMD_V_FMAC_F32 +#define CK_USE_AMD_V_DOT2_F32_F16 +#define CK_USE_AMD_V_DOT4_I32_I8 +#endif + +// MFMA instruction +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_USE_AMD_MFMA +#elif defined(__gfx908__) || defined(__gfx90a__) // for GPU code +#define CK_USE_AMD_MFMA +#endif + +#if defined(__gfx90a__) +#define CK_USE_AMD_MFMA_BF16_1K_OP +#endif + +// buffer load +#define CK_USE_AMD_BUFFER_LOAD 1 + +// buffer store +#define CK_USE_AMD_BUFFER_STORE 1 + +// buffer atomic add: integer +#define CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1 + +// buffer atomic add: floating point +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 +#elif defined(__gfx908__) || defined(__gfx90a__) // for GPU code +#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 +#else // for GPU code +#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 +#endif + +#if defined(__gfx90a__) // for GPU code +#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1 +#else +#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0 +#endif + +// inline asm +#define CK_USE_AMD_INLINE_ASM 1 + +// inner product (DLOP) +#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1 + +// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) +#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 + +// experimental feature: multi index implemented as array +#define CK_EXPERIMENTAL_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0 + +// experimental feature: static tensor descriptor +#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0 + +// experimental feature: buffer load/store/atomic-add/ OOB trick +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 +#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 +#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 +#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1 + +// experimental feature: in-regsiter sub-dword transpose +#define CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE 1 + +// experimental feature: merge transformation use magic number division +#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1 + +// experimental feature: use __builtin_memcpy instead of pointer cast to access a vector from +// pointer of scalar +#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0 + +// experimental feature: use __builtin_memcpy instead of union to do bit_cast +#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1 + +// experimental feature: optimize for inter-wave scheduling policy +#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING 0 +#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS 1 + +// hack: have underlying assumption that need to be satsified, otherwise it's a bug +// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be +// thread-invariant, otherwise it's a bug +// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" +#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 + +// workaround: compiler crash when compiling recursive lambda +#define CK_WORKAROUND_SWDEV_275126 1 + +// workaround: compiler crash when using buffer load/store for i8 +#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1 + +// workaround: compiler gnerating inefficient ds_write instructions +#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 + +// workaround: verifaction failure, due to compiler regression, for conv bwd-data fp16 using some +// tuning parameter +#define CK_WORKAROUND_SWDEV_325164 1 + +// workaround for verification failure ConvNd forward +// https://github.com/ROCmSoftwarePlatform/composable_kernel/issues/135 +#define CK_WORKAROUND_GITHUB_135 1 + +namespace ck { + +enum struct InMemoryDataOperationEnum +{ + Set, + AtomicAdd, + AtomicMax, + Add +}; + +template +struct InMemoryDataOperationEnumSequence +{ + static constexpr int mSize = sizeof...(Is); + + __host__ __device__ static constexpr InMemoryDataOperationEnum At(int I) + { + // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 + const InMemoryDataOperationEnum mData[mSize + 1] = {Is..., InMemoryDataOperationEnum::Set}; + return mData[I]; + } +}; + +// TODO: no longer needed, remove this +enum struct ActivTypeEnum +{ + None, + LeakyRelu, + Sigmoid +}; + +// index type +using index_t = int32_t; +using long_index_t = int64_t; + +} // namespace ck +#endif diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp new file mode 100644 index 00000000000..74b20acecd3 --- /dev/null +++ b/include/ck/host_utility/device_prop.hpp @@ -0,0 +1,50 @@ +#pragma once + +#include +#include + +namespace ck { + +inline std::string get_device_name() +{ + hipDeviceProp_t props{}; + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return std::string(); + } + + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return std::string(); + } + const std::string raw_name(props.gcnArchName); + + // https://github.com/ROCmSoftwarePlatform/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 + static std::map device_name_map = { + {"Ellesmere", "gfx803"}, + {"Baffin", "gfx803"}, + {"RacerX", "gfx803"}, + {"Polaris10", "gfx803"}, + {"Polaris11", "gfx803"}, + {"Tonga", "gfx803"}, + {"Fiji", "gfx803"}, + {"gfx800", "gfx803"}, + {"gfx802", "gfx803"}, + {"gfx804", "gfx803"}, + {"Vega10", "gfx900"}, + {"gfx901", "gfx900"}, + {"10.3.0 Sienna_Cichlid 18", "gfx1030"}, + }; + + const auto name = raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str. + + auto match = device_name_map.find(name); + if(match != device_name_map.end()) + return match->second; + return name; +} + +} // namespace ck diff --git a/include/ck/options.hpp b/include/ck/options.hpp new file mode 100644 index 00000000000..82c604f82ba --- /dev/null +++ b/include/ck/options.hpp @@ -0,0 +1,3 @@ +#pragma once + +#define CK_TIME_KERNEL 1 diff --git a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp similarity index 75% rename from composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp rename to include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp index 09ea16fa239..af682ecfa7e 100644 --- a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp @@ -7,9 +7,9 @@ namespace ck { -// Number of GEMMs = YTilda * XTilda +// Number of GEMMs = YTilde * XTilde // GemmM = C -// GemmN = N * HTildaSlice * WTildaSlice +// GemmN = N * HTildeSlice * WTildeSlice // GemmK = K * YDotSlice * XDotSlice template __host__ __device__ constexpr auto transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( @@ -30,8 +30,8 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, const InRightPads& in_right_pads, - Number, - Number, + Number, + Number, Number) { constexpr auto I0 = Number<0>{}; @@ -40,8 +40,8 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( constexpr auto I3 = Number<3>{}; constexpr auto GemmK1 = Number{}; - constexpr auto IYTilda = Number{}; - constexpr auto IXTilda = Number{}; + constexpr auto IYTilde = Number{}; + constexpr auto IXTilde = Number{}; const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); @@ -71,55 +71,55 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - const auto YTilda = ConvStrideH / GcdStrideDilationH; - const auto XTilda = ConvStrideW / GcdStrideDilationW; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; - const auto YDot = math::integer_divide_ceil(Y, YTilda); - const auto XDot = math::integer_divide_ceil(X, XTilda); + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); - const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); - const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); - // only work on HTilda and WTilda that contribute to non-padding area of input tensor - const auto IHTildaSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH); - const auto IWTildaSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW); + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); - const auto IHTildaSliceEnd = - math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); - const auto IWTildaSliceEnd = - math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + const auto IHTildeSliceEnd = + math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = + math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); - const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin; - const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin; + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; // GemmK is different for each GEMM - const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda); - const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda); + const auto YDotSlice = math::integer_divide_ceil(Y - IYTilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - IXTilde, XTilde); const auto K1 = GemmK1; const auto K0 = K / K1; // weight tensor - const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor( + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( wei_k_y_x_c_grid_desc, make_tuple(make_pass_through_transform(K), - make_embed_transform(make_tuple(YDot, YTilda), + make_embed_transform(make_tuple(YDot, YTilde), make_tuple(ConvStrideH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, XTilda), + make_embed_transform(make_tuple(XDot, XTilde), make_tuple(ConvStrideW / GcdStrideDilationW, I1)), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = - transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(XDot, I0, XDotSlice), - make_freeze_transform(IYTilda), - make_freeze_transform(IXTilda), + make_freeze_transform(IYTilde), + make_freeze_transform(IXTilde), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -163,25 +163,25 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor( + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( out_n_hop_wop_k_grid_desc, make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YDot, HTilda), + make_embed_transform(make_tuple(YDot, HTilde), make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, WTilda), + make_embed_transform(make_tuple(XDot, WTilde), make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), make_pass_through_transform(K)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc = + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor( - out_n_ydot_htilda_xdot_wtilda_k_grid_desc, + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, make_tuple(make_pass_through_transform(N), make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(XDot, I0, XDotSlice), - make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_unmerge_transform(make_tuple(K0, K1))), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -198,17 +198,17 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( #if 1 const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), - make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), make_pass_through_transform(K1)), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); #else const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), - make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), make_pass_through_transform(K1)), make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); @@ -224,24 +224,24 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor( + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( in_n_hip_wip_c_grid_desc, make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YTilda, HTilda), + make_embed_transform(make_tuple(YTilde, HTilde), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(XTilda, WTilda), + make_embed_transform(make_tuple(XTilde, WTilde), make_tuple(ConvDilationW, ConvStrideW)), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor( - in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, make_tuple(make_pass_through_transform(N), - make_freeze_transform(IYTilda), - make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), - make_freeze_transform(IXTilda), - make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_freeze_transform(IYTilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IXTilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -257,9 +257,9 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( Sequence<3>{})); const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - in_n_htildaslice_wtildaslice_c_grid_desc, + in_n_htildeslice_wtildeslice_c_grid_desc, make_tuple(make_pass_through_transform(C), - make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice))), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))), make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); diff --git a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp similarity index 81% rename from composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp rename to include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp index fa78d769653..6693c0756b9 100644 --- a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp @@ -10,8 +10,8 @@ namespace ck { // A: out // B: wei // C: in -// Number of GEMMs = YTilda * XTilda -// GemmM = N * HTildaSlice * WTildaSlice +// Number of GEMMs = YTilde * XTilde +// GemmM = N * HTildeSlice * WTildeSlice // GemmN = C // GemmK = K * YDotSlice * XDotSlice template __host__ __device__ constexpr auto transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( @@ -33,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, const InRightPads& in_right_pads, - IYTilda i_ytilda, - IXTilda i_xtilda, + IYTilde i_ytilde, + IXTilde i_xtilde, Number) { constexpr auto I0 = Number<0>{}; @@ -72,32 +72,32 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - const auto YTilda = ConvStrideH / GcdStrideDilationH; - const auto XTilda = ConvStrideW / GcdStrideDilationW; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; - const auto YDot = math::integer_divide_ceil(Y, YTilda); - const auto XDot = math::integer_divide_ceil(X, XTilda); + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); - const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); - const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); - // only work on HTilda and WTilda that contribute to non-padding area of input tensor - const auto IHTildaSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH); - const auto IWTildaSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW); + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); - const auto IHTildaSliceEnd = - math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); - const auto IWTildaSliceEnd = - math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + const auto IHTildeSliceEnd = + math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = + math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); - const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin; - const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin; + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; // GemmK is different for each GEMM - const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilda, YTilda); - const auto XDotSlice = math::integer_divide_ceil(X - i_xtilda, XTilda); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); const auto K1 = GemmK1; const auto K0 = K / K1; @@ -113,25 +113,25 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor( + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( out_n_hop_wop_k_grid_desc, make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YDot, HTilda), + make_embed_transform(make_tuple(YDot, HTilde), make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, WTilda), + make_embed_transform(make_tuple(XDot, WTilde), make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), make_pass_through_transform(K)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc = + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor( - out_n_ydot_htilda_xdot_wtilda_k_grid_desc, + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, make_tuple(make_pass_through_transform(N), make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(XDot, I0, XDotSlice), - make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_unmerge_transform(make_tuple(K0, K1))), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -148,41 +148,41 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( #if 1 const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), - make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), make_pass_through_transform(K1)), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); #else const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), - make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), make_pass_through_transform(K1)), make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); #endif // B: weight tensor - const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor( + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( wei_k_y_x_c_grid_desc, make_tuple(make_pass_through_transform(K), - make_embed_transform(make_tuple(YDot, YTilda), + make_embed_transform(make_tuple(YDot, YTilde), make_tuple(ConvStrideH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, XTilda), + make_embed_transform(make_tuple(XDot, XTilde), make_tuple(ConvStrideW / GcdStrideDilationW, I1)), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = - transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(XDot, I0, XDotSlice), - make_freeze_transform(i_ytilda), - make_freeze_transform(i_xtilda), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -225,24 +225,24 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor( + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( in_n_hip_wip_c_grid_desc, make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YTilda, HTilda), + make_embed_transform(make_tuple(YTilde, HTilde), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(XTilda, WTilda), + make_embed_transform(make_tuple(XTilde, WTilde), make_tuple(ConvDilationW, ConvStrideW)), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor( - in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, make_tuple(make_pass_through_transform(N), - make_freeze_transform(i_ytilda), - make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), - make_freeze_transform(i_xtilda), - make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -258,8 +258,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( Sequence<3>{})); const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - in_n_htildaslice_wtildaslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), make_pass_through_transform(C)), make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp similarity index 100% rename from composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp rename to include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp similarity index 100% rename from composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp rename to include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp similarity index 100% rename from composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp rename to include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp similarity index 100% rename from composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp rename to include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp similarity index 100% rename from composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp rename to include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp diff --git a/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp b/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp new file mode 100644 index 00000000000..7544289b218 --- /dev/null +++ b/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp @@ -0,0 +1,150 @@ +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: in +// B: wei +// C: out +// GemmM = N * Do * Ho * Wo +// GemmN = K +// GemmK = Z * Y * X * C +template +__host__ __device__ constexpr auto +transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk_pad( + const TensorDescriptor& in_grid_desc_n_di_hi_wi_c, + const TensorDescriptor& wei_k_z_y_x_c_grid_desc, + const TensorDescriptor& out_n_do_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_grid_desc_n_di_hi_wi_c.GetLength(I0); + const auto K = out_n_do_ho_wo_k_grid_desc.GetLength(I4); + const auto C = in_grid_desc_n_di_hi_wi_c.GetLength(I4); + + const auto Di = in_grid_desc_n_di_hi_wi_c.GetLength(I1); + const auto Hi = in_grid_desc_n_di_hi_wi_c.GetLength(I2); + const auto Wi = in_grid_desc_n_di_hi_wi_c.GetLength(I3); + + const auto Do = out_n_do_ho_wo_k_grid_desc.GetLength(I1); + const auto Ho = out_n_do_ho_wo_k_grid_desc.GetLength(I2); + const auto Wo = out_n_do_ho_wo_k_grid_desc.GetLength(I3); + + const auto Z = wei_k_z_y_x_c_grid_desc.GetLength(I1); + const auto Y = wei_k_z_y_x_c_grid_desc.GetLength(I2); + const auto X = wei_k_z_y_x_c_grid_desc.GetLength(I3); + + const auto ConvStrideD = conv_strides[I0]; + const auto ConvStrideH = conv_strides[I1]; + const auto ConvStrideW = conv_strides[I2]; + + const auto ConvDilationD = conv_dilations[I0]; + const auto ConvDilationH = conv_dilations[I1]; + const auto ConvDilationW = conv_dilations[I2]; + + const auto InLeftPadD = in_left_pads[I0]; + const auto InLeftPadH = in_left_pads[I1]; + const auto InLeftPadW = in_left_pads[I2]; + + const auto InRightPadD = in_right_pads[I0]; + const auto InRightPadH = in_right_pads[I1]; + const auto InRightPadW = in_right_pads[I2]; + + const auto GemmM = N * Do * Ho * Wo; + const auto GemmN = K; + const auto GemmK = Z * Y * X * C; + const auto GemmK0 = GemmK / GemmK1; + + // A: input tensor + const auto in_grid_desc_n_dip_hip_wip_c = transform_tensor_descriptor( + in_grid_desc_n_di_hi_wi_c, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_grid_desc_n_z_do_y_ho_x_wo_c = transform_tensor_descriptor( + in_grid_desc_n_dip_hip_wip_c, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5, 6>{}, Sequence<7>{})); + + const auto in_grid_desc_gemmk_gemmm = + transform_tensor_descriptor(in_grid_desc_n_z_do_y_ho_x_wo_c, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_grid_desc_gemmk0_gemmm_gemmk1 = + transform_tensor_descriptor(in_grid_desc_gemmk_gemmm, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_grid_desc_gemmk_gemmn = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Z * Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_grid_desc_gemmk0_gemmn_gemmk1 = + transform_tensor_descriptor(wei_grid_desc_gemmk_gemmn, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_grid_desc_gemmm_gemmn = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Do * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // const auto out_grid_desc_gemmm_gemmn = transform_tensor_descriptor( + // out_n_do_ho_wo_k_grid_desc, + // make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + // make_pass_through_transform(K)), + // make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<3>{}), + // make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_grid_desc_gemmk0_gemmm_gemmk1, + wei_grid_desc_gemmk0_gemmn_gemmk1, + out_grid_desc_gemmm_gemmn); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp similarity index 100% rename from composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp rename to include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp similarity index 100% rename from composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp rename to include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp similarity index 100% rename from composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp rename to include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp similarity index 100% rename from composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp rename to include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp similarity index 98% rename from composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp rename to include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp index b0b07505e5e..ac90e8a6ffa 100644 --- a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp @@ -21,8 +21,7 @@ template -__host__ __device__ constexpr auto -transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk( const TensorDescriptor& in_n_hi_wi_c_grid_desc, const TensorDescriptor& wei_k_y_x_c_grid_desc, const TensorDescriptor& out_n_ho_wo_k_grid_desc, diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp similarity index 100% rename from composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp rename to include/ck/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp diff --git a/include/ck/stream_config.hpp b/include/ck/stream_config.hpp new file mode 100644 index 00000000000..3e80b4c8920 --- /dev/null +++ b/include/ck/stream_config.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include +#include + +struct StreamConfig +{ + hipStream_t stream_id_ = nullptr; + bool time_kernel_ = false; +}; diff --git a/include/ck/tensor/static_tensor.hpp b/include/ck/tensor/static_tensor.hpp new file mode 100644 index 00000000000..2ca920df9d4 --- /dev/null +++ b/include/ck/tensor/static_tensor.hpp @@ -0,0 +1,270 @@ +#ifndef CK_STATIC_TENSOR_HPP +#define CK_STATIC_TENSOR_HPP + +namespace ck { + +// StaticTensor for Scalar +template ::type = false> +struct StaticTensor +{ + static constexpr auto desc_ = TensorDesc{}; + static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension(); + static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize(); + + __host__ __device__ constexpr StaticTensor() : invalid_element_scalar_value_{0} {} + + __host__ __device__ constexpr StaticTensor(T invalid_element_value) + : invalid_element_scalar_value_{invalid_element_value} + { + } + + // read access + template ::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr const T& operator[](Idx) const + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_[Number{}]; + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return zero_scalar_value_; + } + else + { + return invalid_element_scalar_value_; + } + } + } + + // write access + template ::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr T& operator()(Idx) + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_(Number{}); + } + else + { + return ignored_element_scalar_; + } + } + + StaticBuffer data_; + static constexpr T zero_scalar_value_ = T{0}; + const T invalid_element_scalar_value_; + T ignored_element_scalar_; +}; + +// StaticTensor for vector +template ::type = false> +struct StaticTensorTupleOfVectorBuffer +{ + static constexpr auto desc_ = TensorDesc{}; + static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension(); + static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize(); + + static constexpr index_t num_of_vector_ = + math::integer_divide_ceil(element_space_size_, ScalarPerVector); + + using V = vector_type; + + __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer() + : invalid_element_scalar_value_{0} + { + } + + __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value) + : invalid_element_scalar_value_{invalid_element_value} + { + } + + // Get S + // Idx is for S, not V + template ::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr const S& operator[](Idx) const + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_[Number{}]; + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return zero_scalar_value_; + } + else + { + return invalid_element_scalar_value_; + } + } + } + + // Set S + // Idx is for S, not V + template ::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr S& operator()(Idx) + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_(Number{}); + } + else + { + return ignored_element_scalar_; + } + } + + // Get X + // Idx is for S, not X. Idx should be aligned with X + template ::value && + is_known_at_compile_time::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr X GetAsType(Idx) const + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_.template GetAsType(Number{}); + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + // TODO: is this right way to initialize a vector? + return X{0}; + } + else + { + // TODO: is this right way to initialize a vector? + return X{invalid_element_scalar_value_}; + } + } + } + + // Set X + // Idx is for S, not X. Idx should be aligned with X + template ::value && + is_known_at_compile_time::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr void SetAsType(Idx, X x) + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + data_.template SetAsType(Number{}, x); + } + } + + // Get read access to V. No is_valid check + // Idx is for S, not V. Idx should be aligned with V + template + __host__ __device__ constexpr const V& GetVectorTypeReference(Idx) const + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + return data_.GetVectorTypeReference(Number{}); + } + + // Get read access to V. No is_valid check + // Idx is for S, not V. Idx should be aligned with V + template + __host__ __device__ constexpr V& GetVectorTypeReference(Idx) + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + return data_.GetVectorTypeReference(Number{}); + } + + StaticBufferTupleOfVector data_; + static constexpr S zero_scalar_value_ = S{0}; + const S invalid_element_scalar_value_ = S{0}; + S ignored_element_scalar_; +}; + +template ::type = false> +__host__ __device__ constexpr auto make_static_tensor(TensorDesc) +{ + return StaticTensor{}; +} + +template < + AddressSpaceEnum AddressSpace, + typename T, + typename TensorDesc, + typename X, + typename enable_if::type = false, + typename enable_if, remove_cvref_t>::value, bool>::type = false> +__host__ __device__ constexpr auto make_static_tensor(TensorDesc, X invalid_element_value) +{ + return StaticTensor{invalid_element_value}; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/cluster_descriptor.hpp b/include/ck/tensor_description/cluster_descriptor.hpp similarity index 100% rename from composable_kernel/include/tensor_description/cluster_descriptor.hpp rename to include/ck/tensor_description/cluster_descriptor.hpp diff --git a/composable_kernel/include/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp similarity index 95% rename from composable_kernel/include/tensor_description/multi_index_transform.hpp rename to include/ck/tensor_description/multi_index_transform.hpp index 1a25e99f3bb..fa705cc3fee 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -30,7 +30,8 @@ struct PassThrough __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } template - __host__ __device__ static void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) + __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) { static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, "wrong! inconsistent # of dimension"); @@ -1708,7 +1709,8 @@ struct Vectorize __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } template - __host__ __device__ void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) const + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const { static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, "wrong! inconsistent # of dimension"); @@ -1860,5 +1862,92 @@ struct Slice } }; +/* + * \brief lower_idx = upper_idx % modulus. + * TODO: Need an improved implementation since the modulo operation is expensive. + */ +template +struct Modulo +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + using UpLengths = decltype(make_tuple(UpLength{})); + + Modulus modulus_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Modulo() = default; + + __host__ __device__ constexpr Modulo(const Modulus& modulus, const UpLength& up_length) + : modulus_{modulus}, up_lengths_{make_tuple(up_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] % modulus_; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& up_idx, + Number) const + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + const auto idx_low_old = idx_low; + idx_low(I0) = (up_idx(I0) + idx_diff_up(I0)) % modulus_; + idx_diff_low(I0) = idx_low - idx_low_old; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Modulus, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("}"); + } +}; } // namespace ck #endif diff --git a/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp b/include/ck/tensor_description/multi_index_transform_helper.hpp similarity index 90% rename from composable_kernel/include/tensor_description/multi_index_transform_helper.hpp rename to include/ck/tensor_description/multi_index_transform_helper.hpp index 9a737991735..bc360714b99 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp +++ b/include/ck/tensor_description/multi_index_transform_helper.hpp @@ -98,6 +98,12 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i return Freeze{low_idx}; } +template +__host__ __device__ constexpr auto make_insert_transform(const UpperIndex& up_idx) +{ + return Insert{up_idx}; +} + template __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_length, const SliceBegin& slice_begin, @@ -113,5 +119,11 @@ __host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& ve return Vectorize{vector_size, up_length}; } +template +__host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus, + const UpLength& up_length) +{ + return Modulo{modulus, up_length}; +} } // namespace ck #endif diff --git a/composable_kernel/include/tensor_description/tensor_adaptor.hpp b/include/ck/tensor_description/tensor_adaptor.hpp similarity index 98% rename from composable_kernel/include/tensor_description/tensor_adaptor.hpp rename to include/ck/tensor_description/tensor_adaptor.hpp index 50a8088bbab..8787abd6ba6 100644 --- a/composable_kernel/include/tensor_description/tensor_adaptor.hpp +++ b/include/ck/tensor_description/tensor_adaptor.hpp @@ -151,6 +151,20 @@ struct TensorAdaptor __host__ __device__ constexpr auto GetElementSize() const { return element_size_; } +#if 0 // debug + template + __host__ __device__ constexpr index_t GetTopDimensionLength(Number idim) const + { + // TODO: not implemented + } + + template + __host__ __device__ constexpr index_t GetBottomDimensionLength(Number idim) const + { + // TODO: not implemented + } +#endif + template __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const { diff --git a/composable_kernel/include/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp similarity index 98% rename from composable_kernel/include/tensor_description/tensor_descriptor.hpp rename to include/ck/tensor_description/tensor_descriptor.hpp index 8f6a5a3e43c..9cd51c61d66 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -307,6 +307,10 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, { // sanity check { + static_assert(NewTransforms::Size() == NewLowerDimensionOldVisibleIdss::Size() && + NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(), + "wrong! inconsitent number of transform"); + constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, NewLowerDimensionOldVisibleIdss{}); diff --git a/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp b/include/ck/tensor_description/tensor_descriptor_helper.hpp similarity index 88% rename from composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp rename to include/ck/tensor_description/tensor_descriptor_helper.hpp index ad75f9245ee..ddc0ede404d 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp +++ b/include/ck/tensor_description/tensor_descriptor_helper.hpp @@ -1,6 +1,4 @@ -#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP -#define CK_TENSOR_DESCRIPTOR_HELPER_HPP - +#pragma once #include "common_header.hpp" #include "tensor_descriptor.hpp" #include "multi_index_transform_helper.hpp" @@ -35,6 +33,12 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt } #endif +// Lengths..., Strides... could be: +// 1) index_t, which is known at run-time, or +// 2) Number<>, which is known at compile-time +// element_space_size could be: +// 1) long_index_t, or +// 2) LongNumber<> template ::type = false> @@ -68,10 +72,10 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple{}, Number<1>{}); + const auto element_space_size = f(f, Number<0>{}, LongNumber<1>{}); #else const auto element_space_size = - calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{}); + calculate_element_space_size_impl(lengths, strides, Number<0>{}, LongNumber<1>{}); #endif return TensorDescriptor, @@ -82,9 +86,12 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple, which is known at compile-time +// element_space_size could be: +// 1) long_index_t, or +// 2) LongNumber<> template __host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple& lengths) @@ -100,7 +107,7 @@ make_naive_tensor_descriptor_packed(const Tuple& lengths) constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; - const auto element_space_size = container_reduce(lengths, math::multiplies{}, Number<1>{}); + const auto element_space_size = container_reduce(lengths, math::multiplies{}, LongNumber<1>{}); return TensorDescriptor, remove_cv_t, @@ -110,6 +117,12 @@ make_naive_tensor_descriptor_packed(const Tuple& lengths) element_space_size}; } +// Lengths... could be: +// 1) index_t, which is known at run-time, or +// 2) Number<>, which is known at compile-time +// align could be: +// 1) index_t, or +// 2) Number<> template __host__ __device__ constexpr auto make_naive_tensor_descriptor_aligned(const Tuple& lengths, Align align) @@ -146,4 +159,3 @@ make_naive_tensor_descriptor_aligned(const Tuple& lengths, Align ali } } // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp similarity index 96% rename from composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp rename to include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp index 26ca0bf1115..f7fa867e162 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp @@ -1,10 +1,8 @@ -#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP -#define CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP - +#pragma once #include "common_header.hpp" #include "tensor_adaptor.hpp" -#include "threadwise_tensor_slice_transfer_v2.hpp" -#include "threadwise_contraction_dlops.hpp" +#include "threadwise_tensor_slice_transfer_v4r1.hpp" +#include "threadwise_contraction_dl.hpp" namespace ck { @@ -41,7 +39,7 @@ template ::type = false> -struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 +struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 { using AIndex = MultiIndex<3>; using BIndex = MultiIndex<3>; @@ -148,7 +146,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{}); public: - __device__ BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2() + __device__ BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2() : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( get_thread_local_1d_id())}, a_thread_copy_{ @@ -175,6 +173,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B "wrong!"); // TODO: remove this restriction + static_assert(BM0 == 2, "wrong"); static_assert(BM0 == 2 && BN0 == 2, "wrong"); } @@ -220,13 +219,13 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0, "wrong"); - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); constexpr auto threadwise_contraction = - ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< + ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< FloatA, FloatB, FloatC, @@ -407,4 +406,3 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B }; } // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp similarity index 99% rename from composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp rename to include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp index 35ff66a2b0e..2a8a4bc8b88 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp @@ -207,9 +207,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0, "wrong"); - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_k_m0_m1_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_k_n0_n1_thread_desc_.GetElementSpaceSize()); constexpr auto threadwise_gemm = diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp new file mode 100644 index 00000000000..78cfc1e0fbf --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp @@ -0,0 +1,175 @@ +#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP +#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP + +#include "common_header.hpp" +#include "threadwise_gemm_dlops_v3.hpp" + +namespace ck { + +template +struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + using AIndex = MultiIndex<3>; + using BIndex = MultiIndex<3>; + using CIndex = MultiIndex<4>; + + static constexpr auto E1 = ABlockDesc_E1_K1_E2{}.GetLength(I0); + static constexpr auto KPerBlock = ABlockDesc_E1_K1_E2{}.GetLength(I1); + static constexpr auto E2 = ABlockDesc_E1_K1_E2{}.GetLength(I2); + + static constexpr auto HoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I2); + static constexpr auto WoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I3); + + static constexpr auto KPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I0); + static constexpr auto HoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I2); + static constexpr auto WoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I3); + + static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + static constexpr auto b_thread_mtx_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number<1>{}, + Number{}, + Number{}, + Number{})); + + static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3() + : c_thread_origin_data_idx_{GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id())}, + a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * KPerThread, 0)} + { + static_assert(ABlockDesc_E1_K1_E2::IsKnownAtCompileTime() && + BBlockDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() && + CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert( + ABlockDesc_E1_K1_E2{}.GetLength(I0) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I0) && + ABlockDesc_E1_K1_E2{}.GetLength(I2) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I4), + "wrong! E dimension not consistent\n"); + + static_assert(E1 % EPerThreadLoop == 0, ""); + static_assert(KPerThread % KPerThreadLoop == 0, ""); + + static_assert(KPerBlock % KPerThread == 0 && HoPerBlock % HoPerThread == 0 && + WoPerBlock % WoPerThread == 0, + "wrong! Cannot evenly divide work among\n"); + + constexpr auto KThreadCluster = KPerBlock / KPerThread; + constexpr auto HThreadCluster = HoPerBlock / HoPerThread; + constexpr auto WThreadCluster = WoPerBlock / WoPerThread; + + static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster, + "wrong! wrong blocksize\n"); + } + + __device__ static constexpr auto GetCThreadDesc_K_N_Ho_WoLengths() + { + return Sequence{}; + } + + __device__ static CIndex GetBeginOfCThreadDesc_K_N_Ho_Wo(index_t thread_id) + { + constexpr auto K0 = KPerBlock / KPerThread; + constexpr auto N0 = I1; + constexpr auto H0 = HoPerBlock / HoPerThread; + constexpr auto W0 = WoPerBlock / WoPerThread; + + constexpr auto c_threadid_to_k_n_h_w_thread_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto c_k_n_h_w_thread_cluster_idx = + c_threadid_to_k_n_h_w_thread_cluster_adaptor.CalculateBottomIndex( + make_multi_index(thread_id)); + + return c_k_n_h_w_thread_cluster_idx; + } + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BThreadBuffer& b_thread_buf, + CThreadBuffer& c_thread_buf) const + { + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + constexpr auto a_block_mtx = ABlockDesc_E1_K1_E2{}; + + // thread A buffer for GEMM + StaticBuffer + a_thread_buf; + + constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3{}; + + static_for<0, E1, EPerThreadLoop>{}([&](auto e_begin) { + static_for<0, KPerThread, KPerThreadLoop>{}([&](auto k_begin) { + a_thread_copy_.Run(a_block_mtx, + make_tuple(e_begin, k_begin, I0), + a_block_buf, + a_thread_mtx_, + make_tuple(I0, I0, I0), + a_thread_buf); + + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(e_begin, I0, I0, I0, I0), + c_thread_buf, + make_tuple(k_begin, I0, I0, I0)); + }); + }); + } + + template + __device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx) + { + a_thread_copy_.MoveSrcSliceWindow(ABlockDesc_E1_K1_E2{}, a_block_slice_move_step_idx); + } + + private: + using AThreadCopy = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + E2, + E2>; + + CIndex c_thread_origin_data_idx_; + + AThreadCopy a_thread_copy_; +}; + +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp new file mode 100644 index 00000000000..a989cb5297a --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -0,0 +1,585 @@ +#pragma once +#include "common_header.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "xdlops_gemm.hpp" +#include "tensor_adaptor.hpp" +#include "thread_group.hpp" + +namespace ck { + +enum struct LoopScheduler +{ + Default, + Interwave, +}; + +constexpr LoopScheduler make_default_loop_scheduler() +{ +#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING + return LoopScheduler::Interwave; +#else + return LoopScheduler::Default; +#endif // if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING +} + +template +struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = XdlopsGemm{}; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); + static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + static_for<0, KPerThread, KPack>{}([&](auto k) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, KPerThread] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPerThread] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; +}; + +// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro +// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in +// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the +// default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0 +template +struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 + : public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 +{ + using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + +#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING + using Base::a_block_desc_m0_m1_m2_k; + using Base::A_K1; + using Base::b_block_desc_n0_n1_n2_k; + using Base::B_K1; + using Base::c_thread_buf_; + using Base::c_thread_desc_; + using Base::CalculateAThreadOriginDataIndex; + using Base::CalculateBThreadOriginDataIndex; + using Base::I0; + using Base::I1; + using Base::KPerThread; + using Base::xdlops_gemm; + + static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); + + // 2-wave optimized blockwise gemm + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, k), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, k), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, I0, I0), + b_thread_buf); + }); + __builtin_amdgcn_sched_barrier(); + // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except + // the first, as we can shorten non-MAC cluster a bit and there's no observable negative + // impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids + // some out-of-sync waves hijacking MAC resource from other workgroups and reducing the + // chance of latency hiding by waiting for the rest of the workgroup at the eventual + // sync point. + if constexpr(k.value != 0 || KPerInnerLoop == KPerThread) + { + asm volatile("s_barrier" ::); + __builtin_amdgcn_sched_barrier(); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from blockwise_gemm is + // moved here B) reduce VMEM FIFO congestion by applying small delays to + // different wavefronts It is performed near the end of MAC cluster to + // minimize lgkmcnt penalty + if constexpr(k.value == KPerThread - KPerInnerLoop && + k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 && + n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(); + } + + // TODO: insert setprio in more precise manner since we + // could have more than >1 MFMA instructions in single call + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(); + }); + } + + protected: + // A[M0, M1, M2, KPerInnerLoop] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{})); + + // B[N0, N1, N2, KPerInnerLoop] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; + +#endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING +}; + +template +constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() +{ + if constexpr(LoopSched == LoopScheduler::Default) + { + return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + } + else if constexpr(LoopSched == LoopScheduler::Interwave) + { + return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + } +}; + +} // namespace ck diff --git a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp similarity index 88% rename from composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp rename to include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp index 6b2d2d52319..e8ec1643640 100644 --- a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp @@ -1,11 +1,11 @@ -#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP -#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP +#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP +#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP #include "common_header.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "cluster_descriptor.hpp" -#include "threadwise_tensor_slice_transfer_v2.hpp" +#include "threadwise_tensor_slice_transfer_v5r1.hpp" namespace ck { @@ -14,7 +14,7 @@ namespace ck { // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate template -struct BlockwiseTensorSliceTransfer_v4r1 +struct BlockwiseTensorSliceTransfer_v5r1 { static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); using Index = MultiIndex; - __device__ constexpr BlockwiseTensorSliceTransfer_v4r1(const SrcDesc& src_desc, + __device__ constexpr BlockwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc, const Index& src_block_slice_origin, const DstDesc& dst_desc, const Index& dst_block_slice_origin) @@ -45,8 +45,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 src_desc, make_zero_multi_index(), dst_desc, make_zero_multi_index()) { - static_assert(nDim == remove_reference_t>::GetNumOfDimension() && - nDim == remove_reference_t>::GetNumOfDimension() && + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() && nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterArrangeOrder::Size() && @@ -75,14 +75,13 @@ struct BlockwiseTensorSliceTransfer_v4r1 } } - template - __device__ void - RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) { if(BlockSize == thread_cluster_desc_.GetElementSize() or get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) { - threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks); + threadwise_transfer_.RunRead(src_desc, src_buf); } } @@ -134,7 +133,7 @@ struct BlockwiseTensorSliceTransfer_v4r1 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); using ThreadwiseTransfer = - ThreadwiseTensorSliceTransfer_v3r1 +struct PartitionedBlockwiseReduction +{ + static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), + "The product of cluster lengths should be same as BlockSize!"); + + static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0); + static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1); + + static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements"); + + static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using Accumulation = detail::AccumulateWithNanCheck; + + template + __device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value) + { + static_assert(is_same{}, + "Buffer data type should be consistent as AccDataType!"); + + constexpr auto cluster_len_shift = get_shift(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); + + const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; + const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; + + work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value; + + __syncthreads(); + + static_for<0, cluster_len_shift, 1>{}([&](auto I) { + constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I()); + + if(thread_k_cluster_id < indOffset) + { + index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx); + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + + make_tuple(0, indOffset)); + + AccDataType opData1 = work_buffer[offset1]; + AccDataType opData2 = work_buffer[offset2]; + Accumulation::Calculate(opData1, opData2); + work_buffer(offset1) = opData1; + } + + __syncthreads(); + }); + + index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0)); + + in_out_value = work_buffer[offset]; + }; +}; + +// clang-format off +// Assume: +// 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data +// 2) work_val_buffer/work_idx_buffer has AccDataType/IndexDataType elements, and space size is no less than BlockSize +// 3) in_out_value/in_out_index is the input data in vgpr from each thread +// 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread +// clang-format on +template +struct PartitionedBlockwiseReductionWithIndex +{ + static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), + "The product of cluster lengths should be same as BlockSize!"); + + static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0); + static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1); + + static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements"); + + static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using Accumulation = + detail::AccumulateWithIndexAndNanCheck; + + // This interface accumulates on both data values and indices + template + __device__ static void Reduce(BufferType& work_val_buffer, + IdxBufferType& work_idx_buffer, + AccDataType& in_out_value, + IndexDataType& in_out_index) + { + static_assert(is_same{}, + "Buffer data type should be consistent as AccDataType!"); + static_assert(is_same{}, + "Buffer data type should be consistent as IndexDataType!"); + + constexpr auto cluster_len_shift = get_shift(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); + + const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; + const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; + + work_val_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value; + work_idx_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_index; + + __syncthreads(); + + static_for<0, cluster_len_shift, 1>{}([&](auto I) { + constexpr index_t indOffset = 1 << I(); + + if(thread_k_cluster_id % (indOffset * 2) == 0) + { + index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx); + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + + make_tuple(0, indOffset)); + + AccDataType opData1 = work_val_buffer[offset1]; + AccDataType opData2 = work_val_buffer[offset2]; + IndexDataType currIndex1 = work_idx_buffer[offset1]; + IndexDataType currIndex2 = work_idx_buffer[offset2]; + + Accumulation::Calculate(opData1, opData2, currIndex1, currIndex2); + work_val_buffer(offset1) = opData1; + work_idx_buffer(offset1) = currIndex1; + } + + __syncthreads(); + }); + + index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0)); + + in_out_value = work_val_buffer[offset]; + in_out_index = work_idx_buffer[offset]; + }; +}; + +}; // end of namespace ck + +#endif diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp new file mode 100644 index 00000000000..cbabbaf47df --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp @@ -0,0 +1,169 @@ +#pragma once +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "cluster_descriptor.hpp" +#include "threadwise_tensor_slice_transfer_v3r1.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct ThreadGroupTensorSliceTransfer_v4r1 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const SrcElementwiseOperation& src_element_op, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const DstElementwiseOperation& dst_element_op) + : threadwise_transfer_(src_desc, + make_zero_multi_index(), + src_element_op, + dst_desc, + make_zero_multi_index(), + dst_element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id); + } + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id) + { + RunRead(src_desc, src_buf, thread_scratch_id); + RunWrite(dst_desc, dst_buf, thread_scratch_id); + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v3r1; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp new file mode 100644 index 00000000000..1f0ad3e35af --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp @@ -0,0 +1,130 @@ +#pragma once +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "cluster_descriptor.hpp" +#include "threadwise_tensor_slice_transfer_v6r1.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct ThreadGroupTensorSliceTransfer_v6r1 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r1(const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_desc, + make_zero_multi_index(), + dst_desc, + make_zero_multi_index(), + element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v6r1; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp new file mode 100644 index 00000000000..121ddf12ad9 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp @@ -0,0 +1,154 @@ +#pragma once +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "cluster_descriptor.hpp" +#include "threadwise_tensor_slice_transfer_v6r2.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. It does not keep reference to tensor descriptor +// 3. Run() does not construct new tensor coordinate +template +struct ThreadGroupTensorSliceTransfer_v6r2 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, + const Index& src0_block_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src0_desc, + make_zero_multi_index(), + src1_desc, + make_zero_multi_index(), + dst_desc, + make_zero_multi_index(), + element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrc0SliceOrigin( + src0_desc, src0_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetSrc1SliceOrigin( + src1_desc, src1_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf); + } + } + + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); + } + } + + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v6r2; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp new file mode 100644 index 00000000000..ca5db90f307 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp @@ -0,0 +1,179 @@ +#pragma once +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "cluster_descriptor.hpp" +#include "threadwise_tensor_slice_transfer_v6r3.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct ThreadGroupTensorSliceTransfer_v6r3 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, + const Index& src0_block_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_block_slice_origin, + const Src2Desc& src2_desc, + const Index& src2_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src0_desc, + make_zero_multi_index(), + src1_desc, + make_zero_multi_index(), + src2_desc, + make_zero_multi_index(), + dst_desc, + make_zero_multi_index(), + element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrc0SliceOrigin( + src0_desc, src0_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetSrc1SliceOrigin( + src1_desc, src1_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetSrc2SliceOrigin( + src2_desc, src2_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const Src2Desc& src2_desc, + const Src2Buffer& src2_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run( + src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf); + } + } + + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); + } + } + + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); + } + } + + __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v6r3; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp new file mode 100644 index 00000000000..eae1bf9f8ee --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp @@ -0,0 +1,17 @@ +#ifndef CONVOLUTION_BACKWARD_DATA_SPECIALIZATION +#define CONVOLUTION_BACKWARD_DATA_SPECIALIZATION + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct ConvolutionBackwardDataSpecialization +{ + Default, + Filter1x1Stride1Pad0, +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp new file mode 100644 index 00000000000..60995e068ce --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp @@ -0,0 +1,17 @@ +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct ConvolutionBackwardWeightSpecialization +{ + Default, + Filter1x1Stride1Pad0, + Filter1x1Pad0, + OddC, +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp new file mode 100644 index 00000000000..c9eaf64d667 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -0,0 +1,33 @@ +#ifndef CONVOLUTION_FORWARD_SPECIALIZATION +#define CONVOLUTION_FORWARD_SPECIALIZATION + +#include + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct ConvolutionForwardSpecialization +{ + Default, + Filter1x1Pad0, + Filter1x1Stride1Pad0, + OddC, +}; + +inline std::string getConvFwdSpecializationStr(const ConvolutionForwardSpecialization& s) +{ + switch(s) + { + case ConvolutionForwardSpecialization::Default: return "Default"; + case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0"; + case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; + case ConvolutionForwardSpecialization::OddC: return "OddC"; + default: return "Unrecognized specialization!"; + } +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp new file mode 100644 index 00000000000..9bc3cb1a02f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -0,0 +1,50 @@ +#pragma once + +#include + +#include "stream_config.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct BaseArgument +{ + BaseArgument() = default; + BaseArgument(const BaseArgument&) = default; + BaseArgument& operator=(const BaseArgument&) = default; + + virtual ~BaseArgument() {} +}; + +struct BaseInvoker +{ + BaseInvoker() = default; + BaseInvoker(const BaseInvoker&) = default; + BaseInvoker& operator=(const BaseInvoker&) = default; + + virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{}) + { + return float{0}; + } + + virtual ~BaseInvoker() {} +}; + +struct BaseOperator +{ + BaseOperator() = default; + BaseOperator(const BaseOperator&) = default; + BaseOperator& operator=(const BaseOperator&) = default; + + virtual bool IsSupportedArgument(const BaseArgument*) { return false; } + virtual std::string GetTypeString() const { return ""; } + + virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } + + virtual ~BaseOperator() {} +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp new file mode 100644 index 00000000000..6b3c2bf9c40 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -0,0 +1,923 @@ +#pragma once +#include +#include +#include "device.hpp" +#include "device_gemm_reduce.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" +#include "gemm_specialization.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + DPtrsGlobal p_ds_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const DxsInElementwiseOperation dxs_in_element_op, + const DxsOutElementwiseOperation dxs_out_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock, + const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); + + static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) { + const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In))); + p_ds_grid(In) = p_ds_grid(In) + d_batch_offset; + }); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_ds_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + dxs_in_element_op, + dxs_out_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + d_grid_desc_mblock_mperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_ds_grid; + ignore = batch_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = dxs_in_element_op; + ignore = dxs_out_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = d_grid_desc_mblock_mperblock; + ignore = compute_base_ptr_of_batch_; + ignore = block_2_ctile_map; +#endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__)) +} + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce +{ + using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + // assume D is packed tensor + static auto MakeDGridDescriptor_M(index_t MRaw) + { + const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor(d_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return d_grid_desc_mraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using DGridDesc_M = decltype(MakeDGridDescriptor_M(1)); + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideD) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideC_(BatchStrideC), + BatchStrideD_(BatchStrideD) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + template + __host__ __device__ constexpr long_index_t GetDBasePtr(index_t g_idx, + Number reduction_idx) const + { + // TODO - Support sequence of StrideD in MakeArgument() + (void)reduction_idx; + return g_idx * static_cast(BatchStrideD_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; + index_t BatchStrideD_; + }; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + ReduceAccDataType, + DPtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DxsReduceOperation, + DxsInElementwiseOperation, + DxsOutElementwiseOperation, + InMemoryDataOperationEnum::Set, + DGlobalMemoryDataOperation, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + DGridDesc_M, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + DPtrsGlobal p_ds_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + DxsInElementwiseOperation dxs_in_element_op, + DxsOutElementwiseOperation dxs_out_element_op, + index_t BatchCount) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_ds_grid_{p_ds_grid}, + BatchCount_(BatchCount), + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + d_grid_desc_mblock_mperblock_{}, + compute_base_ptr_of_batch_{ + type_convert(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()), + type_convert(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()), + type_convert(c_grid_desc_m_n_.GetElementSpaceSize()), + type_convert(d_grid_desc_m_.GetElementSpaceSize())}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + dxs_in_element_op_{dxs_in_element_op}, + dxs_out_element_op_{dxs_out_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + d_grid_desc_mblock_mperblock_ = + GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + DPtrsGlobal p_ds_grid_; + index_t BatchCount_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + DGridDesc_M d_grid_desc_m_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_; + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + DxsInElementwiseOperation dxs_in_element_op_; + DxsOutElementwiseOperation dxs_out_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.BatchCount_ = " << arg.BatchCount_ << std::endl; + + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}" + << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float elapsed_time = 0.0f; + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + DPtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DxsInElementwiseOperation, + DxsOutElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, + ComputeBasePtrOfStridedBatch, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_ds_grid_, + arg.BatchCount_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.dxs_in_element_op_, + arg.dxs_out_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.d_grid_desc_mblock_mperblock_, + arg.compute_base_ptr_of_batch_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + DPtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DxsInElementwiseOperation, + DxsOutElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, + ComputeBasePtrOfStridedBatch, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_ds_grid_, + arg.BatchCount_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.dxs_in_element_op_, + arg.dxs_out_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.d_grid_desc_mblock_mperblock_, + arg.compute_base_ptr_of_batch_, + arg.block_2_ctile_map_); + } + + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + auto casted_p_arg = dynamic_cast(p_arg); + if(casted_p_arg == nullptr) + { + return false; + } + else + { + return IsSupportedArgument(*casted_p_arg); + } + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + DPtrsGlobal p_dxs, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + DxsInElementwiseOperation dxs_in_element_op, + DxsOutElementwiseOperation dxs_out_element_op, + index_t BatchCount) + { + return Argument{p_a, + p_b, + p_c, + p_dxs, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + dxs_in_element_op, + dxs_out_element_op, + BatchCount}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + DPtrsGlobal p_dxs, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + DxsInElementwiseOperation dxs_in_element_op, + DxsOutElementwiseOperation dxs_out_element_op, + index_t BatchCount) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + p_dxs, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + dxs_in_element_op, + dxs_out_element_op, + BatchCount); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmReduce_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp new file mode 100644 index 00000000000..d1ffa9df147 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp @@ -0,0 +1,619 @@ +#ifndef DEVICE_BATCHED_GEMM_XDL_HPP +#define DEVICE_BATCHED_GEMM_XDL_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_gemm.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link + * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of + * pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_xdlops_v2r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const index_t batch_count, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct DeviceBatchedGemmXdl + : public DeviceGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + const auto a_grid_desc_k0_mp_k1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_k0_mp_k1; + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + const auto b_grid_desc_k0_np_k1 = + transform_tensor_descriptor(b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_k0_np_k1; + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + const auto c_grid_desc_mp_np = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return c_grid_desc_mp_np; + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; + }; + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); + using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t BatchCount) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + BatchCount_(BatchCount), + a_grid_desc_k0_m_k1_{ + DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)}, + b_grid_desc_k0_n_k1_{ + DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, + c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + compute_ptr_offset_of_batch_{ + type_convert(a_grid_desc_k0_m_k1_.GetElementSpaceSize()), + type_convert(b_grid_desc_k0_n_k1_.GetElementSpaceSize()), + type_convert(c_grid_desc_m_n_.GetElementSpaceSize())}, + block_2_ctile_map_{ + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + index_t BatchCount_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceBatchedGemmXdl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_batched_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.BatchCount_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_batched_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.BatchCount_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t BatchCount) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op, + BatchCount}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t BatchCount) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op, + BatchCount); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmXdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp new file mode 100644 index 00000000000..8955aadc110 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp @@ -0,0 +1,193 @@ +#pragma once +#include +#include + +#include "device.hpp" +#include "device_base.hpp" +#include "gridwise_binary_elementwise_1d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBinaryElementwise : public BaseOperator +{ + static constexpr auto I0 = Number<0>{}; + + template + static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize) + { + const auto m0 = desc_m0.GetLength(I0); + const index_t loop_step = gridSize * blockSize * ScalarPerVector; + const auto pad = math::integer_least_multiple(m0, loop_step) - m0; + const auto desc_m0_pad = + transform_tensor_descriptor(desc_m0, + make_tuple(make_right_pad_transform(m0, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m0_pad; + } + + static auto MakeDescriptor_M0(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(Dim > 1) + { + const auto desc_m0 = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize); + } + else + return PadDescriptor_M0_1d(desc, gridSize, blockSize); + } + + using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1)); + using GridwiseBinEltwise = GridwiseBinaryElementwise_1D; + + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const std::vector& shape, + const std::vector& stride_a, + const std::vector& stride_b, + const std::vector& stride_c, + ElementwiseFunctor functor) + : p_a_(p_a), + p_b_(p_b), + p_c_(p_c), + shape_(shape), + functor_(functor), + blockSize_(256), + gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future + { + a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_); + b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_); + c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, blockSize_); + } + + const ADataType* p_a_; + const BDataType* p_b_; + CDataType* p_c_; + std::vector shape_; + GridDesc_M0 a_grid_desc_m0_; + GridDesc_M0 b_grid_desc_m0_; + GridDesc_M0 c_grid_desc_m0_; + ElementwiseFunctor functor_; + index_t blockSize_; + index_t gridSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel = kernel_binary_elementwise_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize_), + dim3(arg.blockSize_), + 0, + arg.p_a_, + arg.p_b_, + arg.p_c_, + arg.a_grid_desc_m0_, + arg.b_grid_desc_m0_, + arg.c_grid_desc_m0_, + arg.functor_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg == nullptr) + return false; + + if(pArg->shape_.back() % ScalarPerVector != 0) + return false; + + return true; + }; + + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + std::vector shape, + std::vector stride_a, + std::vector stride_b, + std::vector stride_c, + ElementwiseFunctor functor) + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + shape, + stride_a, + stride_b, + stride_c, + functor); + } + + std::unique_ptr MakeInvokerPointer() { return std::make_unique(); } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBinaryElementwise" + << "<" + << "ScalarPerVector = " << ScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp new file mode 100644 index 00000000000..8404f4c266e --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,776 @@ +#ifndef DEVICE_CONV2D_WRW_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP +#define DEVICE_CONV2D_WRW_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv_backward_weight.hpp" +#include "convolution_forward_specialization.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_bwd_weight.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvBwdWeight +{ + using DeviceOp = + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static constexpr auto N1Number = K1Number; + + // Bytes per 32 lds bank: 32 * 4 bytes + static constexpr auto BankLength = 128; + static constexpr auto ElePerBank = BankLength / sizeof(ADataType); + + // M1 & M0 + static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1; + static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock; + static constexpr auto ABlockLdsM1Padding = 4; + + // N1 & N0 + static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1; + static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; + static constexpr auto BBlockLdsN1Padding = 4; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t batch_k) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmKTotal = N * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * X * Y; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + // A: output tensor + const index_t N0 = N / N1Number; + const index_t GemmK0Total = N0 * Ho * Wo; + + const index_t GemmK0S = + math::integer_divide_ceil(GemmK0Total, K0PerBlock * GemmKBatch) * K0PerBlock; + const index_t GemmK0Pad = GemmKBatch * GemmK0S; + const auto out_n_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Ho * Wo, K)); + + const auto out_n0_ho_wo_k_n1_grid_desc = + transform_tensor_descriptor(out_n_ho_wo_k_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)), + make_pass_through_transform(Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_gemmk0total_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(out_n0_ho_wo_k_n1_grid_desc, + make_tuple(make_merge_transform(make_tuple(N0, Ho * Wo)), + make_pass_through_transform(K), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_gemmk0pad_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmk0total_gemmm_gemmk1_grid_desc, + make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total), + make_pass_through_transform(GemmM), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmk0pad_gemmm_gemmk1_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)), + make_pass_through_transform(GemmM), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n0_y_ho_x_wo_c_n1_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)), + make_pass_through_transform(Y), + make_pass_through_transform(Ho), + make_pass_through_transform(X), + make_pass_through_transform(Wo), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 6>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto in_gemmk0total_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_n0_y_ho_x_wo_c_n1_grid_desc, + make_tuple(make_merge_transform(make_tuple(N0, Ho, Wo)), + make_merge_transform(make_tuple(Y, X, C)), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmk0pad_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0total_gemmn_gemmk1_grid_desc, + make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total), + make_pass_through_transform(GemmN), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0pad_gemmn_gemmk1_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)), + make_pass_through_transform(GemmN), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1)); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + + using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + + using Block2CTileMap = + decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + c_element_op_{wei_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + output_spatial_lengths_{output_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation a_element_op_; + OutElementwiseOperation b_element_op_; + WeiElementwiseOperation c_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector output_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + ShowInfo(arg); + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight has invalid setting"); + } + const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + hipGetErrorString(hipMemset( + arg.p_c_grid_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(CDataType))); + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + if(has_main_k0_block_loop) + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + } + else + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // unmerge N to N0 and N1, where N1 equals to K1 + if(!(arg.Conv_N_ % K1 == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp new file mode 100644 index 00000000000..83953e59bd9 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,824 @@ +#ifndef DEVICE_CONV2D_BWD_DATA_XDL_NHWC_KYXC_NHWK_HPP +#define DEVICE_CONV2D_BWD_DATA_XDL_NHWC_KYXC_NHWK_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv_bwd_data.hpp" +#include "convolution_backward_data_specialization.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvBwdData +{ + using DeviceOp = DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = OutDataType; + using BDataType = WeiDataType; + using CDataType = InDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static_assert((K1 % ABlockTransferThreadClusterLengths_K0_M_K1{}[I2]) % + ABlockTransferSrcScalarPerVector == + 0); + static_assert((NPerBlock / BBlockTransferThreadClusterLengths_K0_N_K1{}[I1]) % + BBlockTransferSrcScalarPerVector == + 0); + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + index_t i_ytilde, + index_t i_xtilde) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const auto K0 = K / K1; + + const auto out_n_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + const auto wei_k_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 0, 0)); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_in_grid}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{wei_element_op}, + c_element_op_{in_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const index_t Y = filter_spatial_lengths_[0]; + const index_t X = filter_spatial_lengths_[1]; + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + i_ytilde, + i_xtilde); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + auto block_2_ctile_map = + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01); + + if(GridwiseGemm::CheckValidity( + descs[I0], descs[I1], descs[I2], block_2_ctile_map)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); + + block_2_ctile_map_container_.push_back(block_2_ctile_map); + } + } + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + std::vector a_grid_desc_k0_m_k1_container_; + std::vector b_grid_desc_k0_n_k1_container_; + std::vector c_grid_desc_m_n_container_; + std::vector + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_; + std::vector block_2_ctile_map_container_; + index_t M01_; + index_t N01_; + OutElementwiseOperation a_element_op_; + WeiElementwiseOperation b_element_op_; + InElementwiseOperation c_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_container_{" + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_container_{" + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m_n_container_{ " + << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " + << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I0) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I1) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I2) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I3) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I4) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5) + << " ) " << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i], + arg.block_2_ctile_map_container_[i])) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( + arg.c_grid_desc_m_n_container_[i]); + + const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * + arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, + OutElementwiseOperation, + WeiElementwiseOperation, + InElementwiseOperation, + remove_reference_t, + true>; + + ave_time += launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_container_[i]); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, + OutElementwiseOperation, + WeiElementwiseOperation, + InElementwiseOperation, + remove_reference_t, + false>; + + ave_time += launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_container_[i]); + } + } + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 1 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i], + arg.block_2_ctile_map_container_[i])) + { + return false; + } + } + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(void* p_in_grid, + const void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp new file mode 100644 index 00000000000..85063443c17 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,973 @@ +#ifndef DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_ADD_NHWC_KYXC_NHWK_HPP +#define DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_ADD_NHWC_KYXC_NHWK_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv_fwd_bias_activation_add.hpp" +#include "convolution_forward_specialization.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v3r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = +// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) + residual[N, Ho, Wo, K] +template < + typename InDataType, + typename WeiDataType, + typename OutDataType, + typename AccDataType, + typename InElementwiseOperation, + typename WeiElementwiseOperation, + typename OutElementwiseOperation, + ConvolutionForwardSpecialization ConvForwardSpecialization, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXDL, + ck::index_t NPerXDL, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwdBiasActivationAdd +{ + using DeviceOp = + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + // TODO make it support any # of spatial dimensions + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + + const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock); + const auto GemmMPad = GemmM - GemmMRaw; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { // 1x1, stride=1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { // 1x1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC) + { // C = odd value + const index_t GemmKRaw = Y * X * C; + const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); + const index_t GemmKPad = GemmK - GemmKRaw; + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmkraw_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmkraw_gemmmraw_grid_desc, + make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), + make_right_pad_transform(GemmKRaw, GemmKPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + else + { + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + } + + using GridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using C0GridDesc_M_N = remove_cvref_t; + using C1GridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + C0GridDesc_M_N, + C1GridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + const OutDataType* p_resi_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + p_c0_grid_{p_bias_grid}, + p_c1_grid_{p_resi_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c0_grid_desc_m_n_{}, + c1_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{ + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + c0_grid_desc_m_n_ = descs[I3]; + c1_grid_desc_m_n_ = descs[I4]; + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c0_grid_desc_m_n_); + + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c1_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const CDataType* p_c0_grid_; + const CDataType* p_c1_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_M_N c0_grid_desc_m_n_; + C1GridDesc_M_N c1_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << DeviceOp{}.GetTypeString() << std::endl; + std::cout << "N " << arg.Conv_N_ << ", " + << "K " << arg.Conv_K_ << ", " + << "C " << arg.Conv_C_ << ", " << std::endl; + std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", " + << arg.filter_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", " + << arg.input_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", " + << arg.output_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Strides " << arg.conv_filter_strides_[0] << ", " + << arg.conv_filter_strides_[1] << ", " << std::endl; + std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", " + << arg.conv_filter_dilations_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", " + << arg.input_left_pads_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " + << arg.input_right_pads_[1] << ", " << std::endl; + } + + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v3r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.p_c1_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.p_c1_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + const OutDataType* p_resi_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + p_bias_grid, + p_resi_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + const void* p_bias_grid, + const void* p_resi_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + static_cast(p_bias_grid), + static_cast(p_resi_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp new file mode 100644 index 00000000000..a397b5e2b13 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,919 @@ +#pragma once +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv_fwd_bias_activation.hpp" +#include "convolution_forward_specialization.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v3r2.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = +// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) +template < + typename InDataType, + typename WeiDataType, + typename OutDataType, + typename AccDataType, + typename InElementwiseOperation, + typename WeiElementwiseOperation, + typename OutElementwiseOperation, + InMemoryDataOperationEnum OutGlobalMemoryDataOperation, + ConvolutionForwardSpecialization ConvForwardSpecialization, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXDL, + ck::index_t NPerXDL, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwdBiasActivation +{ + using DeviceOp = + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + // TODO make it support any # of spatial dimensions + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + + const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock); + const auto GemmMPad = GemmM - GemmMRaw; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { // 1x1, stride=1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { // 1x1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC) + { // C = odd value + const index_t GemmKRaw = Y * X * C; + const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); + const index_t GemmKPad = GemmK - GemmKRaw; + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmkraw_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmkraw_gemmmraw_grid_desc, + make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), + make_right_pad_transform(GemmKRaw, GemmKPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + else + { + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using C0GridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + OutGlobalMemoryDataOperation, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + C0GridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + p_c0_grid_{p_bias_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c0_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + c0_grid_desc_m_n_ = descs[I3]; + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c0_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const CDataType* p_c0_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_M_N c0_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << DeviceOp{}.GetTypeString() << std::endl; + std::cout << "N " << arg.Conv_N_ << ", " + << "K " << arg.Conv_K_ << ", " + << "C " << arg.Conv_C_ << ", " << std::endl; + std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", " + << arg.filter_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", " + << arg.input_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", " + << arg.output_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Strides " << arg.conv_filter_strides_[0] << ", " + << arg.conv_filter_strides_[1] << ", " << std::endl; + std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", " + << arg.conv_filter_dilations_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", " + << arg.input_left_pads_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " + << arg.input_right_pads_[1] << ", " << std::endl; + } + + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r2 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v3r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + p_bias_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + const void* p_bias_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + static_cast(p_bias_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp new file mode 100644 index 00000000000..f29e59039ed --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,891 @@ +#ifndef DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP +#define DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv_fwd.hpp" +#include "convolution_forward_specialization.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v3r1.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template < + typename InDataType, + typename WeiDataType, + typename OutDataType, + typename AccDataType, + typename InElementwiseOperation, + typename WeiElementwiseOperation, + typename OutElementwiseOperation, + ConvolutionForwardSpecialization ConvForwardSpecialization, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXdl, + ck::index_t NPerXdl, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwd +{ + using DeviceOp = DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + + const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock); + const auto GemmMPad = GemmM - GemmMRaw; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { // 1x1, stride=1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { // 1x1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC) + { // C = odd value + const index_t GemmKRaw = Y * X * C; + const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); + const index_t GemmKPad = GemmK - GemmKRaw; + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmkraw_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmkraw_gemmmraw_grid_desc, + make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), + make_right_pad_transform(GemmKRaw, GemmKPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else + { + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, // TODO: Add ShuffleType for DeviceConv2d + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock * K1, + K1, // AK1 + K1, // BK1 + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + c_grid_desc_m_n_ = descs[I2]; + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << DeviceOp{}.GetTypeString() << std::endl; + std::cout << "N " << arg.Conv_N_ << ", " + << "K " << arg.Conv_K_ << ", " + << "C " << arg.Conv_C_ << ", " << std::endl; + std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", " + << arg.filter_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", " + << arg.input_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", " + << arg.output_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Strides " << arg.conv_filter_strides_[0] << ", " + << arg.conv_filter_strides_[1] << ", " << std::endl; + std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", " + << arg.conv_filter_dilations_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", " + << arg.input_left_pads_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " + << arg.input_right_pads_[1] << ", " << std::endl; + } + + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout + << "arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_" + "nwavenperxdl_{ " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I0) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I1) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I2) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I3) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I4) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I5) + << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v3r1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvFwdSpecializationStr(ConvForwardSpecialization) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp new file mode 100644 index 00000000000..ece18459a0c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,722 @@ +#ifndef DEVICE_CONV2D_FWD_XDL_NHWC_KYXC_NHWK_HPP +#define DEVICE_CONV2D_FWD_XDL_NHWC_KYXC_NHWK_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv_fwd.hpp" +#include "convolution_forward_specialization.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwd +{ + using DeviceOp = DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + const index_t GemmK = Y * X * C; + + const auto GemmMPad = math::integer_least_multiple(GemmMRaw, MPerBlock) - GemmMRaw; + + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else + { + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvFwdSpecializationStr(ConvForwardSpecialization) + << ">"; + // clang-format on + + return str.str(); + } +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp new file mode 100644 index 00000000000..b1eea0b33f3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp @@ -0,0 +1,265 @@ +#ifndef DEVICE_CONV3D_FWD_NAIVE_HPP +#define DEVICE_CONV3D_FWD_NAIVE_HPP + +#include +#include +#include +#include "conv_util.hpp" +#include "device.hpp" +#include "device_conv_fwd.hpp" +#include "common_header.hpp" +#include "naive_conv_fwd.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] +template +struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K + : public DeviceConvFwd + +{ + using DeviceOp = DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + // TODO make A/B datatype different + using ABDataType = InDataType; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in, + const WeiDataType* p_wei, + OutDataType* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : params_{3, + N, + K, + C, + filter_spatial_lengths, + input_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, + out_spatial_lengths_{output_spatial_lengths}, + p_in_{p_in}, + p_wei_{p_wei}, + p_out_{p_out}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + + { + } + + // private: + utils::conv::ConvParams params_; + std::vector out_spatial_lengths_; + + const InDataType* p_in_; + const WeiDataType* p_wei_; + OutDataType* p_out_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto naive_conv3d_fwd = + ref::naive_conv_fwd_ndhwc_kzyxc_ndhwk; + + float ave_time = launch_and_time_kernel(stream_config, + naive_conv3d_fwd, + dim3(256), + dim3(256), + 0, + arg.p_in_, + arg.p_wei_, + arg.p_out_, + arg.N_, + arg.K_, + arg.C_, + arg.in_spatial_lengths_[0], + arg.in_spatial_lengths_[1], + arg.in_spatial_lengths_[2], + arg.filter_spatial_lengths_[0], + arg.filter_spatial_lengths_[1], + arg.filter_spatial_lengths_[2], + arg.out_spatial_lengths_[0], + arg.out_spatial_lengths_[1], + arg.out_spatial_lengths_[2], + arg.conv_filter_strides_[0], + arg.conv_filter_strides_[1], + arg.conv_filter_strides_[2], + arg.conv_filter_dilations_[0], + arg.conv_filter_dilations_[1], + arg.conv_filter_dilations_[2], + arg.in_left_pads_[0], + arg.in_left_pads_[1], + arg.in_left_pads_[2]); + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + std::vector out_spatial_lengths = arg.params_.GetOutputSpatialLengths(); + + bool out_lengths_are_consistent = out_spatial_lengths[0] == arg.out_spatial_lengths_[0] && + out_spatial_lengths[1] == arg.out_spatial_lengths_[1] && + out_spatial_lengths[2] == arg.out_spatial_lengths_[2]; + return out_lengths_are_consistent; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in, + const WeiDataType* p_wei, + OutDataType* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in, + p_wei, + p_out, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + + { + return std::make_unique(static_cast(p_in), + static_cast(p_wei), + static_cast(p_out), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K<>"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp new file mode 100644 index 00000000000..256d0f81e96 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -0,0 +1,639 @@ +#ifndef DEVICE_CONV3D_FWD_XDL_HPP +#define DEVICE_CONV3D_FWD_XDL_HPP + +#include +#include +#include +#include "device.hpp" +#include "device_conv_fwd.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "convolution_forward_specialization.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/* + * \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3() \endlink. + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r3_for_conv3d( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const index_t num_batches, + const index_t a_batch_stride, + const index_t b_batch_stride, + const index_t c_batch_stride, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast(a_batch_stride) * g_idx); + const long_index_t b_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast(b_batch_stride) * g_idx); + const long_index_t c_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast(c_batch_stride) * g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = num_batches; + ignore = a_batch_stride; + ignore = b_batch_stride; + ignore = c_batch_stride; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] +template +struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K + : public DeviceConvFwd + +{ + using DeviceOp = DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + /* + * \brief Split the number of batches, \p N, into N = B * N1, such that the memory + * space of input and output tensors stays with the value range of index_t, and each subbatch + * can be dealed with GridwiseGemm. + */ + static index_t GetMaxAllowableSubBatchSize(const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector output_spatial_lengths) + { + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + // N1 should satisfy that + // 1) N % N1 = 0; + // 2) N1 * (Do * Ho * Wo * K) < (2^31 - 1) + // 3) N1 * (Di * Hi * Wi * C) < (2^31 - 1) + // + // Do NOT confuse (B, N1) in this function with (B, N1) in gridewise GEMM. + auto N1 = N + 1; + + const auto stride = + math::max(long_index_t(Do) * Ho * Wo * K, long_index_t(Di) * Hi * Wi * C); + const index_t max_stride = NumericLimits::Max(); + + for(index_t n0 = 1; n0 <= N; ++n0) + { + index_t n1 = N / n0; + if(n0 * n1 == N && long_index_t(n1) * long_index_t(stride) < max_stride) + { + N1 = n1; + break; + } + } + + const auto B = N / N1; + if(B * N1 != N) + { + throw std::runtime_error(__func__ + + std::string(": failed to find num_subbatches for conv3d.\n")); + } + + return N1; + } + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + assert(input_spatial_lengths.size() > 2); + assert(filter_spatial_lengths.size() > 2); + assert(conv_filter_strides.size() > 2); + assert(conv_filter_dilations.size() > 2); + assert(input_left_pads.size() > 2); + assert(input_right_pads.size() > 2); + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + static_assert(ConvForwardSpecialization == ConvolutionForwardSpecialization::Default, + "Wrong! This specialization not implemented!"); + + const auto in_desc_n_di_hi_wi_c = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + const auto wei_desc_k_z_y_x_c = + make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C)); + const auto out_desc_n_do_ho_wo_k = + make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); + + const auto descs = transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk_pad( + in_desc_n_di_hi_wi_c, + wei_desc_k_z_y_x_c, + out_desc_n_do_ho_wo_k, + make_tuple(conv_filter_strides[0], conv_filter_strides[1], conv_filter_strides[2]), + make_tuple( + conv_filter_dilations[0], conv_filter_dilations[1], conv_filter_dilations[2]), + make_tuple(input_left_pads[0], input_left_pads[1], input_left_pads[2]), + make_tuple(input_right_pads[0], input_right_pads[1], input_right_pads[2]), + Number{}); + + return descs; + } + + using ABCGridDescs = remove_cvref_t; + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + InDataType, + AccDataType, + OutDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 7, + CThreadTransferDstScalarPerVector>; + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); + using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in, + const WeiDataType* p_wei, + OutDataType* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + index_t M01, + index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in}, + p_b_grid_{p_wei}, + p_c_grid_{p_out}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + const index_t subbatch_size = + GetMaxAllowableSubBatchSize(N, K, C, input_spatial_lengths, output_spatial_lengths); + num_subbatches_ = N / subbatch_size; + + const auto descs = + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(subbatch_size, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + a_batch_stride_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize(); + b_batch_stride_ = 0; + c_batch_stride_ = c_grid_desc_m_n_.GetElementSpaceSize(); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + } + } + + // private: + const InDataType* p_a_grid_; + const WeiDataType* p_b_grid_; + OutDataType* p_c_grid_; + index_t num_subbatches_; + index_t a_batch_stride_; + index_t b_batch_stride_; + index_t c_batch_stride_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl; + std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "b_grid_desc_k0_n_k1{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "c_grid_desc_m_n{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * + arg.num_subbatches_; + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v2r3_for_conv3d< + GridwiseGemm, + InDataType, + OutDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.num_subbatches_, + arg.a_batch_stride_, + arg.b_batch_stride_, + arg.c_batch_stride_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3_for_conv3d< + GridwiseGemm, + InDataType, + OutDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.num_subbatches_, + arg.a_batch_stride_, + arg.b_batch_stride_, + arg.c_batch_stride_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in, + const WeiDataType* p_wei, + OutDataType* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in, + p_wei, + p_out, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + + { + return std::make_unique(static_cast(p_in), + static_cast(p_wei), + static_cast(p_out), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp b/include/ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp new file mode 100644 index 00000000000..549cfb26f3d --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp @@ -0,0 +1,47 @@ +#ifndef DEVICE_CONV_WRW_HPP +#define DEVICE_CONV_WRW_HPP + +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvBwdWeight : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + void* p_wei, + const void* p_out, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceConvBwdWeightPtr = std::unique_ptr< + DeviceConvBwdWeight>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp b/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp new file mode 100644 index 00000000000..1d08af1a05e --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp @@ -0,0 +1,47 @@ +#ifndef DEVICE_CONV_BWD_DATA_HPP +#define DEVICE_CONV_BWD_DATA_HPP + +#include +#include "device_base.hpp" +#include "element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvBwdData : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(void* p_in, + const void* p_wei, + const void* p_out, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceConvBwdDataPtr = std::unique_ptr< + DeviceConvBwdData>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp new file mode 100644 index 00000000000..d53e56f18ba --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp @@ -0,0 +1,46 @@ +#ifndef DEVICE_CONV_FWD_HPP +#define DEVICE_CONV_FWD_HPP + +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvFwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceConvFwdPtr = std::unique_ptr< + DeviceConvFwd>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp new file mode 100644 index 00000000000..77d4b7fb95a --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp @@ -0,0 +1,49 @@ +#ifndef DEVICE_CONV_FWD_BIAS_ACTIVATION_HPP +#define DEVICE_CONV_FWD_BIAS_ACTIVATION_HPP + +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvFwdBiasActivation : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + const void* p_bias, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceConvFwdBiasActivationPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp new file mode 100644 index 00000000000..2f8e780b78d --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp @@ -0,0 +1,50 @@ +#ifndef DEVICE_CONV_FWD_BIAS_ACTIVATION_ADD_HPP +#define DEVICE_CONV_FWD_BIAS_ACTIVATION_ADD_HPP + +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvFwdBiasActivationAdd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + const void* p_bias, + const void* p_resi, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceConvFwdBiasActivationAddPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp new file mode 100644 index 00000000000..dde9e0f8739 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,1233 @@ +#pragma once + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv_backward_weight.hpp" +#include "convolution_backward_weight_specialization.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_bwd_weight.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvBwdWeight +{ + using DeviceOp = + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + // Bytes per 32 lds bank: 32 * 4 bytes + static constexpr auto BankLength = 128; + static constexpr auto ElePerBank = BankLength / sizeof(ADataType); + + // M1 & M0 + static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1; + static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock; + static constexpr auto ABlockLdsM1Padding = 4; + + // N1 & N0 + static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1; + static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; + static constexpr auto BBlockLdsN1Padding = 4; + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t batch_k) + { + using namespace ck; + + const index_t Wi = input_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[0]; + const index_t ConvStrideW = conv_filter_strides[0]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + + const index_t GemmKTotal = N * Wo; + const index_t GemmM = K; + const index_t GemmN = C * X; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C)); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); + const auto in_n_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X, C)), + make_merge_transform(make_tuple(N, Wo))), + make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + } + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t batch_k) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmKTotal = N * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * X * Y; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C)); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + } + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t batch_k) + { + using namespace ck; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[2]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t GemmKTotal = N * Do * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * Z * X * Y; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C)); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + } // function end + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, + 1, + 1, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + 1); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + + using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + + using Block2CTileMap = + decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + c_element_op_{wei_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + output_spatial_lengths_{output_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation a_element_op_; + OutElementwiseOperation b_element_op_; + WeiElementwiseOperation c_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector output_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + ShowInfo(arg); + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + hipGetErrorString(hipMemset( + arg.p_c_grid_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(CDataType))); + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + if constexpr(std::is_same::value) + { + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + } + else + { + if(has_main_k0_block_loop) + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + } + else + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + } + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } + + template ::type = false> + static size_t GetWorkSpaceSize(const Argument& arg) + { + size_t WorkSpaceSize = 0; + if(arg.k_batch_ > 1) + { + if constexpr(std::is_same::value) + { + WorkSpaceSize = + arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * sizeof(float); + } + } + return WorkSpaceSize; + } + + template ::type = false> + static size_t GetWorkSpaceSize(const Argument& arg) + { + size_t WorkSpaceSize = 0; + if(arg.k_batch_ > 1) + { + if constexpr(std::is_same::value) + { + WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * + arg.filter_spatial_lengths_[1] * sizeof(float); + } + } + return WorkSpaceSize; + } + + template ::type = false> + static size_t GetWorkSpaceSize(const Argument& arg) + { + size_t WorkSpaceSize = 0; + if(arg.k_batch_ > 1) + { + if constexpr(std::is_same::value) + { + WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * + arg.filter_spatial_lengths_[1] * arg.filter_spatial_lengths_[2] * + sizeof(float); + } + } + return WorkSpaceSize; + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override final + { + return GetWorkSpaceSize(*dynamic_cast(p_arg)); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp new file mode 100644 index 00000000000..0517db44154 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -0,0 +1,1549 @@ +#ifndef DEVICE_CONVND_BWD_DATA_XDL_NDHWC_KZYXC_NDHWK_HPP +#define DEVICE_CONVND_BWD_DATA_XDL_NDHWC_KZYXC_NDHWK_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv_bwd_data.hpp" +#include "convolution_backward_data_specialization.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K + : public DeviceConvBwdData +{ + using DeviceOp = DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K; + + using ADataType = OutDataType; + using BDataType = WeiDataType; + using CDataType = InDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static_assert((K1 % ABlockTransferThreadClusterLengths_K0_M_K1{}[I2]) % + ABlockTransferSrcScalarPerVector == + 0); + static_assert((NPerBlock / BBlockTransferThreadClusterLengths_K0_N_K1{}[I1]) % + BBlockTransferSrcScalarPerVector == + 0); + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector tildes) + { + using namespace ck; + + index_t i_xtilde = tildes[0]; + + const index_t Wi = input_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + const index_t ConvStrideW = conv_filter_strides[0]; + const index_t ConvDilationW = conv_filter_dilations[0]; + + const auto K0 = K / K1; + + const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)), + make_tuple(make_pass_through_transform(N * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto out_n_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wo, K)); + const auto wei_k_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, X, C)); + + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_wop_k_grid_desc = transform_tensor_descriptor( + out_n_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto out_n_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor( + out_n_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)), + make_merge_transform(make_tuple(N, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor( + wei_k_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<>{}, Sequence<3>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 0>{}, Sequence<3>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector tildes) + { + using namespace ck; + + index_t i_ytilde = tildes[0]; + index_t i_xtilde = tildes[1]; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const auto K0 = K / K1; + + const auto out_n_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + const auto wei_k_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector tildes) + { + using namespace ck; + + const index_t i_ztilde = tildes[0]; + const index_t i_ytilde = tildes[1]; + const index_t i_xtilde = tildes[2]; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const auto K0 = K / K1; + + const auto out_n_do_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); + const auto wei_k_z_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C)); + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Do * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<0, 2, 4, 6>{}, + Sequence<7>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = ConvStrideD / GcdStrideDilationD; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto ZDot = math::integer_divide_ceil(Z, ZTilde); + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto DTilde = + Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD); + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IDTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IDTildeSliceEnd = math::min( + DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_do_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Do, I0, I0), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc = + transform_tensor_descriptor( + out_n_dop_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(ZDot, DTilde), + make_tuple(-ConvDilationD / GcdStrideDilationD, I1)), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7, 8>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple( + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc = + transform_tensor_descriptor( + wei_k_z_y_x_c_grid_desc, + make_tuple( + make_pass_through_transform(K), + make_embed_transform(make_tuple(ZDot, ZTilde), + make_tuple(ConvStrideD / GcdStrideDilationD, I1)), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ztilde), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<5>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(ZTilde, DTilde), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ztilde), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<3>{}, + Sequence<4>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0}); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, + 1, + 1, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {0, 0, 0}); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_in_grid}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{wei_element_op}, + c_element_op_{in_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + CreateABCDesc(); + } + + template ::type = false> + void CreateABCDesc() + { + const index_t ConvStrideW = conv_filter_strides_[0]; + const index_t ConvDilationW = conv_filter_dilations_[0]; + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const index_t X = filter_spatial_lengths_[0]; + + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(XDotSlice <= 0) + { + continue; + } + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + {i_xtilde}); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + auto block_2_ctile_map = + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); + + if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], block_2_ctile_map)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); + + block_2_ctile_map_container_.push_back(block_2_ctile_map); + } + } + } + template ::type = false> + void CreateABCDesc() + { + const index_t ConvStrideH = conv_filter_strides_[0]; + const index_t ConvStrideW = conv_filter_strides_[1]; + + const index_t ConvDilationH = conv_filter_dilations_[0]; + const index_t ConvDilationW = conv_filter_dilations_[1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const index_t Y = filter_spatial_lengths_[0]; + const index_t X = filter_spatial_lengths_[1]; + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + {i_ytilde, i_xtilde}); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + auto block_2_ctile_map = + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); + + if(GridwiseGemm::CheckValidity( + descs[I0], descs[I1], descs[I2], block_2_ctile_map)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); + + block_2_ctile_map_container_.push_back(block_2_ctile_map); + } + } + } + } + template ::type = false> + void CreateABCDesc() + { + const index_t ConvStrideD = conv_filter_strides_[0]; + const index_t ConvStrideH = conv_filter_strides_[1]; + const index_t ConvStrideW = conv_filter_strides_[2]; + + const index_t ConvDilationD = conv_filter_dilations_[0]; + const index_t ConvDilationH = conv_filter_dilations_[1]; + const index_t ConvDilationW = conv_filter_dilations_[2]; + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = ConvStrideD / GcdStrideDilationD; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const index_t Z = filter_spatial_lengths_[0]; + const index_t Y = filter_spatial_lengths_[1]; + const index_t X = filter_spatial_lengths_[2]; + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(ZDotSlice * YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< + NumDimSpatial>(Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + {i_ztilde, i_ytilde, i_xtilde}); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + auto block_2_ctile_map = + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); + + if(GridwiseGemm::CheckValidity( + descs[I0], descs[I1], descs[I2], block_2_ctile_map)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2( + descs[I2])); + + block_2_ctile_map_container_.push_back(block_2_ctile_map); + } + } + } + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + std::vector a_grid_desc_k0_m_k1_container_; + std::vector b_grid_desc_k0_n_k1_container_; + std::vector c_grid_desc_m_n_container_; + std::vector + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_; + std::vector block_2_ctile_map_container_; + index_t M01_; + index_t N01_; + OutElementwiseOperation a_element_op_; + WeiElementwiseOperation b_element_op_; + InElementwiseOperation c_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_container_{" + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_container_{" + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m_n_container_{ " + << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " + << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I0) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I1) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I2) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I3) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I4) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I6) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I7) + << " ) " << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i], + arg.block_2_ctile_map_container_[i])) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( + arg.c_grid_desc_m_n_container_[i]); + + const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * + arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, + OutElementwiseOperation, + WeiElementwiseOperation, + InElementwiseOperation, + remove_reference_t, + true>; + + ave_time += launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_container_[i]); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, + OutElementwiseOperation, + WeiElementwiseOperation, + InElementwiseOperation, + remove_reference_t, + false>; + + ave_time += launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_container_[i]); + } + } + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NumDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 1 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i], + arg.block_2_ctile_map_container_[i])) + { + return false; + } + } + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(void* p_in_grid, + const void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){ + + str<< " Filter1x1Stride1Pad0"; + } + + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp new file mode 100644 index 00000000000..f0be2498e7a --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,1033 @@ +#ifndef DEVICE_CONVND_FWD_XDL_NHWC_KYXC_NHWK_HPP +#define DEVICE_CONVND_FWD_XDL_NHWC_KYXC_NHWK_HPP + +#include +#include +#include +#include +#include + +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv_fwd.hpp" +#include "convolution_forward_specialization.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Inputs with up to 3 spatial dimentions +// @li Input tensor in NHWC data format +// @li Weight tensor in KYXC data format +// @li Output tensor in NHWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template +struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwd +{ + using DeviceOp = DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = NumDimSpatial; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto GetWeightTensorDescriptor(ck::index_t gemm_n, ck::index_t gemm_k) + { + const ck::index_t gemm_k0 = gemm_k / GemmK1Number; + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_n, gemm_k)); + + // wei_gemmk0_gemmn_gemmk1_grid_desc + return transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_pass_through_transform(gemm_n)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + static auto + GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n, ck::index_t gemm_m_pad) + { + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n)); + + // out_gemmm_gemmn_grid_desc + return transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(gemm_n)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template ::type = false> + static auto GetInputTensorDescriptor(ck::index_t N, + ck::index_t C, + ck::index_t gemm_m, + ck::index_t gemm_k, + ck::index_t gemm_m_pad, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths, + const std::vector& conv_filter_strides, + const std::vector& conv_filter_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads) + { + const ck::index_t gemm_k0 = gemm_k / GemmK1Number; + const index_t Wi = input_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[0]; + const index_t ConvStrideW = conv_filter_strides[0]; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_right_pad_transform(gemm_m, gemm_m_pad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + const auto in_n_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + const auto in_n_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_merge_transform(make_tuple(N, Wo))), + make_tuple(Sequence<2>{}, Sequence<0, 1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else + { + const index_t X = filter_spatial_lengths[0]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + + const auto in_n_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X, C)), + make_merge_transform(make_tuple(N, Wo))), + make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_pass_through_transform(gemm_m)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + } + + template ::type = false> + static auto GetInputTensorDescriptor(ck::index_t N, + ck::index_t C, + ck::index_t gemm_m, + ck::index_t gemm_k, + ck::index_t gemm_m_pad, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths, + const std::vector& conv_filter_strides, + const std::vector& conv_filter_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads) + { + const ck::index_t gemm_k0 = gemm_k / GemmK1Number; + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_right_pad_transform(gemm_m, gemm_m_pad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else + { + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_pass_through_transform(gemm_m)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + } + + template ::type = false> + static auto GetInputTensorDescriptor(ck::index_t N, + ck::index_t C, + ck::index_t gemm_m, + ck::index_t gemm_k, + ck::index_t gemm_m_pad, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths, + const std::vector& conv_filter_strides, + const std::vector& conv_filter_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads) + { + const ck::index_t gemm_k0 = gemm_k / GemmK1Number; + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_right_pad_transform(gemm_m, gemm_m_pad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_do_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else + { + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_pass_through_transform(gemm_m)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + } + + static index_t GetGemmMRaw(ck::index_t N, + const std::vector& output_spatial_lengths) + { + return N * std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + 1, + std::multiplies()); + } + + static index_t GetGemmK(ck::index_t C, const std::vector& filter_spatial_lengths) + { + return C * std::accumulate(std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths), + 1, + std::multiplies()); + } + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t GemmMRaw = GetGemmMRaw(N, output_spatial_lengths); + const index_t GemmN = K; + const index_t GemmK = GetGemmK(C, filter_spatial_lengths); + + const auto GemmMPad = math::integer_least_multiple(GemmMRaw, MPerBlock) - GemmMRaw; + + assert(GemmK % GemmK1Number == 0); + + // C = A^T*B + // A: + const auto in_gemmk0_gemmm_gemmk1_grid_desc = + GetInputTensorDescriptor(N, + C, + GemmMRaw, + GemmK, + GemmMPad, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + // B: + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = GetWeightTensorDescriptor(GemmN, GemmK); + // C: + const auto out_gemmm_gemmn_grid_desc = GetOutputTensorDescriptor(GemmMRaw, GemmN, GemmMPad); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // Input tensors can't be bigger than 2GB each. + constexpr ck::long_index_t GB2 = (ck::long_index_t{1} << 31); + + if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) > GB2 || + arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) > GB2 || + arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) > GB2) + { + return false; + } + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(ck::index_t i = 0; i < NumDimSpatial; ++i) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(ck::index_t i = 0; i < NumDimSpatial; ++i) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.input_left_pads_[i] == 0 && + arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv" << std::to_string(NumDimSpatial) + << "DFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvFwdSpecializationStr(ConvForwardSpecialization) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_gemm.hpp new file mode 100644 index 00000000000..4576aaa7e03 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm.hpp @@ -0,0 +1,70 @@ +#pragma once +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct GemmShape +{ + ck::index_t M, N, K; + ck::index_t StrideA, StrideB, StrideC; +}; + +template +struct DeviceGemm : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmPtr = std::unique_ptr< + DeviceGemm>; + +template +struct DeviceGroupedGemm : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(std::vector& p_a, + std::vector& p_b, + std::vector& p_c, + std::vector& gemm_shapes, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGroupedGemmPtr = std::unique_ptr< + DeviceGroupedGemm>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp new file mode 100644 index 00000000000..9f5d16a1f9b --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp @@ -0,0 +1,40 @@ +#pragma once +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmBias : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmBiasPtr = std::unique_ptr< + DeviceGemmBias>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias_activation.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_bias_activation.hpp new file mode 100644 index 00000000000..95736b18870 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_bias_activation.hpp @@ -0,0 +1,43 @@ +#ifndef DEVICE_GEMM_BIAS_ACTIVATION_HPP +#define DEVICE_GEMM_BIAS_ACTIVATION_HPP + +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmBiasActivation : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + const void* p_c0, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmBiasActivationPtr = std::unique_ptr< + DeviceGemmBiasActivation>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias_activation_add.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_bias_activation_add.hpp new file mode 100644 index 00000000000..d304abaa384 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_bias_activation_add.hpp @@ -0,0 +1,47 @@ +#ifndef DEVICE_GEMM_BIAS_ACTIVATION_ADD_HPP +#define DEVICE_GEMM_BIAS_ACTIVATION_ADD_HPP + +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmBiasActivationAdd : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + const void* p_c0, + const void* p_c1, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t StrideC1, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmBiasActivationAddPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp new file mode 100644 index 00000000000..a6a059df77c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp @@ -0,0 +1,586 @@ +#pragma once + +#include +#include + +#include "device.hpp" +#include "device_base.hpp" +#include "device_gemm.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gemm_specialization.hpp" +#include "element_wise_operation.hpp" +#include "gridwise_gemm_dl_v1r3.hpp" +#include "device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template < + typename ADataType, + typename BDataType, + typename CDataType, + typename AccDataType, + typename ALayout, + typename BLayout, + typename CLayout, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + GemmSpecialization GemmSpec, + index_t BlockSize, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t K1, + index_t M1PerThread, + index_t N1PerThread, + index_t KPerThread, + typename M1N1ThreadClusterM1Xs, + typename M1N1ThreadClusterN1Xs, + typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + typename ABlockTransferSrcVectorTensorContiguousDimOrder, + typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + typename BBlockTransferSrcVectorTensorContiguousDimOrder, + typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + typename CThreadTransferSrcDstAccessOrder, + index_t CThreadTransferSrcDstVectorDim, + index_t CThreadTransferDstScalarPerVector, + enable_if_t< + is_same_v && + is_same_v && + is_same_v, + bool> = false> +struct DeviceGemmDl + : public DeviceGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDl_km_kn_mn_v1r3; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using DefaultBlock2CTileMap = + decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m0_m1_k1_{}, + b_grid_desc_k0_n0_n1_k1_{}, + c_grid_desc_m0_m10_m11_n0_n10_n11_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + a_grid_desc_k0_m_k1_ = DeviceGemmDl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = DeviceGemmDl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = DeviceGemmDl::MakeCGridDescriptor_M_N(M, N, StrideC); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_)) + { + a_grid_desc_k0_m0_m1_k1_ = + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_); + b_grid_desc_k0_n0_n1_k1_ = + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_); + c_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_; + BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_; + CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_; + + DefaultBlock2CTileMap block_2_ctile_map_; + + // TODO: unused, but may be useful in future. + index_t M01_; + index_t N01_; + + // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being. + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmDl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{" + << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{" + << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdl_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize( + arg.c_grid_desc_m_n_.GetLength(I0), arg.c_grid_desc_m_n_.GetLength(I1)); + + const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030") + { + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); + } + else + { + return false; + } + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmDl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << M1PerThread << ", " + << N1PerThread << ", " + << KPerThread + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp new file mode 100644 index 00000000000..66c966c7f9d --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp @@ -0,0 +1,53 @@ +#pragma once +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmReduce : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + DPtrsGlobal p_dxs, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + DxsInElementwiseOperation dxs_in_element_op, + DxsOutElementwiseOperation dxs_out_element_op, + ck::index_t BatchCount = 1) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmReducePtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp new file mode 100644 index 00000000000..3bd29c13c63 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -0,0 +1,757 @@ +#pragma once +#include +#include +#include "device.hpp" +#include "device_gemm_reduce.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" +#include "gemm_specialization.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce +{ + using DeviceOp = DeviceGemmReduce_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + // assume D is packed tensor + static auto MakeDGridDescriptor_M(index_t MRaw) + { + const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor(d_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return d_grid_desc_mraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using DGridDesc_M = decltype(MakeDGridDescriptor_M(1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + ReduceAccDataType, + DPtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DxsReduceOperation, + DxsInElementwiseOperation, + DxsOutElementwiseOperation, + InMemoryDataOperationEnum::Set, + DGlobalMemoryDataOperation, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + DGridDesc_M, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + DPtrsGlobal p_ds_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + DxsInElementwiseOperation dxs_in_element_op, + DxsOutElementwiseOperation dxs_out_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_ds_grid_{p_ds_grid}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + d_grid_desc_mblock_mperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + dxs_in_element_op_{dxs_in_element_op}, + dxs_out_element_op_{dxs_out_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + d_grid_desc_mblock_mperblock_ = + GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + DPtrsGlobal p_ds_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + DGridDesc_M d_grid_desc_m_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + DxsInElementwiseOperation dxs_in_element_op_; + DxsOutElementwiseOperation dxs_out_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}" + << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float elapsed_time = 0.0f; + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + DPtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DxsInElementwiseOperation, + DxsOutElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_ds_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.dxs_in_element_op_, + arg.dxs_out_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.d_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + DPtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DxsInElementwiseOperation, + DxsOutElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_ds_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.dxs_in_element_op_, + arg.dxs_out_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.d_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); + } + + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + DPtrsGlobal p_dxs, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + DxsInElementwiseOperation dxs_in_element_op, + DxsOutElementwiseOperation dxs_out_element_op) + { + return Argument{p_a, + p_b, + p_c, + p_dxs, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + dxs_in_element_op, + dxs_out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + DPtrsGlobal p_dxs, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + DxsInElementwiseOperation dxs_in_element_op, + DxsOutElementwiseOperation dxs_out_element_op, + index_t /* KBatch */ = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + p_dxs, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + dxs_in_element_op, + dxs_out_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmReduce_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp new file mode 100644 index 00000000000..31f354358f5 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp @@ -0,0 +1,522 @@ +#pragma once + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_gemm.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" +#include "gemm_specialization.hpp" +#include "device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmXdl + : public DeviceGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumPrefetch>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC); + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmXdl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp new file mode 100644 index 00000000000..1db69dd4620 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp @@ -0,0 +1,506 @@ +#pragma once +#include +#include +#include "device.hpp" +#include "device_gemm_bias.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v3r2.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template < + typename ADataType, + typename BDataType, + typename CDataType, + typename AccDataType, + typename ALayout, + typename BLayout, + typename CLayout, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXDL, + ck::index_t NPerXDL, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct DeviceGemmXdl_C_Shuffle_Bias_2d + : public DeviceGemmBias +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_k0_m_k1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_k0_m_k1; + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_k0_n_k1 = + transform_tensor_descriptor(b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_k0_n_k1; + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + C0GridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + const CDataType* p_bias_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c0_grid_{p_bias_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c0_grid_desc_m_n_{}, + c_grid_desc_m_n_{}, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + a_grid_desc_k0_m_k1_ = + DeviceGemmXdl_C_Shuffle_Bias_2d::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = + DeviceGemmXdl_C_Shuffle_Bias_2d::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c0_grid_desc_m_n_ = + DeviceGemmXdl_C_Shuffle_Bias_2d::MakeCGridDescriptor_M_N(M, N, StrideC); + c_grid_desc_m_n_ = + DeviceGemmXdl_C_Shuffle_Bias_2d::MakeCGridDescriptor_M_N(M, N, StrideC); + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c0_grid_desc_m_n_); + + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + const CDataType* p_c0_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + C0GridDesc_M_N c0_grid_desc_m_n_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmXdl_C_Shuffle_Bias_2d::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v3r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + const CDataType* p_bias, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_bias, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_bias), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdl_C_Shuffle_Bias_2d" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp new file mode 100644 index 00000000000..b465f8e4aee --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp @@ -0,0 +1,516 @@ +#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_ACTIVATION_HPP +#define DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_ACTIVATION_HPP + +#include +#include +#include "device.hpp" +#include "device_gemm_bias_activation.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v3r2.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// C[M, N] = activate(A[M, K] * B[K, N] + C0[N]) +template < + typename ADataType, + typename BDataType, + typename CDataType, + typename AccDataType, + typename ALayout, + typename BLayout, + typename CLayout, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXDL, + ck::index_t NPerXDL, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct DeviceGemmXdl_C_Shuffle_Bias_Activation + : public DeviceGemmBiasActivation +{ + using DeviceOp = DeviceGemmXdl_C_Shuffle_Bias_Activation; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N( + index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + // A[K0, M, K1] + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_k0_m_k1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B[K0, N, K1] + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_k0_n_k1 = + transform_tensor_descriptor(b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C[M, N] + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // C0[N]: assume a contiguous vector + const auto c0_grid_desc_m_n = + make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, I1)); + + return make_tuple( + a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_grid_desc_m_n, c0_grid_desc_m_n); + } + + using GridDescs = + decltype(MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N(1, 1, 1, 1, 1, 1)); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using C0GridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + C0GridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const CDataType* p_c0_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_c0_grid_{p_c0_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c0_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + const auto descs = DeviceOp::MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N( + M, N, K, StrideA, StrideB, StrideC); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + c0_grid_desc_m_n_ = descs[I3]; + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c0_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const CDataType* p_c0_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_M_N c0_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r5 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v3r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const CDataType* p_c0, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + p_c0, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + const void* p_c0, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_c0), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdl_C_Shuffle_Bias_Activation" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp new file mode 100644 index 00000000000..7a2e1886d35 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp @@ -0,0 +1,576 @@ +#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_ACTIVATION_ADD_HPP +#define DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_ACTIVATION_ADD_HPP + +#include +#include +#include "device.hpp" +#include "device_gemm_bias_activation_add.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v3r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// C[M, N] = activate(A[M, K] * B[K, N] + C0[N]) + C1[M, N] +template < + typename ADataType, + typename BDataType, + typename CDataType, + typename AccDataType, + typename ALayout, + typename BLayout, + typename CLayout, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXDL, + ck::index_t NPerXDL, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add + : public DeviceGemmBiasActivationAdd +{ + using DeviceOp = DeviceGemmXdl_C_Shuffle_Bias_Activation_Add; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N_C1_M_N(index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideC1) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + // A[K0, M, K1] + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_k0_m_k1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B[K0, N, K1] + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_k0_n_k1 = + transform_tensor_descriptor(b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C[M, N] + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // C0[N]: assume a contiguous vector + const auto c0_grid_desc_m_n = + make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, I1)); + + // C1[M, N]: residual tensor: assume same layout as C + const auto c1_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC1, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC1)); + } + }(); + + return make_tuple(a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m_n, + c0_grid_desc_m_n, + c1_grid_desc_m_n); + } + + using GridDescs = + decltype(MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N_C1_M_N(1, 1, 1, 1, 1, 1, 1)); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using C0GridDesc_M_N = remove_cvref_t; + using C1GridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + C0GridDesc_M_N, + C1GridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const CDataType* p_c0_grid, + const CDataType* p_c1_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideC1, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_c0_grid_{p_c0_grid}, + p_c1_grid_{p_c1_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c0_grid_desc_m_n_{}, + c1_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + const auto descs = DeviceOp::MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N_C1_M_N( + M, N, K, StrideA, StrideB, StrideC, StrideC1); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + c0_grid_desc_m_n_ = descs[I3]; + c1_grid_desc_m_n_ = descs[I4]; + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c0_grid_desc_m_n_); + + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c1_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const CDataType* p_c0_grid_; + const CDataType* p_c1_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_M_N c0_grid_desc_m_n_; + C1GridDesc_M_N c1_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r5 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v3r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.p_c1_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.p_c1_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const CDataType* p_c0, + const CDataType* p_c1, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideC1, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + p_c0, + p_c1, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideC1, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + const void* p_c0, + const void* p_c1, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideC1, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_c0), + static_cast(p_c1), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideC1, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdl_C_Shuffle_Bias_Activation_Add" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp new file mode 100644 index 00000000000..a74ee816799 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp @@ -0,0 +1,666 @@ +#pragma once +#include +#include +#include "device.hpp" +#include "device_gemm.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdl_cshuffle_v1.hpp" +#include "tensor_operation/gpu/device/gemm_specialization.hpp" +#include "device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceGemm_Xdl_CShuffle + : public DeviceGemm +{ + using DeviceOp = DeviceGemm_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemm_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp new file mode 100644 index 00000000000..d9fc8f7a8a7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp @@ -0,0 +1,642 @@ +#ifndef DEVICE_GEMM_SPLITK_XDL_HPP +#define DEVICE_GEMM_SPLITK_XDL_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_gemm.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r4.hpp" +#include "gemm_specialization.hpp" +#include "device_prop.hpp" + +#ifndef CK_RUN_KERNEL_AND_TIME +#define CK_RUN_KERNEL_AND_TIME 1 +#endif + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmXdlSplitK + : public DeviceGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + + static auto + MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t K, index_t StrideA, int KBatch, int KPad) + { + assert(KPad % (K1 * KBatch) == 0); + + const index_t K0 = KPad / (K1 * KBatch); + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(M)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + static auto + MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t N, index_t StrideB, int KBatch, int KPad) + { + assert(KPad % (K1 * KBatch) == 0); + + const index_t K0 = KPad / (K1 * KBatch); + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto GetKPad(index_t K, index_t KBatch) + { + const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; + const index_t KPad = KBatch * K0 * K1; + return KPad; + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // GridwiseGemm + using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(CGridDesc_M_N{})); + + using Block2CTileMap = + decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t k_batch) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + k_batch_{k_batch} + { + int KPad = DeviceGemmXdlSplitK::GetKPad(K, k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = DeviceGemmXdlSplitK::MakeAGridDescriptor_KBatch_K0_M_K1( + M, K, StrideA, k_batch_, KPad); + b_grid_desc_kbatch_k0_n_k1_ = DeviceGemmXdlSplitK::MakeBGridDescriptor_KBatch_K0_N_K1( + K, N, StrideB, k_batch_, KPad); + c_grid_desc_m_n_ = DeviceGemmXdlSplitK::MakeCGridDescriptor_M_N(M, N, StrideC); + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmXdlSplitK::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + ShowInfo(arg); + + const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + // FIXME: this should be moved outside of DeviceOp + hipGetErrorString( + hipMemset(arg.p_c_grid_, + 0, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetElementSpaceSize() * + sizeof(CDataType))); + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + if(has_main_k0_block_loop) + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_v2r4< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r4< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + } + else + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_v2r4< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r4< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op, + KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdlSplitK" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp new file mode 100644 index 00000000000..ad424d91d97 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp @@ -0,0 +1,644 @@ +#ifndef DEVICE_GEMM_XDL_SPLITK_C_SHUFFLE_HPP +#define DEVICE_GEMM_XDL_SPLITK_C_SHUFFLE_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_gemm.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r4r2.hpp" +#include "gemm_specialization.hpp" + +#ifndef CK_RUN_KERNEL_AND_TIME +#define CK_RUN_KERNEL_AND_TIME 1 +#endif + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmXdlSplitKCShuffle + : public DeviceGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + + static auto + MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t K, index_t StrideA, int KBatch, int KPad) + { + assert(KPad % (K1 * KBatch) == 0); + + const index_t K0 = KPad / (K1 * KBatch); + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(M)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + static auto + MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t N, index_t StrideB, int KBatch, int KPad) + { + assert(KPad % (K1 * KBatch) == 0); + + const index_t K0 = KPad / (K1 * KBatch); + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto GetKPad(index_t K, index_t KBatch) + { + const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; + const index_t KPad = KBatch * K0 * K1; + return KPad; + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXDL, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; + + // GridwiseGemm + using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXDL, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + + using Block2CTileMap = typename GridwiseGemm::CBlockClusterAdaptor; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t k_batch) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + k_batch_{k_batch} + { + int KPad = DeviceGemmXdlSplitKCShuffle::GetKPad(K, k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = + DeviceGemmXdlSplitKCShuffle::MakeAGridDescriptor_KBatch_K0_M_K1( + M, K, StrideA, k_batch_, KPad); + b_grid_desc_kbatch_k0_n_k1_ = + DeviceGemmXdlSplitKCShuffle::MakeBGridDescriptor_KBatch_K0_N_K1( + K, N, StrideB, k_batch_, KPad); + c_grid_desc_m_n_ = DeviceGemmXdlSplitKCShuffle::MakeCGridDescriptor_M_N(M, N, StrideC); + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmXdlSplitKCShuffle::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + ShowInfo(arg); + + const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + hipGetErrorString(hipMemset( + arg.p_c_grid_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(CDataType))); + + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + if(has_main_k0_block_loop) + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_v2r4r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r4r2< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + } + else + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_v2r4r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r4r2< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op, + KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdlSplitKCShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp new file mode 100644 index 00000000000..08a70823be3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp @@ -0,0 +1,660 @@ +#ifndef DEVICE_GROUPED_GEMM_XDL_HPP +#define DEVICE_GROUPED_GEMM_XDL_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_gemm.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" +#include "gemm_specialization.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_xdlops_v2r3( + const StaticallyIndexedArray gemm_descs, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + +#if 1 + static_for<0, MaxGroupCount, 1>{}([&](auto i) { + if(block_id >= gemm_descs[i].BlockStart_ && block_id < gemm_descs[i].BlockEnd_ && + i < group_count) + { + auto group_id = i; + + GridwiseGemm::template Run( + gemm_descs[group_id].a_ptr, + gemm_descs[group_id].b_ptr, + gemm_descs[group_id].c_ptr, + p_shared, + gemm_descs[group_id].a_grid_desc_k0_m_k1_, + gemm_descs[group_id].b_grid_desc_k0_n_k1_, + gemm_descs[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + a_element_op, + b_element_op, + c_element_op, + gemm_descs[group_id].grouped_gemm_block_2_ctile_map_); + } + }); +#else + const auto gemm_desc_ptr = reinterpret_cast(&gemm_descs); + + index_t group_id = 0; + static_for<0, MaxGroupCount, 1>{}([&](auto i) { + group_id = (block_id >= gemm_descs[i].BlockStart && block_id < gemm_descs[i].BlockEnd && + i < group_count) + ? i + : group_id; + }); + + const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart; + + GridwiseGemm::template Run( + gemm_desc_ptr[group_id].a_ptr, + gemm_desc_ptr[group_id].b_ptr, + gemm_desc_ptr[group_id].c_ptr, + p_shared, + gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_, + gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_, + gemm_desc_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + a_element_op, + b_element_op, + c_element_op, + gemm_desc_ptr[group_id].block_2_ctile_map_, + block_id_grp); +#endif +#else + ignore = gemm_descs; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct DeviceGroupedGemmXdl + : public DeviceGroupedGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumPrefetch>; + + struct GroupedGemmBlock2CTileMap + { + using UnderlyingBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + static_assert( + std::is_same::value, + "Wrong! Should be the same type name"); + GroupedGemmBlock2CTileMap() + { + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1); + BlockStart_ = -1; + } + + GroupedGemmBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01, + ck::index_t BlockStart) + { + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, M01, N01); + BlockStart_ = BlockStart; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return block_2_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[I0] - BlockStart_)); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_2_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_2_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + ck::index_t BlockStart_; + }; + + struct GemmDescKernelArg + { + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + + GroupedGemmBlock2CTileMap grouped_gemm_block_2_ctile_map_; + + const ADataType* a_ptr; + const BDataType* b_ptr; + CDataType* c_ptr; + + ck::index_t BlockStart_, BlockEnd_; + }; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::vector& p_a, + std::vector& p_b, + std::vector& p_c, + std::vector& gemm_shapes, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + grid_size_ = 0; + + group_count_ = ck::type_convert(gemm_shapes.size()); + + if(!(group_count_ == ck::type_convert(p_a.size()) && + group_count_ == ck::type_convert(p_b.size()) && + group_count_ == ck::type_convert(p_c.size()))) + { + throw std::runtime_error("wrong! group_count_ != P_a/b/c.size"); + } + + gemm_desc_kernel_arg_.reserve(group_count_); + + for(std::size_t i = 0; i < gemm_shapes.size(); i++) + { + const index_t M = gemm_shapes[i].M; + const index_t N = gemm_shapes[i].N; + const index_t K = gemm_shapes[i].K; + + const index_t StrideA = gemm_shapes[i].StrideA; + const index_t StrideB = gemm_shapes[i].StrideB; + const index_t StrideC = gemm_shapes[i].StrideC; + + const auto a_grid_desc_k0_m_k1_ = + DeviceGroupedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + const auto b_grid_desc_k0_n_k1_ = + DeviceGroupedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + const auto c_grid_desc_m_n_ = + DeviceGroupedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC); + + const index_t grid_size_grp = + GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, 0) + .block_2_ctile_map_.CalculateGridSize(c_grid_desc_m_n_); + + const index_t BlockStart = grid_size_; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + const auto grouped_gemm_block_2_ctile_map_ = + GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, BlockStart); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + grouped_gemm_block_2_ctile_map_)) + { + const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + + gemm_desc_kernel_arg_.push_back( + GemmDescKernelArg{a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + grouped_gemm_block_2_ctile_map_, + static_cast(p_a[i]), + static_cast(p_b[i]), + static_cast(p_c[i]), + BlockStart, + BlockEnd}); + } + } + } + + // private: + index_t M01_; + index_t N01_; + index_t group_count_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + + std::vector gemm_desc_kernel_arg_; + + index_t grid_size_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGroupedGemmXdl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + StaticallyIndexedArray gemm_desc_kernel_args; + + bool has_main_k_block_loop = true; + + static_for<0, MaxGroupCount, 1>{}([&](auto i) { + if(i < arg.gemm_desc_kernel_arg_.size()) + { + gemm_desc_kernel_args(i) = arg.gemm_desc_kernel_arg_[i]; + + std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{" + << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}"; + + std::cout << ", arg.b_grid_desc_k0_n_k1_{" + << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}"; + + std::cout << ", arg.c_grid_desc_m_n_{ " + << gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I0) << ", " + << gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I1) << "}" + << std::endl; + + if(!GridwiseGemm::CheckValidity( + gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_, + gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_, + gemm_desc_kernel_args[i].c_grid_desc_m_n_, + gemm_desc_kernel_args[i].grouped_gemm_block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const auto K = gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) * + gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop) + { + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + } + } + }); + + float ave_time = 0; + + if(has_main_k_block_loop) + { + const auto kernel = + kernel_grouped_gemm_xdlops_v2r3, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + true, + MaxGroupCount>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + gemm_desc_kernel_args, + arg.gemm_desc_kernel_arg_.size(), + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } + else + { + const auto kernel = + kernel_grouped_gemm_xdlops_v2r3, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + false, + MaxGroupCount>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + gemm_desc_kernel_args, + arg.gemm_desc_kernel_arg_.size(), + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) + return false; + else + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_a, + std::vector& p_b, + std::vector& p_c, + std::vector gemm_shapes, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(std::vector& p_a, + std::vector& p_b, + std::vector& p_c, + std::vector& gemm_shapes, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) override + { + return std::make_unique( + p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemmXdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp new file mode 100644 index 00000000000..d049f6e9791 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp @@ -0,0 +1,38 @@ +#ifndef DEVICE_POOL2D_FWD_HPP +#define DEVICE_POOL2D_FWD_HPP + +#include +#include +#include "device_base.hpp" +#include "reduction_enums.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DevicePool2dFwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* in_dev, + void* out_dev, + void* out_indices_dev, + ck::index_t N, + ck::index_t C, + std::array input_spatial_lengths, + std::array window_spatial_lengths, + std::array output_spatial_lengths, + std::array window_strides, + std::array input_left_pads, + std::array input_right_pads) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DevicePool2dFwdPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp b/include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp new file mode 100644 index 00000000000..c7e18d98dcd --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp @@ -0,0 +1,327 @@ +#ifndef DEVICE_POOL2D_FWD_NHWC_NHWC_HPP +#define DEVICE_POOL2D_FWD_NHWC_NHWC_HPP + +#include +#include +#include "device_pool2d_fwd.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "reduction_operator_mapping.hpp" +#include "gridwise_2d_reduction_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using IndexDataType = int32_t; + + using ReduceOperation = typename reduce_binary_operator::opType; + + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + + using AccElementwiseOperation = + typename reduce_unary_operator:: + AccElementwiseOperation; + + static constexpr index_t InSrcOutDstVectorDim = + 0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is + // not reduced. + + static constexpr ck::index_t ReduceM_BlockTileSize = + ReduceMThreadClusterSize * ReduceMThreadSliceSize; + static constexpr ck::index_t ReduceK_BlockTileSize = + ReduceKThreadClusterSize * ReduceKThreadSliceSize; + + static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N, + ck::index_t C, + std::array input_spatial_lengths, + std::array window_spatial_lengths, + std::array output_spatial_lengths, + std::array window_strides, + std::array input_left_pads, + std::array input_right_pads) + { + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = window_spatial_lengths[0]; + const index_t X = window_spatial_lengths[1]; + + const index_t ConvStrideH = window_strides[0]; + const index_t ConvStrideW = window_strides[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t ReduceMRaw = N * Ho * Wo * C; + const index_t ReduceMPad = + math::integer_least_multiple(ReduceMRaw, ReduceM_BlockTileSize) - ReduceMRaw; + + const index_t ReduceKRaw = Y * X; + const index_t ReduceKPad = + math::integer_least_multiple(ReduceKRaw, ReduceK_BlockTileSize) - ReduceKRaw; + + // A[ReduceM, ReduceK] + const auto in_grid_desc_n_hi_wi_c = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_grid_desc_n_hip_wip_c = transform_tensor_descriptor( + in_grid_desc_n_hi_wi_c, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_grid_desc_n_y_ho_x_wo_c = transform_tensor_descriptor( + in_grid_desc_n_hip_wip_c, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_grid_desc_reducemraw_reducekraw = + transform_tensor_descriptor(in_grid_desc_n_y_ho_x_wo_c, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C)), + make_merge_transform(make_tuple(Y, X))), + make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor( + in_grid_desc_reducemraw_reducekraw, + make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad), + make_right_pad_transform(ReduceKRaw, ReduceKPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // B[ReduceM] + const auto out_grid_desc_reducemraw = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo * C)); + + const auto out_grid_desc_reducem = transform_tensor_descriptor( + out_grid_desc_reducemraw, + make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem); + } + + using ABGridDescs = decltype( + MakeABGridDescriptor_A_M_K_B_M(1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_M_K = remove_cvref_t; + using BGridDesc_M = remove_cvref_t; + + // TODO + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_dev, + OutDataType* p_out_dev, + int* p_out_indices_dev, + ck::index_t N, + ck::index_t C, + std::array& input_spatial_lengths, + std::array& window_spatial_lengths, + std::array& output_spatial_lengths, + std::array& window_strides, + std::array& input_left_pads, + std::array& input_right_pads) + : p_in_dev_{p_in_dev}, + p_out_dev_{p_out_dev}, + p_out_indices_dev_{p_out_indices_dev}, + a_grid_desc_m_k_{}, + b_grid_desc_m_{} + { + const auto descs = MakeABGridDescriptor_A_M_K_B_M(N, + C, + input_spatial_lengths, + window_spatial_lengths, + output_spatial_lengths, + window_strides, + input_left_pads, + input_right_pads); + + a_grid_desc_m_k_ = descs[I0]; + b_grid_desc_m_ = descs[I1]; + + invariant_lowest_length_ = C; + reduce_lowest_length_ = window_spatial_lengths[1]; + + // TODO: is this correct? + if constexpr(ReduceOpId == ck::ReduceTensorOp::AVG) + { + ck::index_t divider = window_spatial_lengths[0] * window_spatial_lengths[1]; + in_element_op_ = InElementwiseOperation{divider}; + acc_element_op_ = AccElementwiseOperation{divider}; + } + } + + const InDataType* p_in_dev_; + OutDataType* p_out_dev_; + int* p_out_indices_dev_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_M b_grid_desc_m_; + InElementwiseOperation in_element_op_; + AccElementwiseOperation acc_element_op_; + + // for checking vector load/store + ck::index_t invariant_lowest_length_; + ck::index_t reduce_lowest_length_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + using gridwise_reduce = + GridwiseReduction_mk_to_m_threadwise; + + const auto kernel = kernel_reduce_threadwise; + + ck::index_t ReduceM = arg.a_grid_desc_m_k_.GetLength(I0); + + const index_t grid_size = (ReduceM / ReduceM_BlockTileSize); + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.a_grid_desc_m_k_, + arg.b_grid_desc_m_, + arg.in_element_op_, + arg.acc_element_op_, + float(1), + arg.p_in_dev_, + nullptr, + float(0), + arg.p_out_dev_, + arg.p_out_indices_dev_); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg->invariant_lowest_length_ % InSrcOutDstVectorSize != 0) + { + return (false); + } + + return (true); + } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_dev, + void* p_out_dev, + void* p_out_indices_dev, + ck::index_t N, + ck::index_t C, + std::array input_spatial_lengths, + std::array window_spatial_lengths, + std::array output_spatial_lengths, + std::array window_strides, + std::array input_left_pads, + std::array input_right_pads) override + { + return std::make_unique(static_cast(p_in_dev), + static_cast(p_out_dev), + static_cast(p_out_indices_dev), + N, + C, + input_spatial_lengths, + window_spatial_lengths, + output_spatial_lengths, + window_strides, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<" << BlockSize << ","; + str << "M_C" << ReduceMThreadClusterSize << "_S" << ReduceMThreadSliceSize << ","; + str << "K_C" << ReduceKThreadClusterSize << "_S" << ReduceKThreadSliceSize << ","; + str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_reduce.hpp new file mode 100644 index 00000000000..6f367a8747c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_reduce.hpp @@ -0,0 +1,44 @@ +#ifndef DEVICE_REDUCE_HPP +#define DEVICE_REDUCE_HPP + +#include +#include +#include + +#include "common_header.hpp" +#include "device_base.hpp" +#include "reduction_enums.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceReduce : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const std::vector inLengths, + const std::vector inStrides, + const std::vector outLengths, + const std::vector outStrides, + const std::vector reduceDims, + float alpha, + float beta, + const void* in_dev, + const void* in_index_dev, + void* out_dev, + void* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceReducePtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_reduce_common.hpp b/include/ck/tensor_operation/gpu/device/device_reduce_common.hpp new file mode 100644 index 00000000000..f68a3928217 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_reduce_common.hpp @@ -0,0 +1,90 @@ +#ifndef DEVICE_REDUCE_COMMON_HPP +#define DEVICE_REDUCE_COMMON_HPP + +#include +#include + +#include "common_header.hpp" +#include "reduction_enums.hpp" +#include "reduction_operator.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// here, inLengths[] is already shuffled so that lengths of invariant dims are included before those +// of reduce dims +template +std::pair get_2d_lengths(const std::vector& inLengths) +{ + static_assert(Rank <= 6, "bigger Rank size not supported!"); + + long_index_t invariant_total_length = 1; + long_index_t reduce_total_length = 1; + + constexpr int NumInvariantDim = Rank - NumReduceDim; + + for(int i = NumInvariantDim; i < Rank; i++) + reduce_total_length *= inLengths[i]; + + for(int i = 0; i < NumInvariantDim; i++) + invariant_total_length *= inLengths[i]; + + return std::make_pair(invariant_total_length, reduce_total_length); +}; + +// helper functions using variadic template arguments +template +auto make_tuple_from_array_and_index_seq(const std::vector& lengths, Sequence) +{ + return make_tuple(static_cast(lengths[Ns])...); +}; + +template +auto make_tuple_from_array(const std::vector& lengths, Number) +{ + static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); + + constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; + + return make_tuple_from_array_and_index_seq(lengths, index_seq); +}; + +template +std::vector shuffle_tensor_dimensions(const std::vector& origLengthsStrides, + const std::vector& reduceDims) +{ + std::vector newLengthsStrides; + + assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size()); + + int reduceFlag = 0; + + // flag the bits for the reduceDims + for(int i = 0; i < NumReduceDim; i++) + { + reduceFlag |= 1 << reduceDims[i]; + }; + + // collect invariant dimensions + for(int i = 0; i < Rank; i++) + if((reduceFlag & (1 << i)) == 0) + { + newLengthsStrides.push_back(origLengthsStrides[i]); + }; + + // collect reduce dimensions + for(int i = 0; i < Rank; i++) + if((reduceFlag & (1 << i)) > 0) + { + newLengthsStrides.push_back(origLengthsStrides[i]); + }; + + return newLengthsStrides; +}; + +} // namespace device +} // namespace tensor_operation + +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp b/include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp new file mode 100644 index 00000000000..2f447c0979b --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp @@ -0,0 +1,508 @@ +#ifndef DEVICE_REDUCE_MULTIBLOCK_HPP +#define DEVICE_REDUCE_MULTIBLOCK_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_reduce.hpp" +#include "device_reduce_common.hpp" +#include "gridwise_2d_reduction_multiblock.hpp" +#include "gridwise_set_buffer_value.hpp" +#include "reduction_operator.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceReduceMultiBlock : public DeviceReduce +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, + "Invalid thread cluster size assignments!"); + + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + using IndexDataType = int32_t; + + static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex; + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + static constexpr index_t numSrcDim = Rank; + static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + // So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added + // later + static constexpr bool use_multiblock = + (OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd); + + static constexpr bool out_type_compatible_with_atomic_op = + std::is_same::value || std::is_same::value; + + static_assert( + !use_multiblock || (use_multiblock && out_type_compatible_with_atomic_op), + "The OutDataType must support the atomic operation for using MultiBlock reduction"); + + static_assert(!use_multiblock || (use_multiblock && !OutputIndex), + "MultiBlock reduction can only be used when outputing index is not required"); + + static_assert( + ReduceOperation::IsCompatibleInMemoryDataOperation(OutMemoryDataOperation), + "The reduction accumulation operation must be compatible with the OutMemoryDataOperation!"); + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int blkGroupSize, + int numBlockTileIteration) + { + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + static auto MakeDst1dDescriptor(const std::vector& outLengths, + const std::vector& outStrides) + { + const auto tupleDstLengths = make_tuple_from_array(outLengths, Number{}); + const auto tupleDstStrides = make_tuple_from_array(outStrides, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); + + const auto outPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto out_grid_desc_m_padded = transform_tensor_descriptor( + out_grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, outPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + static auto MakeDst1dDescriptorForBufferSet(const std::vector& outLengths, + const std::vector& outStrides) + { + const auto tupleDstLengths = make_tuple_from_array(outLengths, Number{}); + const auto tupleDstStrides = make_tuple_from_array(outStrides, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto length = out_grid_desc_m.GetLength(Number<0>{}); + + const auto pad = math::integer_least_multiple(length, BlockSize) - length; + + auto out_grid_desc_m_padded = + transform_tensor_descriptor(out_grid_desc_m, + make_tuple(make_right_pad_transform(length, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + struct Argument : public BaseArgument + { + Argument(const std::vector inLengths, + const std::vector inStrides, + const std::vector outLengths, + const std::vector outStrides, + const std::vector reduceDims, + float alpha, + float beta, + const InDataType* in_dev, + const IndexDataType* in_index_dev, + OutDataType* out_dev, + IndexDataType* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) + : outLengths_{outLengths}, + outStrides_{outStrides}, + in_dev_{in_dev}, + in_index_dev_{in_index_dev}, + out_dev_{out_dev}, + out_index_dev_{out_index_dev}, + in_elementwise_op_{in_elementwise_op}, + acc_elementwise_op_{acc_elementwise_op} + { + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); + + alpha_ = type_convert(alpha); + beta_ = type_convert(beta); + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(inLengths_); + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length = 1; + else + invariant_lowest_length = inLengths_[NumInvariantDim - 1]; + + reduce_lowest_length = inLengths_[Rank - 1]; + + if constexpr(use_multiblock) + { + + int iterations = 1; + while(true) + { + int testBlkGroupSize = + (reduce_total_length + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + // we want the blkGroupSize be not more than 128 + if(testBlkGroupSize <= 128) + break; + + iterations++; + }; + + blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + numBlockTileIteration = iterations; + } + else + { + blkGroupSize = 1; + numBlockTileIteration = + (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + }; + + gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize * blkGroupSize; + + gridSize_pre = + math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize; + } + + std::vector inLengths_; + std::vector inStrides_; + std::vector outLengths_; + std::vector outStrides_; + + AccDataType alpha_; + AccDataType beta_; + + const InDataType* in_dev_; + const IndexDataType* in_index_dev_; + OutDataType* out_dev_; + IndexDataType* out_index_dev_; + + InElementwiseOperation in_elementwise_op_; + AccElementwiseOperation acc_elementwise_op_; + + index_t invariant_lowest_length; + index_t reduce_lowest_length; + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + int blkGroupSize; + int numBlockTileIteration; + size_t gridSize; + + size_t gridSize_pre; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto in_grid_desc_m_k = DeviceReduceMultiBlock::MakeSrc2dDescriptor( + arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); + const auto out_grid_desc_m = + DeviceReduceMultiBlock::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_); + const auto out_grid_desc_m_2 = DeviceReduceMultiBlock::MakeDst1dDescriptorForBufferSet( + arg.outLengths_, arg.outStrides_); + + using InGridDesc_M_K = decltype(in_grid_desc_m_k); + using OutGridDesc_M = decltype(out_grid_desc_m); + using OutGridDesc_M_2 = decltype(out_grid_desc_m_2); + + using GridwiseReduce = GridwiseReduction_mk_to_m_multiblock; + + const auto kernel_main = kernel_reduce_multiblock; + + float avg_time = 0; + + if constexpr(use_multiblock) + { + const auto zeroVal = + ck::reduce::GetReductionZeroValueForInMemoryDataOperation( + OutMemoryDataOperation); + + const auto kernel_pre = + kernel_buffer_set_value; + + avg_time += launch_and_time_kernel(stream_config, + kernel_pre, + dim3(arg.gridSize_pre), + dim3(BlockSize), + 0, + out_grid_desc_m_2, + arg.out_dev_, + zeroVal); + }; + + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + in_grid_desc_m_k, + out_grid_desc_m, + arg.in_elementwise_op_, + arg.acc_elementwise_op_, + arg.blkGroupSize, + arg.numBlockTileIteration, + arg.alpha_, + arg.in_dev_, + arg.in_index_dev_, + arg.beta_, + arg.out_dev_, + arg.out_index_dev_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if constexpr(use_multiblock) + { + if(static_cast(pArg->beta_) != 0.0f) + return (false); + }; + + if constexpr(InSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return (false); + } + else + { + if(pArg->inStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(pArg->invariant_lowest_length % InSrcVectorSize != 0) + return (false); + }; + } + else + { + if(pArg->inStrides_[Rank - 1] != 1) + return (false); + + if(pArg->reduce_lowest_length % InSrcVectorSize != 0) + return (false); + }; + + // To improve + if(pArg->invariant_lowest_length % OutDstVectorSize != 0) + return (false); + + if constexpr(use_multiblock) + { + // blkGroupSize of 1 should be handled by Blockwise path using + // InMemoryDataOperationEnum::Set + if(pArg->blkGroupSize == 1) + return (false); + + // This is very strong restriction, but needed to avoid some failure + if(pArg->invariant_lowest_length % M_BlockTileSize != 0) + return (false); + } + else + { + // cases with very small reduce_total_length should be handled by ThreadWise kernel + if(pArg->reduce_total_length / KThreadSliceSize < 2) + return (false); + }; + + return (true); + }; + + std::unique_ptr + MakeArgumentPointer(const std::vector inLengths, + const std::vector inStrides, + const std::vector outLengths, + const std::vector outStrides, + const std::vector reduceDims, + float alpha, + float beta, + const void* in_dev, + const void* in_index_dev, + void* out_dev, + void* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) override + { + return std::make_unique(inLengths, + inStrides, + outLengths, + outStrides, + reduceDims, + alpha, + beta, + static_cast(in_dev), + static_cast(in_index_dev), + static_cast(out_dev), + static_cast(out_index_dev), + in_elementwise_op, + acc_elementwise_op); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceReduceMultiBlockAtomicAdd<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_reduce_threadwise.hpp b/include/ck/tensor_operation/gpu/device/device_reduce_threadwise.hpp new file mode 100644 index 00000000000..9549bf65d24 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_reduce_threadwise.hpp @@ -0,0 +1,373 @@ +#ifndef DEVICE_REDUCE_THREADWISE_HPP +#define DEVICE_REDUCE_THREADWISE_HPP + +#include +#include +#include "device.hpp" +#include "device_reduce.hpp" +#include "device_reduce_common.hpp" +#include "gridwise_2d_reduction_multiblock.hpp" +#include "gridwise_2d_reduction_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceReduceThreadWise : public DeviceReduce +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + using IndexDataType = int32_t; + + static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex; + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + static constexpr index_t numSrcDim = Rank; + static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize; + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides) + { + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = + math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + static auto MakeDst1dDescriptor(const std::vector& outLengths, + const std::vector& outStrides) + { + const auto tupleDstLengths = make_tuple_from_array(outLengths, Number{}); + const auto tupleDstStrides = make_tuple_from_array(outStrides, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); + + const auto outPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto out_grid_desc_m_padded = transform_tensor_descriptor( + out_grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, outPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + struct Argument : public BaseArgument + { + Argument(const std::vector inLengths, + const std::vector inStrides, + const std::vector outLengths, + const std::vector outStrides, + const std::vector reduceDims, + float alpha, + float beta, + const InDataType* in_dev, + OutDataType* out_dev, + IndexDataType* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) + : outLengths_{outLengths}, + outStrides_{outStrides}, + in_dev_{in_dev}, + out_dev_{out_dev}, + out_index_dev_{out_index_dev}, + in_elementwise_op_{in_elementwise_op}, + acc_elementwise_op_{acc_elementwise_op} + { + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); + + alpha_ = type_convert(alpha); + beta_ = type_convert(beta); + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(inLengths_); + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length = 1; + else + invariant_lowest_length = inLengths_[NumInvariantDim - 1]; + + reduce_lowest_length = inLengths_[Rank - 1]; + + numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + + gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize; + } + + std::vector inLengths_; + std::vector inStrides_; + std::vector outLengths_; + std::vector outStrides_; + + AccDataType alpha_; + AccDataType beta_; + + const InDataType* in_dev_; + OutDataType* out_dev_; + IndexDataType* out_index_dev_; + + InElementwiseOperation in_elementwise_op_; + AccElementwiseOperation acc_elementwise_op_; + + index_t invariant_lowest_length; + index_t reduce_lowest_length; + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + int numBlockTileIteration; + size_t gridSize; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto in_grid_desc_m_k = + DeviceReduceThreadWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_); + const auto out_grid_desc_m = + DeviceReduceThreadWise::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_); + using InGridDesc_M_K = decltype(in_grid_desc_m_k); + using OutGridDesc_M = decltype(out_grid_desc_m); + + float avg_time = 0; + + using GridwiseReduce = + GridwiseReduction_mk_to_m_threadwise; + + const auto kernel = kernel_reduce_threadwise; + + avg_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + in_grid_desc_m_k, + out_grid_desc_m, + arg.in_elementwise_op_, + arg.acc_elementwise_op_, + arg.alpha_, + arg.in_dev_, + nullptr, + arg.beta_, + arg.out_dev_, + arg.out_index_dev_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if constexpr(InSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return (false); + } + else + { + if(pArg->inStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(pArg->invariant_lowest_length % InSrcVectorSize != 0) + return (false); + }; + } + else + { + if(pArg->inStrides_[Rank - 1] != 1) + return (false); + + if(pArg->reduce_lowest_length % InSrcVectorSize != 0) + return (false); + }; + + // To improve + if(pArg->invariant_lowest_length % OutDstVectorSize != 0) + return (false); + + // cases with big reduce_total_length should be handled by Blockwise kernel + if(pArg->reduce_total_length / KThreadSliceSize >= 32) + return (false); + + return (true); + }; + + std::unique_ptr + MakeArgumentPointer(const std::vector inLengths, + const std::vector inStrides, + const std::vector outLengths, + const std::vector outStrides, + const std::vector reduceDims, + float alpha, + float beta, + const void* in_dev, + const void* in_index_dev, + void* out_dev, + void* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) override + { + (void)in_index_dev; + + return std::make_unique(inLengths, + inStrides, + outLengths, + outStrides, + reduceDims, + alpha, + beta, + static_cast(in_dev), + static_cast(out_dev), + static_cast(out_index_dev), + in_elementwise_op, + acc_elementwise_op); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceReduceThreadWise<" << BlockSize << ","; + str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << 1 << "_S" << KThreadSliceSize << ","; + str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp new file mode 100644 index 00000000000..d4ef61a133a --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp @@ -0,0 +1,23 @@ +#ifndef GEMM_SPECIALIZATION +#define GEMM_SPECIALIZATION + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct GemmSpecialization +{ + Default, + MPadding, + NPadding, + KPadding, + MNPadding, + MKPadding, + NKPadding, + MNKPadding, +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp b/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp new file mode 100644 index 00000000000..634e9212ea8 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp @@ -0,0 +1,169 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef CK_REDUCTION_OPERATOR_MAPPING_HPP +#define CK_REDUCTION_OPERATOR_MAPPING_HPP + +#include "reduction_operator.hpp" +#include "reduction_enums.hpp" +#include "element_wise_operation.hpp" + +namespace ck { + +// The templated struct reduce_binary_operator maps the enum Ids of binary operators to their +// respective functor classes. +// The boolean member "indexable" are also provided in reduce_binary_operactor for +// easier checking by the upper-layer codes in the kernels. + +template +struct reduce_binary_operator; + +template +struct reduce_binary_operator +{ + using opType = reduce::Add; + using dataType = T; + + static constexpr bool indexable = false; +}; + +template +struct reduce_binary_operator +{ + using opType = reduce::Mul; + using dataType = T; + + static constexpr bool indexable = false; +}; + +template +struct reduce_binary_operator +{ + using opType = reduce::Min; + using dataType = T; + + static constexpr bool indexable = true; +}; + +template +struct reduce_binary_operator +{ + using opType = reduce::Max; + using dataType = T; + + static constexpr bool indexable = true; +}; + +template +struct reduce_binary_operator +{ + using opType = reduce::AMax; + using dataType = T; + + static constexpr bool indexable = true; +}; + +template +struct reduce_binary_operator +{ + using opType = reduce::Add; + using dataType = T; + + static constexpr bool indexable = false; +}; + +template +struct reduce_binary_operator +{ + using opType = reduce::Add; + using dataType = T; + + static constexpr bool indexable = false; +}; + +template +struct reduce_binary_operator +{ + using opType = reduce::Add; + using dataType = T; + + static constexpr bool indexable = false; +}; + +// The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary +// functor classes. +// The two unary functors are called before and afer the Reduction is executed respectively +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; + using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; +}; + +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; + using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; +}; + +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs; + using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; +}; + +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs; + using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; +}; + +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnarySquare; + using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; +}; + +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnarySquare; + using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt; +}; + +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; + using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt; +}; + +} // end of namespace ck + +#endif diff --git a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp new file mode 100644 index 00000000000..2409071b482 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp @@ -0,0 +1,129 @@ +#pragma once + +namespace ck { +namespace tensor_layout { + +struct BaseTensorLayout +{ +}; + +namespace gemm { + +struct RowMajor : public BaseTensorLayout +{ + static constexpr const char* name = "RowMajor"; +}; + +struct ColumnMajor : public BaseTensorLayout +{ + static constexpr const char* name = "ColumnMajor"; +}; +} // namespace gemm + +namespace convolution { + +// 1D Conv +struct NWC : public BaseTensorLayout +{ + static constexpr const char* name = "NWC"; +}; + +struct KXC : public BaseTensorLayout +{ + static constexpr const char* name = "KXC"; +}; + +struct NWK : public BaseTensorLayout +{ + static constexpr const char* name = "NWK"; +}; + +struct NCW : public BaseTensorLayout +{ + static constexpr const char* name = "NCW"; +}; + +struct KCX : public BaseTensorLayout +{ + static constexpr const char* name = "KCX"; +}; + +struct NKW : public BaseTensorLayout +{ + static constexpr const char* name = "NKW"; +}; + +// 2D Conv +struct NHWC : public BaseTensorLayout +{ + static constexpr const char* name = "NHWC"; +}; + +struct KYXC : public BaseTensorLayout +{ + static constexpr const char* name = "KYXC"; +}; + +struct NHWK : public BaseTensorLayout +{ + static constexpr const char* name = "NHWK"; +}; + +struct NCHW : public BaseTensorLayout +{ + static constexpr const char* name = "NCHW"; +}; + +struct KCYX : public BaseTensorLayout +{ + static constexpr const char* name = "KCYX"; +}; + +struct NKHW : public BaseTensorLayout +{ + static constexpr const char* name = "NKHW"; +}; + +// 3D Conv +struct NDHWC : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWC"; +}; + +struct KZYXC : public BaseTensorLayout +{ + static constexpr const char* name = "KZYXC"; +}; + +struct NDHWK : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWK"; +}; +struct NCDHW : public BaseTensorLayout +{ + static constexpr const char* name = "NCDHW"; +}; + +struct KCZYX : public BaseTensorLayout +{ + static constexpr const char* name = "KCZYX"; +}; + +struct NKDHW : public BaseTensorLayout +{ + static constexpr const char* name = "NKDHW"; +}; + +} // namespace convolution + +template < + typename Layout, + typename std::enable_if::value, bool>::type = false> +std::ostream& operator<<(std::ostream& os, const Layout&) +{ + os << Layout::name; + return os; +} + +} // namespace tensor_layout +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp new file mode 100644 index 00000000000..d2c7e1c1b55 --- /dev/null +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -0,0 +1,25 @@ +#pragma once +#include "data_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace binary_element_wise { + +struct Add +{ + __host__ __device__ constexpr void + operator()(double& dst, const double& src1, const double& src2) const + { + dst = src1 + src2; + } + + __host__ __device__ constexpr void + operator()(float& dst, const float& src1, const float& src2) const + { + dst = src1 + src2; + } +}; + +} // namespace binary_element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp new file mode 100644 index 00000000000..ab1cbfed454 --- /dev/null +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -0,0 +1,334 @@ +#pragma once +#include "data_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +struct PassThrough +{ + __host__ __device__ void operator()(float& y, const float& x) const { y = x; } + + __host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; } + + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { y = x; } + + __host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; } + + __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; } + + __host__ __device__ void operator()(double& y, const double& x) const { y = x; } +}; + +struct Add +{ + __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const + { + y = x0 + x1; + } + + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + // FIXME - Use float (acc type) bias in the future. + y = x0 + x1; + } +}; + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {} + + __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const + { + y = alpha_ * x0 + beta_ * x1; + } + + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + // FIXME - Let x0 be acc type + y = static_cast(alpha_ * static_cast(x0) + beta_ * static_cast(x1)); + } + + float alpha_; + float beta_; +}; + +struct AddRelu +{ + __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const + { + const float a = x0 + x1; + y = a > 0 ? a : 0; + } + + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + const half_t a = x0 + x1; + y = a > 0 ? a : 0; + } +}; + +struct AddHardswish +{ + __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const + { + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; + y = c; + } + + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; + y = c; + } +}; + +struct AddReluAdd +{ + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const + { + half_t a = x0 + x1; + half_t b = a > 0 ? a : 0; + y = b + x2; + } + + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1, const float& x2) const + { + float a = x0 + x1; + float b = a > 0 ? a : 0; + float c = b + x2; + y = c; + } + + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1, const half_t& x2) const + { + float a = x0 + x1; + float b = a > 0 ? a : 0; + float c = b + x2; + y = c; + } +}; + +struct AddHardswishAdd +{ + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1, const float& x2) const + { + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; + float d = c + x2; + y = d; + } + + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const + { + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; + float d = c + x2; + y = d; + } +}; + +// Unary operators are usually called element-wisely before/after the reduction is executed on the +// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2 + +template +struct UnaryIdentic; + +template <> +struct UnaryIdentic +{ + __host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; }; + + __host__ __device__ void operator()(float& y, const float& x) const { y = x; }; +}; + +template <> +struct UnaryIdentic +{ + __host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; }; + + __host__ __device__ void operator()(float& y, const float& x) const + { + y = x / type_convert(divider_); + }; + + int32_t divider_ = 1; +}; + +template <> +struct UnaryIdentic +{ + __host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; }; + + __host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; }; +}; + +template <> +struct UnaryIdentic +{ + __host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; }; + + __host__ __device__ void operator()(double& y, const double& x) const { y = x; }; +}; + +template <> +struct UnaryIdentic +{ + __host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; }; + + __host__ __device__ void operator()(double& y, const double& x) const + { + y = x / type_convert(divider_); + }; + + int32_t divider_ = 1; +}; + +template <> +struct UnaryIdentic +{ + __host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; }; + + __host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; }; +}; + +template <> +struct UnaryIdentic +{ + __host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; }; + + __host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x / divider_; }; + + int32_t divider_ = 1; +}; + +template <> +struct UnaryIdentic +{ + __host__ __device__ UnaryIdentic(const int8_t divider = 1) { (void)divider; }; + + __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; }; +}; + +template +struct UnarySquare; + +template <> +struct UnarySquare +{ + __host__ __device__ UnarySquare(const int32_t divider = 1) { (void)divider; }; + + __host__ __device__ void operator()(float& y, const float& x) const { y = x * x; }; +}; + +template <> +struct UnarySquare +{ + __host__ __device__ UnarySquare(const int32_t divider = 1) { divider_ = divider; }; + + __host__ __device__ void operator()(float& y, const float& x) const + { + y = x * x / type_convert(divider_); + }; + + int32_t divider_ = 1; +}; + +template <> +struct UnarySquare +{ + __host__ __device__ UnarySquare(const int32_t divider = 1) { (void)divider; }; + + __host__ __device__ void operator()(double& y, const double& x) const { y = x * x; }; +}; + +template <> +struct UnarySquare +{ + __host__ __device__ UnarySquare(const int32_t divider = 1) { divider_ = divider; }; + + __host__ __device__ void operator()(double& y, const double& x) const + { + y = x * x / type_convert(divider_); + }; + + int32_t divider_ = 1; +}; + +template +struct UnaryAbs; + +template <> +struct UnaryAbs +{ + __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; + + __host__ __device__ void operator()(float& y, const float& x) const { y = abs(x); }; +}; + +template <> +struct UnaryAbs +{ + __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; + + __host__ __device__ void operator()(half_t& y, const half_t& x) const { y = __habs(x); }; +}; + +template <> +struct UnaryAbs +{ + __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; + + __host__ __device__ void operator()(double& y, const double& x) const { y = abs(x); }; +}; + +template <> +struct UnaryAbs +{ + __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; + + __host__ __device__ void operator()(int8_t& y, const int8_t& x) const + { + int8_t sgn = x >> (8 - 1); + + y = (x ^ sgn) - sgn; + }; +}; + +template +struct UnarySqrt; + +template <> +struct UnarySqrt +{ + __host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; }; + + __host__ __device__ void operator()(float& y, const float& x) const { y = sqrtf(x); }; +}; + +template <> +struct UnarySqrt +{ + __host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; }; + + __host__ __device__ void operator()(double& y, const double& x) const { y = sqrt(x); }; +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp new file mode 100644 index 00000000000..038e36f564d --- /dev/null +++ b/include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp @@ -0,0 +1,10 @@ +#pragma once +#include "data_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp new file mode 100644 index 00000000000..792060ca862 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -0,0 +1,489 @@ +#ifndef UTILITY_BLOCK_TO_CTILE_MAP +#define UTILITY_BLOCK_TO_CTILE_MAP + +#include "utility/math.hpp" +#include "utility/number.hpp" +#include "tensor_description/tensor_adaptor.hpp" +#include "tensor_description/multi_index_transform_helper.hpp" + +namespace ck { + +// Rows of column-vectors +template +struct BlockToCTileMap_M00_N0_M01 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_M00_N0_M01() = default; + + __host__ __device__ BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1) + : M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01)) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01_); + + const index_t grid_size = M00 * M01_ * N0; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return underlying_map_.CalculateBottomIndex(idx_top); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + if constexpr(DeviceCTileIndexCheck) + return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); + else + return true; + } + + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + if constexpr(DeviceCTileIndexCheck) + return true; // validity check moved to kernel + + const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + if(M0 % M01_ == 0) + { + return true; + } + else + { + return false; + } + } + + private: + __host__ __device__ static constexpr auto + GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01) + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01); + + const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_insert_transform(1), + make_unmerge_transform(make_tuple(M00, M01)), + make_pass_through_transform(make_tuple(N0))), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + + const auto cblockid_to_m00_n0_m01_block_cluster_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(1, M00, N0, M01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto cblockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_n0_m01_to_m0_n0_block_cluster_adaptor, + cblockid_to_m00_n0_m01_block_cluster_adaptor); + + return cblockid_to_m0_n0_block_cluster_adaptor; + } + + index_t M01_; + using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1)); + UnderlyingMap underlying_map_; +}; + +// Rows of column-vectors +// This C-tile map dynamically adjusts M01 when C-tile index is out of range +template +struct BlockToCTileMap_M00_N0_M01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default; + + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 8) + : M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const index_t grid_size = M0 * N0; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0); // swallow batch index + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } + + private: + index_t M01_; + CGridDesc_M_N c_grid_desc_m_n_; +}; + +// 2D slices of column-vectors in 3D space +// This C-tile map dynamically adjusts M01 when C-tile index is out of range +template +struct BlockToCTileMap_KSplit_M00_N0_M01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt() = default; + + __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 8, + index_t KSplit = 1) + : M01_(M01), KSplit_(KSplit), c_grid_desc_m_n_(c_grid_desc_m_n) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const index_t grid_size = M0 * N0 * KSplit_; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock); + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } + + private: + index_t M01_; + index_t KSplit_; + CGridDesc_M_N c_grid_desc_m_n_; +}; + +// Blocks of row-vectors +template +struct BlockToCTileMap_M00_N00_M01_N01 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_M00_N00_M01_N01() = default; + + __host__ __device__ BlockToCTileMap_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1, + index_t N01 = 1) + : M01_(M01), N01_(N01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01)) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01_); + const auto N00 = math::integer_divide_ceil(N0, N01_); + + const index_t grid_size = M00 * M01_ * N00 * N01_; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return underlying_map_.CalculateBottomIndex(idx_top); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + if constexpr(DeviceCTileIndexCheck) + return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); + else + return true; + } + + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + if constexpr(DeviceCTileIndexCheck) + return true; // validity check moved to kernel + + const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + if(M0 % M01_ == 0 && N0 % N01_ == 0) + { + return true; + } + else + { + return false; + } + } + + private: + __host__ __device__ static constexpr auto + GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01); + const auto N00 = math::integer_divide_ceil(N0, N01); + + const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_insert_transform(1), // swallow the carry from lower dimensions + make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); + + const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(1, M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto cblockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); + + return cblockid_to_m0_n0_block_cluster_adaptor; + } + + index_t M01_, N01_; + using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1)); + UnderlyingMap underlying_map_; +}; + +// 2D slices of row-vectors in 3D space +template +struct BlockToCTileMap_KSplit_M00_N00_M01_N01 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ BlockToCTileMap_KSplit_M00_N00_M01_N01() = default; + + __host__ BlockToCTileMap_KSplit_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1, + index_t N01 = 1, + index_t KSplit = 1) + : M01_(M01), + N01_(N01), + KSplit_(KSplit), + underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01, KSplit)) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01_); + const auto N00 = math::integer_divide_ceil(N0, N01_); + + const index_t grid_size = M00 * M01_ * N00 * N01_ * KSplit_; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return underlying_map_.CalculateBottomIndex(idx_top); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + if constexpr(DeviceCTileIndexCheck) + return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); + else + return true; + } + + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + if constexpr(DeviceCTileIndexCheck) + return true; // validity check moved to kernel + + const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + if(M0 % M01_ == 0 && N0 % N01_ == 0) + { + return true; + } + else + { + return false; + } + } + + private: + __host__ static constexpr auto GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01, + index_t KSplit) + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01); + const auto N00 = math::integer_divide_ceil(N0, N01); + + const auto ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(KSplit), + make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); + + const auto c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(KSplit, M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto c_blockid_to_ksplit_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor); + + return c_blockid_to_ksplit_m0_n0_block_cluster_adaptor; + } + + index_t M01_, N01_, KSplit_; + using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1, 1)); + UnderlyingMap underlying_map_; +}; + +template +__host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) +{ + bool is_valid = false; + + const index_t m_block = c_tile_dim[Number<0>{}]; + const index_t n_block = c_tile_dim[Number<1>{}]; + + if constexpr(CTileIdx::Size() == 2) + { + const index_t m_block_idx = c_tile_idx[Number<0>{}]; + const index_t n_block_idx = c_tile_idx[Number<1>{}]; + if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block) + { + is_valid = true; + } + } + else if constexpr(CTileIdx::Size() == 3) + { + const index_t ksplit_idx = c_tile_idx[Number<0>{}]; + const index_t m_block_idx = c_tile_idx[Number<1>{}]; + const index_t n_block_idx = c_tile_idx[Number<2>{}]; + if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block) + { + is_valid = true; + } + ignore = ksplit_idx; + } + + return is_valid; +} + +} // namespace ck + +#endif // UTILITY_BLOCK_TO_CTILE_MAP diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp new file mode 100644 index 00000000000..f3e9836d4f0 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp @@ -0,0 +1,638 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP +#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP + +#include "reduction_common.hpp" +#include "reduction_operator.hpp" +#include "reduction_functions_accumulate.hpp" +#include "reduction_functions_blockwise.hpp" +#include "reduction_functions_threadwise.hpp" + +#include "threadwise_tensor_slice_transfer.hpp" +#include "element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, + const OutGridDesc_M out_grid_desc_m, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) +{ + if constexpr(!OutputIndex) + { + (void)p_in_index_global; + (void)p_out_index_global; + + GridwiseReduction::Run(in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + block_group_size, + num_k_block_tile_iteration, + alpha, + p_in_value_global, + beta, + p_out_value_global); + } + else + { + GridwiseReduction::template RunWithIndex(in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + num_k_block_tile_iteration, + alpha, + p_in_value_global, + p_in_index_global, + beta, + p_out_value_global, + p_out_index_global); + }; +}; + +template +struct GridwiseReduction_mk_to_m_multiblock +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using BlockwiseReduce = PartitionedBlockwiseReduction; + + using ThreadwiseReduce = ThreadwiseReduction; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + using Accumulation = detail::AccumulateWithNanCheck; + + __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation& in_elementwise_op, + const AccElementwiseOperation& acc_elementwise_op, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) + { + const auto zeroVal = ReduceOperation::GetReductionZeroVal(); + + // LDS + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + const auto in_global_val_buf = + make_dynamic_buffer(p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + type_convert(zeroVal)); + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + in_thread_buf; + + StaticBuffer accu_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / block_group_size; + const index_t block_local_id = block_global_id % block_group_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize); + + index_t reducedTiles = 0; + do + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op(in_thread_buf(Number{}), + in_thread_buf(Number{})); + }); + }); + + ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + static_for<0, MThreadSliceSize, 1>{}( + [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if(thread_k_cluster_id == 0) + { + acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); + + accu_value_buf(I) *= alpha; + } + }); + + if(thread_k_cluster_id == 0) + { + if(block_group_size == 0 && !float_equal_zero{}(beta)) + { + StaticBuffer + priorDstValueBuf; + + auto threadwise_dst_load = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + OutDstVectorSize, + 1, + false>( + out_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_dst_load.Run(out_grid_desc_m, + out_global_val_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValueBuf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValueBuf[I]) * beta; + }); + }; + + auto threadwise_dst_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + true>( + out_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_store.Run(reduced_data_desc, + make_tuple(I0), + accu_value_buf, + out_grid_desc_m, + out_global_val_buf); + } + }; + + template + __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) + { + using BlockwiseReduceWithIndex = + PartitionedBlockwiseReductionWithIndex, + ThreadClusterArrangeOrder, + ReduceOperation, + PropagateNan>; + + using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck; + + (void)in_elementwise_op; + + // LDS + __shared__ AccDataType p_reduce_work_val_buffer[BlockSize]; + __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize]; + + const auto zeroVal = ReduceOperation::GetReductionZeroVal(); + + const auto in_global_val_buf = + make_dynamic_buffer(p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + type_convert(zeroVal)); + const auto in_global_idx_buf = make_dynamic_buffer( + p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + auto out_global_idx_buf = make_dynamic_buffer( + p_out_index_global, out_grid_desc_m.GetElementSpaceSize()); + + auto reduce_work_val_buf = + make_dynamic_buffer(p_reduce_work_val_buffer, BlockSize); + auto reduce_work_idx_buf = + make_dynamic_buffer(p_reduce_work_idx_buffer, BlockSize); + + StaticBuffer + in_thread_val_buf; + + StaticBuffer + in_thread_idx_buf; + + StaticBuffer accu_value_buf; + StaticBuffer accu_index_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_1d_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto threadwise_src_val_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) = zeroVal; + accu_index_buf(I) = 0; + }); + + constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize); + + index_t reducedTiles = 0; + + if constexpr(HaveIndexInput) + { + auto threadwise_src_idx_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + do + { + // load the thread slice + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + threadwise_src_idx_load.Run(in_grid_desc_m_k, + in_global_idx_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_idx_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + AccDataType tmpValue = zeroVal; + IndexDataType tmpIndex = 0; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + AccumulationWithIndex::Calculate(tmpValue, + in_thread_val_buf[Number{}], + tmpIndex, + in_thread_idx_buf[Number{}]); + }); + + BlockwiseReduceWithIndex::Reduce( + reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex); + + AccumulationWithIndex::Calculate( + accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex); + }); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + } + else + { + index_t indexOffset = 0; + + do + { + // load the thread slice + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + // initialize the indices for the per-thread to-reduce values + in_thread_idx_buf(Number{}) = + indexOffset + thread_k_cluster_id * KThreadSliceSize + iK(); + + // do element-wise pre-reduction operation + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); + }); + + AccDataType tmpValue = zeroVal; + IndexDataType tmpIndex = 0; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + AccumulationWithIndex::Calculate(tmpValue, + in_thread_val_buf[Number{}], + tmpIndex, + in_thread_idx_buf[Number{}]); + }); + + BlockwiseReduceWithIndex::Reduce( + reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex); + + AccumulationWithIndex::Calculate( + accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex); + }); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + indexOffset += K_BlockTileSize; + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + }; + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if(thread_k_cluster_id == 0) + { + // for indiced operation, acc_elementwise_op shoud do nothing + acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); + + accu_value_buf(I) *= alpha; + } + }); + + if(thread_k_cluster_id == 0) + { + if(!float_equal_zero{}(beta)) + { + StaticBuffer + priorDstValueBuf; + + auto threadwise_dst_load = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + OutDstVectorSize, + 1, + true>( + out_grid_desc_m, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_dst_load.Run(out_grid_desc_m, + out_global_val_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValueBuf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValueBuf[I]) * beta; + }); + }; + + auto threadwise_dst_val_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + out_grid_desc_m, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_dst_idx_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + out_grid_desc_m, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_val_store.Run(reduced_data_desc, + make_tuple(I0), + accu_value_buf, + out_grid_desc_m, + out_global_val_buf); + threadwise_dst_idx_store.Run(reduced_data_desc, + make_tuple(I0), + accu_index_buf, + out_grid_desc_m, + out_global_idx_buf); + } + }; +}; + +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp new file mode 100644 index 00000000000..ff01b881469 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp @@ -0,0 +1,498 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2021 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef CK_GRIDWISE_2D_REDUCTION_THREADWISE_HPP +#define CK_GRIDWISE_2D_REDUCTION_THREADWISE_HPP + +#include "data_type.hpp" +#include "reduction_common.hpp" +#include "reduction_operator.hpp" +#include "reduction_functions_accumulate.hpp" +#include "reduction_functions_threadwise.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, + const OutGridDesc_M out_grid_desc_m, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) +{ + if constexpr(!OutputIndex) + { + GridwiseReduction::Run(in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + alpha, + p_in_value_global, + beta, + p_out_value_global); + } + else + { + GridwiseReduction::template RunWithIndex(in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + alpha, + p_in_value_global, + p_in_index_global, + beta, + p_out_value_global, + p_out_index_global); + }; +}; + +template +struct GridwiseReduction_mk_to_m_threadwise +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + + __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation& in_elementwise_op, + const AccElementwiseOperation& acc_elementwise_op, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) + { + using ThreadwiseReduce = ThreadwiseReduction; + + const auto zeroVal = ReduceOperation::GetReductionZeroVal(); + + const auto in_global_val_buf = + make_dynamic_buffer(p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + type_convert(zeroVal)); + auto dst_global_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + + StaticBuffer + in_thread_buf; + + StaticBuffer accu_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); + + const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); + + auto threadwise_src_val_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); + + constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); + + index_t reducedLength = 0; + do + { + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op(in_thread_buf(Number{}), + in_thread_buf(Number{})); + }); + }); + + ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedLength += KThreadSliceSize; + } while(reducedLength < toReduceLength); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); + + accu_value_buf(I) *= alpha; + }); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + if(!float_equal_zero{}(beta)) + { + auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + 1, + 1, + true>( + out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); + + StaticBuffer + priorDstValue_buf; + + threadwise_dst_load.Run(out_grid_desc_m, + dst_global_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValue_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValue_buf[I]) * beta; + }); + }; + + auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + false>( + out_grid_desc_m, + make_multi_index(thread_global_1d_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_store.Run( + reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf); + }; + + template + __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation& in_elementwise_op, + const AccElementwiseOperation& acc_elementwise_op, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) + { + using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex; + + (void)acc_elementwise_op; + + const auto zeroVal = ReduceOperation::GetReductionZeroVal(); + + const auto in_global_val_buf = + make_dynamic_buffer(p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + type_convert(zeroVal)); + const auto in_global_idx_buf = make_dynamic_buffer( + p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); + + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + auto out_global_idx_buf = make_dynamic_buffer( + p_out_index_global, out_grid_desc_m.GetElementSpaceSize()); + + StaticBuffer + in_thread_val_buf; + + StaticBuffer + in_thread_idx_buf; + + StaticBuffer accu_value_buf; + StaticBuffer accu_index_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) = zeroVal; + accu_index_buf(I) = 0; + }); + + const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); + + auto threadwise_src_val_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); + + constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); + + index_t indexStart = 0; + index_t reducedLength = 0; + if constexpr(HaveIndexInput) + { + auto threadwise_src_idx_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); + + do + { + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + + threadwise_src_idx_load.Run(in_grid_desc_m_k, + in_global_idx_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_idx_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); + }); + }); + + ThreadwiseReduceWithIndex::Reduce( + in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + indexStart += KThreadSliceSize; + reducedLength += KThreadSliceSize; + } while(reducedLength < toReduceLength); + } + else + { + do + { + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + in_thread_idx_buf(Number{}) = indexStart + iK(); + + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); + }); + }); + + ThreadwiseReduceWithIndex::Reduce( + in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + indexStart += KThreadSliceSize; + reducedLength += KThreadSliceSize; + } while(reducedLength < toReduceLength); + }; + + // for indiced operation, acc_elementwise_op shoud do nothing + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); + + accu_value_buf(I) *= alpha; + }); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + if(!float_equal_zero{}(beta)) + { + auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + 1, + 1, + false>( + out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); + + StaticBuffer + priorDstValue_buf; + + threadwise_dst_load.Run(out_grid_desc_m, + out_global_val_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValue_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValue_buf[I]) * beta; + }); + }; + + auto threadwise_dst_val_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + false>( + out_grid_desc_m, + make_multi_index(thread_global_1d_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_dst_idx_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + false>( + out_grid_desc_m, + make_multi_index(thread_global_1d_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_val_store.Run( + reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_val_buf); + + threadwise_dst_idx_store.Run( + reduced_data_desc, make_tuple(I0), accu_index_buf, out_grid_desc_m, out_global_idx_buf); + }; +}; + +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp new file mode 100644 index 00000000000..c77d49ae94a --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp @@ -0,0 +1,150 @@ +#pragma once + +#include "cluster_descriptor.hpp" +#include "data_type.hpp" +#include "element_wise_operation.hpp" +#include "threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void kernel_binary_elementwise_1d(const ADataType* __restrict__ p_a_global, + const BDataType* __restrict__ p_b_global, + CDataType* __restrict__ p_c_global, + const GridDesc_M0 a_grid_desc_m0, + const GridDesc_M0 b_grid_desc_m0, + const GridDesc_M0 c_grid_desc_m0, + const ElementwiseFunctor functor) +{ + GridwiseBinEltwise::Run(p_a_global, + p_b_global, + p_c_global, + a_grid_desc_m0, + b_grid_desc_m0, + c_grid_desc_m0, + functor); +} + +template +struct GridwiseBinaryElementwise_1D +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto thread_desc_m0 = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static __device__ auto CalculateElementwiseIndex() + { + const index_t global_thread_id = get_thread_global_1d_id(); + return make_multi_index(global_thread_id * ScalarPerVector); + } + + __device__ static void Run(const ADataType* __restrict__ p_a_global, + const BDataType* __restrict__ p_b_global, + CDataType* __restrict__ p_c_global, + const GridDesc_M0 a_grid_desc_m0, + const GridDesc_M0 b_grid_desc_m0, + const GridDesc_M0 c_grid_desc_m0, + const ElementwiseFunctor functor) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_grid_desc_m0.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_grid_desc_m0.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_grid_desc_m0.GetElementSpaceSize()); + + StaticBuffer a_thread_buf; + StaticBuffer b_thread_buf; + StaticBuffer c_thread_buf; + + const auto thread_store_global_offset = CalculateElementwiseIndex(); + + auto a_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + ScalarPerVector, + 1, // SrcScalarStrideInVector + false>{a_grid_desc_m0, thread_store_global_offset}; + + auto b_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + ScalarPerVector, + 1, // SrcScalarStrideInVector + false>{b_grid_desc_m0, thread_store_global_offset}; + + auto c_global_write = + ThreadwiseTensorSliceTransfer_v1r3, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // DstVectorDim + ScalarPerVector, + InMemoryDataOperationEnum::Set, + 1, // DstScalarStrideInVector + false>{ + c_grid_desc_m0, thread_store_global_offset, PassThrough{}}; + + const index_t blockSize = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto m0 = c_grid_desc_m0.GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector; + const auto loop_step_index = make_multi_index(loop_step); + + index_t num_iter = m0 / (loop_step); + do + { + // read and process ScalarPerVector elements + a_global_load.Run( + a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf); + + b_global_load.Run( + b_grid_desc_m0, b_global_buf, thread_desc_m0, make_tuple(I0), b_thread_buf); + + static_for<0, ScalarPerVector, 1>{}([&](auto m) { + constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m)); + functor(c_thread_buf(Number{}), + a_thread_buf(Number{}), + b_thread_buf(Number{})); + }); + + c_global_write.Run(thread_desc_m0, + make_tuple(I0), // SrcSliceOriginIdx + c_thread_buf, + c_grid_desc_m0, + c_global_buf); + + a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index); + b_global_load.MoveSrcSliceWindow(b_grid_desc_m0, loop_step_index); + c_global_write.MoveDstSliceWindow(c_grid_desc_m0, loop_step_index); + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp similarity index 98% rename from composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp index fe56d0d813f..a9b6d8dfa0d 100644 --- a/composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp @@ -55,7 +55,7 @@ template , integral_constant) { - const auto a_global_buf = make_dynamic_buffer( + const auto a_global_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( + const auto b_global_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( + auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize()); const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0); @@ -381,9 +381,9 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN "wrong!"); // A matrix blockwise copy - auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< + auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< BlockSize, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, Sequence, ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, @@ -405,9 +405,9 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN make_multi_index(0, 0, 0, 0, 0)); // B matrix blockwise copy - auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< + auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< BlockSize, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, Sequence, BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, @@ -467,7 +467,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; // register allocation for output - auto c_thread_buf = make_static_buffer( + auto c_thread_buf = make_static_buffer( c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize()); ThreadwiseTensorSliceSet_v1( + auto a_block_even_buf = make_dynamic_buffer( p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); - auto b_block_even_buf = make_dynamic_buffer( + auto b_block_even_buf = make_dynamic_buffer( p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); - auto a_block_odd_buf = make_dynamic_buffer( + auto a_block_odd_buf = make_dynamic_buffer( p_a_block_double + a_block_aligned_space_size, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); - auto b_block_odd_buf = make_dynamic_buffer( + auto b_block_odd_buf = make_dynamic_buffer( p_b_block_double + b_block_aligned_space_size, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp new file mode 100644 index 00000000000..3b5daf6eadc --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -0,0 +1,572 @@ +#pragma once + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "blockwise_gemm_dl_v2r3.hpp" +#include "blockwise_tensor_slice_transfer_v5r1.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_set.hpp" +#include "element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_grid_desc_k0_m0_m1_k1, + b_grid_desc_k0_n0_n1_k1, + c_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_ctile_map, + integral_constant{}, + integral_constant{}); +} + +template +struct GridwiseGemmDl_km_kn_mn_v1r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // K1 should be Number<...> + static constexpr auto K1 = AGridDesc_K0_M_K1{}.GetLength(I2); + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = + math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = + math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && + K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2)) && + (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0); + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) + { + const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1; + + return has_main_k_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) + { + const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0; + + return has_double_tail_k_block_loop; + } + + __host__ __device__ static constexpr auto + MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1) + { + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + + const auto M1 = Number{}; + const auto M0 = M / M1; + + const auto a_grid_desc_k0_m0_m1_k1 = + transform_tensor_descriptor(a_grid_desc_k0_m_k1, + make_tuple(make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(M0, M1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return a_grid_desc_k0_m0_m1_k1; + } + + __host__ __device__ static constexpr auto + MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1) + { + const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + + const auto N1 = Number{}; + const auto N0 = N / N1; + + const auto b_grid_desc_k0_n0_n1_k1 = + transform_tensor_descriptor(b_grid_desc_k0_n_k1, + make_tuple(make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(N0, N1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return b_grid_desc_k0_n0_n1_k1; + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + constexpr auto M11 = + Number{}; + constexpr auto N11 = + Number{}; + + constexpr auto M10 = M1 / M11; + constexpr auto N10 = N1 / N11; + + const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), + make_unmerge_transform(make_tuple(N0, N10, N11))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_m0_m10_m11_n0_n10_n11; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N00_M01_N01( + c_grid_desc_m_n); + } + + using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap& block_2_ctile_map, + integral_constant, + integral_constant) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto c_m0_n0_block_cluster_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force index data into SGPR + const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); + const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); + + if(!block_2_ctile_map.ValidCTileIndex( + make_tuple(im0, in0), + make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0), + c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3)))) + { + return; + } + + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // A matrix in LDS memory, for blockwise GEMM + constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, for blockwise GEMM + constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() == + a_k0_m_k1_block_desc.GetElementSpaceSize() && + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() == + b_k0_n_k1_block_desc.GetElementSpaceSize() && + "wrong!"); + + // A matrix blockwise copy + auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< + BlockSize, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + remove_reference_t, + decltype(a_block_desc_k0_m0_m1_k1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths + ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths + ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder + false, + true>(a_grid_desc_k0_m0_m1_k1, + make_multi_index(0, im0, 0, 0), + a_block_desc_k0_m0_m1_k1, + make_multi_index(0, 0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< + BlockSize, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + remove_reference_t, + decltype(b_block_desc_k0_n0_n1_k1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths + BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths + BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder + false, + true>(b_grid_desc_k0_n0_n1_k1, + make_multi_index(0, in0, 0, 0), + b_block_desc_k0_n0_n1_k1, + make_multi_index(0, 0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlocl, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + M1PerThreadM111, + N1PerThreadN111, + KPerThread, + M11N11ThreadClusterM110Xs, + M11N11ThreadClusterN110Xs, + M1PerThreadM111, + N1PerThreadN111>{}; + + constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = + decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); + + constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed( + sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize()); + + // Initialize C + c_thread_buf.Clear(); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); + + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0); + + index_t k_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, + b_block_slice_copy_step); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, + b_block_slice_copy_step); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); + + k_block_data_begin += 2 * K0PerBlock; + } while(k_block_data_begin < K0 - 2 * K0PerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); + + block_sync_lds(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = + blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id()); + + ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_thread_desc_m0_m10_m11_n0_n10_n11), + decltype(c_grid_desc_m0_m10_m11_n0_n10_n11), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, + c_m10_m11_n10_n11_thread_tensor_lengths[I0], + c_m10_m11_n10_n11_thread_tensor_lengths[I1], + 1, + c_m10_m11_n10_n11_thread_tensor_lengths[I2], + c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_grid_desc_m0_m10_m11_n0_n10_n11, + make_multi_index(im0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], + in0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]), + ck::tensor_operation::element_wise::PassThrough{}} + .Run(c_thread_desc_m0_m10_m11_n0_n10_n11, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_m10_m11_n0_n10_n11, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp similarity index 88% rename from composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp index d91159b8849..a7ff81e2094 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp @@ -12,7 +12,6 @@ namespace ck { -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE template {}, integral_constant{}); } -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER -// pass tensor descriptor by CONSTANT void pointer -// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to -// non-modifiable parameter address space, so compiler can enable corresponding optimization -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_dlops_v1r2(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k_m0_m1_grid_desc, - const void CONSTANT* p_b_k_n0_n1_grid_desc, - const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, - const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - // first cast void CONSTANT void* to void* - // second cast void* to Desc* - // the copy constructor of tensor descriptor doesn't take address_space(4) - const auto a_k_m0_m1_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_a_k_m0_m1_grid_desc)); - const auto b_k_n0_n1_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_b_k_n0_n1_grid_desc)); - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = - *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc)); - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor)); - - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor, - integral_constant{}, - integral_constant{}); -} -#endif template {}), make_tuple(Sequence<0>{})); - return c_blockid_to_m0_n0_block_cluster_adaptor; + return cblockid_to_m0_n0_block_cluster_adaptor; } using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{})); @@ -321,22 +264,22 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 const AKM0M1GridDesc& a_k_m0_m1_grid_desc, const BKN0N1GridDesc& b_k_n0_n1_grid_desc, const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, - const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor, + const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor, integral_constant, integral_constant) { - const auto a_global_buf = make_dynamic_buffer( + const auto a_global_buf = make_dynamic_buffer( p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( + const auto b_global_buf = make_dynamic_buffer( p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( + auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); const auto K = a_k_m0_m1_grid_desc.GetLength(I0); // divide block work by [M, N] const auto c_m0_n0_block_cluster_idx = - c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( + cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( make_multi_index(get_block_1d_id())); // HACK: this force index data into SGPR @@ -372,7 +315,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 // A matrix blockwise copy auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4, ABlockTransferThreadSliceLengths_K_M0_M1, ABlockTransferThreadClusterLengths_K_M0_M1, @@ -398,7 +341,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 // B matrix blockwise copy auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4, BBlockTransferThreadSliceLengths_K_N0_N1, BBlockTransferThreadClusterLengths_K_N0_N1, @@ -460,7 +403,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; // register allocation for output - auto c_thread_buf = make_static_buffer( + auto c_thread_buf = make_static_buffer( c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); ThreadwiseTensorSliceSet_v1( + auto a_block_even_buf = make_dynamic_buffer( p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize()); - auto b_block_even_buf = make_dynamic_buffer( + auto b_block_even_buf = make_dynamic_buffer( p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize()); - auto a_block_odd_buf = make_dynamic_buffer( + auto a_block_odd_buf = make_dynamic_buffer( p_a_block_double + a_block_aligned_space_size, a_k_m0_m1_block_desc.GetElementSpaceSize()); - auto b_block_odd_buf = make_dynamic_buffer( + auto b_block_odd_buf = make_dynamic_buffer( p_b_block_double + b_block_aligned_space_size, b_k_n0_n1_block_desc.GetElementSpaceSize()); diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp similarity index 98% rename from composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp index 84ee6f40ec0..607a05d1561 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp @@ -15,7 +15,7 @@ template {}; constexpr auto I3 = Number<3>{}; - const auto a_global_buf = make_dynamic_buffer( + const auto a_global_buf = make_dynamic_buffer( p_a_global, a_e_k_global_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( + const auto b_global_buf = make_dynamic_buffer( p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize()); - auto c_global_buf = make_dynamic_buffer( + auto c_global_buf = make_dynamic_buffer( p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize()); constexpr auto E = EPerBlock * 3 * 3; @@ -181,7 +181,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 // A matrix blockwise copy auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4, ABlockTransferThreadSliceLengths_E_K, ABlockTransferThreadClusterLengths_E_K, @@ -221,11 +221,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3 b_e_n_ho_wo_global_desc, make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); - auto a_block_buf = make_dynamic_buffer( + auto a_block_buf = make_dynamic_buffer( p_shared_block, a_e_k_desc.GetElementSpaceSize()); // register allocation for output - StaticBuffer @@ -250,7 +250,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 BGlobalMoveSliceWindowStepHacks{}; // double regsiter buffer for b - StaticBuffer diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp new file mode 100644 index 00000000000..a36b5e53ce0 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp @@ -0,0 +1,1594 @@ +#ifndef CK_GRIDWISE_GEMM_V3_HPP +#define CK_GRIDWISE_GEMM_V3_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_set.hpp" +#include "blockwise_gemm_dlops_v3.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dlops_v3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatC* __restrict__ p_bias_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::ConvBiasActiv(p_a_grid, + p_b_grid, + p_bias_grid, + p_c_grid, + p_shared_block, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dlops_v3_resize_add( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatC* __restrict__ p_bias_grid, + FloatC* __restrict__ p_d_grid, + const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::ConvBiasActivResizeAdd(p_a_grid, + p_b_grid, + p_bias_grid, + p_d_grid, + p_shared_block, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dlops_v3_maxpool( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatC* __restrict__ p_bias_grid, + FloatC* __restrict__ p_c_grid, + FloatC* __restrict__ p_d_grid, + const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::ConvBiasActivMaxpool(p_a_grid, + p_b_grid, + p_bias_grid, + p_c_grid, + p_d_grid, + p_shared_block, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} + +template +struct GridwiseGemmDlops_km_kn_mn_v3 +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto E1 = Number{}; + static constexpr auto E2 = Number{}; + static constexpr auto K2 = Number{}; + + static constexpr auto NPerBlock = I1; + + static constexpr FloatAcc alpha = 0.3; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = Number{}; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_e0_e1_k1_e2_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(I1, Number{}, Number{}, Number{}), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = math::integer_least_multiple( + a_e0_e1_k1_e2_block_desc.GetElementSpaceSize(), max_lds_align); + + return a_block_space_size * sizeof(FloatAB); + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc) + { + const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0); + const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1); + const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); + const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); + + const auto K0 = K / KPerBlock; + const auto N0 = N / NPerBlock; + const auto H0 = Ho / HoPerBlock; + const auto W0 = Wo / WoPerBlock; + + const index_t grid_size = K0 * N0 * H0 * W0; + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainE0BlockLoop(const index_t E0) + { + const bool has_main_e0_block_loop = E0 > 1; + + return has_main_e0_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasMainE1BlockLoop() + { + const bool has_main_e1_block_loop = ((E1 + E1PerBlock) / (2 * E1PerBlock)) > 1; + + return has_main_e1_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailE1BlockLoop() + { + const bool has_double_tail_e1_block_loop = (E1 / E1PerBlock) % 2 == 0; + + return has_double_tail_e1_block_loop; + } + + __host__ __device__ static constexpr auto + MakeAE0E1K0K1E2GridDescriptor(const AGridDesc_E0_E1_K_E2& a_e0_e1_k_e2_grid_desc) + { + const auto E0 = a_e0_e1_k_e2_grid_desc.GetLength(I0); + const auto K = a_e0_e1_k_e2_grid_desc.GetLength(I2); + + const auto K1 = Number{}; + const auto K0 = K / K1; + + const auto a_e0_e1_k0_k1_e2_grid_desc = transform_tensor_descriptor( + a_e0_e1_k_e2_grid_desc, + make_tuple(make_pass_through_transform(E0), + make_pass_through_transform(E1), + make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + return a_e0_e1_k0_k1_e2_grid_desc; + } + + __host__ __device__ static constexpr auto MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor( + const BGridDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_grid_desc) + { + const auto E0 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I0); + // const auto E1 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I1); + const auto N = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I2); + const auto Ho = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I3); + const auto Wo = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I4); + // const auto E2 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I5); + + const auto H2 = Number{}; + const auto H1 = Number{}; + const auto H0 = Ho / (H1 * H2); + + const auto W2 = Number{}; + const auto W1 = Number{}; + const auto W0 = Wo / (W1 * W2); + + const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc = + transform_tensor_descriptor(b_e0_e1_n_ho_wo_e2_grid_desc, + make_tuple(make_pass_through_transform(E0), + make_pass_through_transform(E1), + make_pass_through_transform(N), + make_unmerge_transform(make_tuple(H0, H1, H2)), + make_unmerge_transform(make_tuple(W0, W1, W2)), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3, 4, 5>{}, + Sequence<6, 7, 8>{}, + Sequence<9>{})); + + return b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCK0K1NH0H1H2W0W1W2GridDescriptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc) + { + const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0); + const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1); + const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); + const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); + + const auto K1 = Number{}; + const auto K0 = K / K1; + + const auto H2 = Number{}; + const auto H1 = Number{}; + const auto H0 = Ho / (H1 * H2); + + const auto W2 = Number{}; + const auto W1 = Number{}; + const auto W0 = Wo / (W1 * W2); + + const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = transform_tensor_descriptor( + c_k_n_ho_wo_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_unmerge_transform(make_tuple(H0, H1, H2)), + make_unmerge_transform(make_tuple(W0, W1, W2))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{})); + + return c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc) + { + const auto K = d_k_n_hx_wx_grid_desc.GetLength(I0); + const auto N = d_k_n_hx_wx_grid_desc.GetLength(I1); + const auto Hx = d_k_n_hx_wx_grid_desc.GetLength(I2); + const auto Wx = d_k_n_hx_wx_grid_desc.GetLength(I3); + + const auto K1 = Number{}; + const auto K0 = K / K1; + +#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR + const auto H2 = Number{}; + const auto H1 = Number{}; + const auto H0 = Number{}; + + const auto W2 = Number{}; + const auto W1 = Number{}; + const auto W0 = Number{}; +#else + const auto H2 = HoPerThread / 2; + const auto H1 = HoPerBlock / HoPerThread; + const auto H0 = Hx / (H1 * H2); + + const auto W2 = WoPerThread / 2; + const auto W1 = WoPerBlock / WoPerThread; + const auto W0 = Wx / (W1 * W2); +#endif + + const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = transform_tensor_descriptor( + d_k_n_hx_wx_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_unmerge_transform(make_tuple(H0, H1, H2)), + make_unmerge_transform(make_tuple(W0, W1, W2))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{})); + + return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc) + { + const auto K = d_k_n_hx_wx_grid_desc.GetLength(I0); + const auto N = d_k_n_hx_wx_grid_desc.GetLength(I1); + const auto Hx = d_k_n_hx_wx_grid_desc.GetLength(I2); + const auto Wx = d_k_n_hx_wx_grid_desc.GetLength(I3); + + const auto K1 = Number{}; + const auto K0 = K / K1; + + const auto H2 = Number{}; + const auto H1 = Number{}; + + const auto W2 = Number{}; + const auto W1 = Number{}; + +#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR + const auto H0 = Number{}; + const auto W0 = Number{}; +#else + const auto H0 = Hx / (H1 * H2); + const auto W0 = Wx / (W1 * W2); +#endif + + const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = transform_tensor_descriptor( + d_k_n_hx_wx_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_unmerge_transform(make_tuple(H0, H1, H2)), + make_unmerge_transform(make_tuple(W0, W1, W2))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{})); + + return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCBlockIdToKNHoWoBlockClusterAdaptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc) + { + const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0); + const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1); + const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); + const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); + +#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR + const auto K0 = Number{}; + const auto N0 = Number{}; + const auto H0 = Number{}; + const auto W0 = Number{}; +#else + const auto K0 = K / KPerBlock; + const auto N0 = N / NPerBlock; + const auto H0 = Ho / HoPerBlock; + const auto W0 = Wo / WoPerBlock; +#endif + + const auto cblockid_to_k_n_ho_wo_block_cluster_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + return cblockid_to_k_n_ho_wo_block_cluster_adaptor; + } + + // using AGridDesc_E0_E1_K0_K1_E2 = + // decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{})); + // using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = + // decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{})); + // using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = + // decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{})); + // using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx = + // decltype(MakeDK0K1NH0H1HxW0W1WxGridDescriptor(DGridDesc_K_N_Hx_Wx{})); + + using CBlockIdToBlockClusterAdaptor_K_N_H_W = + decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{})); + + template + __host__ __device__ static constexpr auto MakeBiasK0K1GridDescriptor( + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc) + { + const auto K0 = c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetLength(I0); + const auto K1 = c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetLength(I1); + + return make_naive_tensor_descriptor_packed(make_tuple(K0, K1)); + } + + __host__ __device__ static constexpr auto MakeCK1NH2W2ThreadDescriptor() + { + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + return c_k1_n_h2_w2_thread_gemm_desc; + } + + // using CThreadDesc_K1_N_H2_W2 = decltype(MakeCK1NH2W2ThreadDescriptor()); + + __host__ __device__ static constexpr auto GetBlockWiseGemm() + { + constexpr auto max_lds_align = Number{}; + + constexpr auto a_e1_k1_e2_block_gemm_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, Number{}), max_lds_align); + + constexpr auto b_e1_n_h_w_e2_block_gemm_desc = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + I1, + Number{}, + Number{}, + Number{})); + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); + + auto blockwise_gemm = + BlockwiseGemmDlops_km_kn_m0m1n0n1_v3{}; + + return blockwise_gemm; + } + + __device__ static constexpr auto GetCThreadIndex() + { + auto blockwise_gemm = GetBlockWiseGemm(); + auto c_thread_mtx_index = + blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id()); + + return c_thread_mtx_index; + }; + + __device__ static constexpr auto GetCBlockIndex( + const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor) + { + const auto c_k_n_h_w_block_cluster_idx = + cblockid_to_k_n_h_w_block_cluster_adaptor.CalculateBottomIndex( + make_multi_index(get_block_1d_id())); + return c_k_n_h_w_block_cluster_idx; + } + + template + __device__ static void BiasOp(BiasGlobalBuff& bias_global_buf, + CThreadBuff& c_thread_buf, + const CBlockIndex& c_block_idx, + const CThreadIndex& c_thread_idx, + const BiasGridDesc_K0_K1& bias_k0_k1_grid_desc, + const CThreadDesc_K1_N_H2_W2&) + + { + const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); + + const auto k_thread_id = c_thread_idx[I0]; + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; + + constexpr auto bias_k0_k1_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); + + StaticBuffer + bias_thread_buf; + + const index_t k_thread_data_on_global = k_thread_id * KPerThread; + + auto bias_threadwise_transfer = + ThreadwiseTensorSliceTransfer_v2{}>, + Sequence<0, 1>, + 1, + CThreadTransferDstScalarPerVector, + false, + true>( + bias_k0_k1_grid_desc, make_multi_index(k_block_work_id, k_thread_data_on_global)); + + constexpr auto bias_k0_k1_global_tensor_step_hacks = make_tuple( + make_tuple(Sequence<0>{}, Sequence<0>{}), make_tuple(Sequence<0>{}, Sequence<0>{})); + + bias_threadwise_transfer.Run(bias_k0_k1_grid_desc, + bias_global_buf, + bias_k0_k1_thread_desc, + make_tuple(I0, I0), + bias_thread_buf, + bias_k0_k1_global_tensor_step_hacks); + + static_for<0, KPerThread, 1>{}([&](auto ki) { + static_for<0, HoPerThread, 1>{}([&](auto hi) { + static_for<0, WoPerThread, 1>{}([&](auto wi) { + constexpr index_t c_offset = + c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(make_tuple(ki, 0, hi, wi)); + c_thread_buf(Number{}) = + c_thread_buf[Number{}] + bias_thread_buf[ki]; + }); + }); + }); + } + + template + __device__ static void Activation(CThreadBuff& c_thread_buf, + const CThreadDesc_K1_N_H2_W2&, + integral_constant) + { + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; + + static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) { + if constexpr(activ_type_ == 1) + { + c_thread_buf(i) = c_thread_buf[i] >= 0 ? c_thread_buf[i] : alpha * c_thread_buf[i]; + } + else if constexpr(activ_type_ == 2) + { + FloatAcc x = 1.0 + exp(-c_thread_buf[i]); + + asm volatile("\n \ + v_rcp_f32 %0, %1 \n" + : "=v"(x) + : "0"(x)); + + c_thread_buf(i) = x; + } + }); + } + + template + __device__ static void + WriteOut(const CThreadBuff& c_thread_buf, + CGlobalBuff& c_global_buf, + const CBlockIndex& c_block_idx, + const CThreadIndex& c_thread_idx, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc) + { + const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); + const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); + const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); + const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); + + const auto k_thread_id = c_thread_idx[I0]; + const auto ho_thread_id = c_thread_idx[I2]; + const auto wo_thread_id = c_thread_idx[I3]; + + // hack to control index calculation when iterating over c_k_n_h0_h1_h2_w0_w1_w2_global + // tensor + constexpr auto c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks = CGlobalStepHacks{}; + + constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + I1, + I1, + I1, + Number{}, + I1, + I1, + Number{})); + + const index_t k_thread_data_on_global = k_thread_id * KPerThread; + + ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc), + decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc), + Sequence, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + make_multi_index(k_block_work_id, + k_thread_data_on_global, + n_block_work_id, + ho_block_work_id, + ho_thread_id, + 0, + wo_block_work_id, + wo_thread_id, + 0)) + .Run(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + c_global_buf, + c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks); + } + + template + __device__ static void + MaxPool(const CThreadBuff& c_thread_buf, + DGlobalBuff& d_global_buf, + const CBlockIndex& c_block_idx, + const CThreadIndex& c_thread_idx, + const CThreadDesc_K1_N_H2_W2&, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc) + { + + const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); + const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); + const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); + const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); + + const auto k_thread_id = c_thread_idx[I0]; + const auto ho_thread_id = c_thread_idx[I2]; + const auto wo_thread_id = c_thread_idx[I3]; + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; + + static_assert(HoPerThread % 2 == 0 && WoPerThread % 2 == 0, ""); + + constexpr auto HoPerThread_2 = HoPerThread / 2; + constexpr auto WoPerThread_2 = WoPerThread / 2; + + constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + I1, + I1, + I1, + Number{}, + I1, + I1, + Number{})); + + StaticBuffer + d_thread_buf; + + static_for<0, KPerThread, 1>{}([&](auto ki) { + static_for<0, HoPerThread_2, 1>{}([&](auto hi) { + static_for<0, WoPerThread_2, 1>{}([&](auto wi) { + constexpr index_t d_offset = + d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.CalculateOffset( + make_tuple(0, ki, 0, 0, 0, hi, 0, 0, wi)); + + constexpr index_t c_offset_0 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( + make_tuple(ki, 0, hi * 2, wi * 2)); + constexpr index_t c_offset_1 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( + make_tuple(ki, 0, hi * 2, wi * 2 + 1)); + constexpr index_t c_offset_2 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( + make_tuple(ki, 0, hi * 2 + 1, wi * 2)); + constexpr index_t c_offset_3 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( + make_tuple(ki, 0, hi * 2 + 1, wi * 2 + 1)); + + d_thread_buf(Number{}) = c_thread_buf[Number{}]; + d_thread_buf(Number{}) = + fmaxf(c_thread_buf[Number{}], d_thread_buf(Number{})); + d_thread_buf(Number{}) = + fmaxf(c_thread_buf[Number{}], d_thread_buf(Number{})); + d_thread_buf(Number{}) = + fmax(c_thread_buf[Number{}], d_thread_buf(Number{})); + }); + }); + }); + + const index_t k_thread_data_on_global = k_thread_id * KPerThread; + + constexpr auto d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = DGlobalStepHacks{}; + + ThreadwiseTensorSliceTransfer_v1r3< + FloatC, + FloatC, + decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc), + decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc), + Sequence, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + InMemoryDataOperationEnum::Set, + 1, + true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + make_multi_index(k_block_work_id, + k_thread_data_on_global, + n_block_work_id, + ho_block_work_id, + ho_thread_id, + 0, + wo_block_work_id, + wo_thread_id, + 0)) + .Run(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), + d_thread_buf, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + d_global_buf, + d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks); + } + + template + __device__ static void + ResizeAdd(const CThreadBuff& c_thread_buf, + DGlobalBuff& d_global_buf, + const CBlockIndex& c_block_idx, + const CThreadIndex& c_thread_idx, + const CThreadDesc_K1_N_H2_W2&, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc) + { + + const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); + const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); + const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); + const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); + + const auto k_thread_id = c_thread_idx[I0]; + const auto ho_thread_id = c_thread_idx[I2]; + const auto wo_thread_id = c_thread_idx[I3]; + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; + + constexpr auto HoPerThreadx2 = HoPerThread * 2; + constexpr auto WoPerThreadx2 = WoPerThread * 2; + + constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + I1, + I1, + I1, + Number{}, + I1, + I1, + Number{})); + + StaticBuffer + d_thread_buf; + + static_for<0, KPerThread, 1>{}([&](auto k_i) { + static_for<0, HoPerThreadx2, 1>{}([&](auto h_i) { + static_for<0, WoPerThreadx2, 1>{}([&](auto w_i) { + d_thread_buf(Number{}) = + c_thread_buf[Number{}]; + }); + }); + }); + + // hack to control index calculation when iterating over d_k_n_ho_wo_global tensor + constexpr auto d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = DGlobalStepHacks{}; + + const index_t k_thread_data_on_global = k_thread_id * KPerThread; + + ThreadwiseTensorSliceTransfer_v1r3< + FloatC, + FloatC, + decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc), + decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc), + Sequence, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + InMemoryDataOperationEnum::Add, + 1, + true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + make_multi_index(k_block_work_id, + k_thread_data_on_global, + n_block_work_id, + ho_block_work_id, + ho_thread_id, + 0, + wo_block_work_id, + wo_thread_id, + 0)) + .Run(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), + d_thread_buf, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + d_global_buf, + d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks); + } + + template + __device__ static void + GemmOp(const AGlobalBuff& a_global_buf, + const BGlobalBuff& b_global_buf, + CThreadBuff& c_thread_buf, + FloatAB* __restrict__ p_shared_block, + const CBlockIndex& c_block_idx, + const CThreadIndex& c_thread_idx, + const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CThreadDesc_K1_N_H2_W2&, + integral_constant) + { + constexpr auto HasMainE1BlockLoop = CalculateHasMainE1BlockLoop(); + constexpr auto HasDoubleTailE1BlockLoop = CalculateHasDoubleTailE1BlockLoop(); + + // const auto c_k_n_h_w_block_cluster_idx = + // GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); + // cblockid_to_k_n_h_w_block_cluster_adaptor.CalculateBottomIndex( + // make_multi_index(get_block_1d_id())); + + const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); + const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); + const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); + const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); + + constexpr auto max_lds_align = Number{}; + + constexpr auto a_e1_k1_e2_block_gemm_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, Number{}), max_lds_align); + + constexpr auto b_e1_n_h_w_e2_block_gemm_desc = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + I1, + Number{}, + Number{}, + Number{})); + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; + + auto blockwise_gemm = + BlockwiseGemmDlops_km_kn_m0m1n0n1_v3{}; + // blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id()); + + const auto ho_thread_id = c_thread_idx[I2]; + const auto wo_thread_id = c_thread_idx[I3]; + + constexpr auto a_e0_e1_k0_k1_e2_block_copy_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, I1, Number{}, Number{}), + max_lds_align); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, + ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_e0_e1_k0_k1_e2_grid_desc), + decltype(a_e0_e1_k0_k1_e2_block_copy_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3, 4>, + ABlockTransferSrcVectorDim, + 4, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_E2, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + false>(a_e0_e1_k0_k1_e2_grid_desc, + make_multi_index(0, 0, k_block_work_id, 0, 0), + a_e0_e1_k0_k1_e2_block_copy_desc, + make_multi_index(0, 0, 0, 0, 0)); + + constexpr auto a_block_slice_copy_step = make_multi_index(I1, 0, 0, 0, 0); + + constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + I1, + I1, + I1, + Number{}, + I1, + I1, + Number{}, + Number{})); + + auto b_threadwise_transfer = ThreadwiseTensorSliceTransfer_v2< + FloatAB, + FloatAB, + decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc), + decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc), + Sequence, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + make_multi_index(0, + 0, + n_block_work_id, + ho_block_work_id, + ho_thread_id, + 0, + wo_block_work_id, + wo_thread_id, + 0, + 0)); + + auto a_block_buf = make_dynamic_buffer( + p_shared_block, a_e0_e1_k0_k1_e2_block_copy_desc.GetElementSpaceSize()); + + //// register allocation for output + // StaticBuffer + // c_thread_buf; + + // initialize output thread tensor + ThreadwiseTensorSliceSet_v1>{} + .Run(c_k1_n_h2_w2_thread_gemm_desc, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + FloatAcc{0}); + + constexpr auto b_thread_slice_copy_step = + make_multi_index(0, E1PerBlock, 0, 0, 0, 0, 0, 0, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_e0_e1_k_e2_global_step_hacks = AGlobalStepHacks{}; + constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = BGlobalStepHacks{}; + + // double regsiter buffer for b + StaticBuffer + b_thread_even_buf, b_thread_odd_buf; + + if constexpr(HasMainE0BlockLoop) + { + const auto E0 = b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetLength(I0); + + index_t e0_block_data_begin = 0; + + do + { + // LDS double buffer: preload data + { + a_blockwise_copy.RunRead( + a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks); + + b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_even_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf); + } + + __syncthreads(); + + if constexpr(HasMainE1BlockLoop) + { + index_t e1_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + b_threadwise_transfer.MoveSrcSliceWindow( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + b_threadwise_transfer.Run( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_odd_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); + + b_threadwise_transfer.MoveSrcSliceWindow( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + b_threadwise_transfer.Run( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_even_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); + + e1_block_data_begin += 2 * E1PerBlock; + + } while(e1_block_data_begin < E1 - 2 * E1PerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left + { + b_threadwise_transfer.MoveSrcSliceWindow( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_odd_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + } + + a_blockwise_copy.MoveSrcSliceWindow(a_e0_e1_k0_k1_e2_grid_desc, + a_block_slice_copy_step, + AGlobalMoveSliceWindowStepHacks{}); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(-(E1 - E1PerBlock), 0, 0)); + + b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + e0_block_data_begin += 1; + + } while(e0_block_data_begin < E0); + } + else + { + // LDS double buffer: preload data + { + a_blockwise_copy.RunRead( + a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks); + + b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_even_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf); + } + + __syncthreads(); + + if constexpr(HasMainE1BlockLoop) + { + index_t e1_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + b_threadwise_transfer.MoveSrcSliceWindow( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_odd_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); + + b_threadwise_transfer.MoveSrcSliceWindow( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_even_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); + + e1_block_data_begin += 2 * E1PerBlock; + + } while(e1_block_data_begin < E1 - 2 * E1PerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left + { + b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_odd_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + } + } + } + + template + __device__ static void + Conv(const FloatAB* __restrict__ p_a_global, + const FloatAB* __restrict__ p_b_global, + const FloatC* __restrict__ p_bias_global, + FloatC* __restrict__ p_c_global, + FloatC* __restrict__ p_d_global, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant) + { + const auto bias_k0_k1_grid_desc = + MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); + auto d_global_buf = make_dynamic_buffer( + p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); + auto bias_global_buf = make_dynamic_buffer( + p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); + + // register allocation for output + StaticBuffer + c_thread_buf; + + const auto c_k_n_h_w_block_cluster_idx = + GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); + + const auto c_thread_mtx_index = GetCThreadIndex(); + + // GemmOp + GemmOp(a_global_buf, + b_global_buf, + c_thread_buf, + p_shared_block, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc, + integral_constant{}); + + // Output + WriteOut(c_thread_buf, + c_global_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + } + + template + __device__ static void ConvBiasActiv( + const FloatAB* __restrict__ p_a_global, + const FloatAB* __restrict__ p_b_global, + const FloatC* __restrict__ p_bias_global, + FloatC* __restrict__ p_c_global, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant, + integral_constant) + { + static constexpr auto activ_type = integral_constant{}; + + const auto bias_k0_k1_grid_desc = + MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); + auto bias_global_buf = make_dynamic_buffer( + p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); + + // register allocation for output + StaticBuffer + c_thread_buf; + + const auto c_k_n_h_w_block_cluster_idx = + GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); + + const auto c_thread_mtx_index = GetCThreadIndex(); + + // GemmOp + GemmOp(a_global_buf, + b_global_buf, + c_thread_buf, + p_shared_block, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc, + integral_constant{}); + + // Bias + BiasOp(bias_global_buf, + c_thread_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + bias_k0_k1_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc); + + // Activ + Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type); + + // Output + WriteOut(c_thread_buf, + c_global_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + } + + template + __device__ static void ConvBiasActivMaxpool( + const FloatAB* __restrict__ p_a_global, + const FloatAB* __restrict__ p_b_global, + const FloatC* __restrict__ p_bias_global, + FloatC* __restrict__ p_c_global, + FloatC* __restrict__ p_d_global, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant, + integral_constant) + { + static constexpr auto activ_type = integral_constant{}; + + const auto bias_k0_k1_grid_desc = + MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); + auto d_global_buf = make_dynamic_buffer( + p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); + auto bias_global_buf = make_dynamic_buffer( + p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); + + // register allocation for output + StaticBuffer + c_thread_buf; + + const auto c_k_n_h_w_block_cluster_idx = + GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); + + const auto c_thread_mtx_index = GetCThreadIndex(); + + // GemmOp + GemmOp(a_global_buf, + b_global_buf, + c_thread_buf, + p_shared_block, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc, + integral_constant{}); + + // Bias + BiasOp(bias_global_buf, + c_thread_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + bias_k0_k1_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc); + + // Activ + Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type); + + // Output + WriteOut(c_thread_buf, + c_global_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + + // MaxPool + MaxPool(c_thread_buf, + d_global_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + c_k1_n_h2_w2_thread_gemm_desc, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc); + } + + template + __device__ static void ConvBiasActivResizeAdd( + const FloatAB* __restrict__ p_a_global, + const FloatAB* __restrict__ p_b_global, + const FloatC* __restrict__ p_bias_global, + FloatC* __restrict__ p_d_global, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant, + integral_constant) + { + static constexpr auto activ_type = integral_constant{}; + + const auto bias_k0_k1_grid_desc = + MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); + auto d_global_buf = make_dynamic_buffer( + p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); + auto bias_global_buf = make_dynamic_buffer( + p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); + + // register allocation for output + StaticBuffer + c_thread_buf; + + const auto c_k_n_h_w_block_cluster_idx = + GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); + + const auto c_thread_mtx_index = GetCThreadIndex(); + + // GemmOp + GemmOp(a_global_buf, + b_global_buf, + c_thread_buf, + p_shared_block, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc, + integral_constant{}); + + // Bias + BiasOp(bias_global_buf, + c_thread_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + bias_k0_k1_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc); + + // Activ + Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type); + + // Resize_Add + ResizeAdd(c_thread_buf, + d_global_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + c_k1_n_h2_w2_thread_gemm_desc, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc); + } +}; + +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp new file mode 100644 index 00000000000..20c3a0b6185 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -0,0 +1,364 @@ +#pragma once +#include "common_header.hpp" +#include "tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" + +namespace ck { + +template +struct GridwiseGemmPipeline_v1; + +// 1-stage prefetch +template <> +struct GridwiseGemmPipeline_v1<1> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +// 2-stage prefetch +template <> +struct GridwiseGemmPipeline_v1<2> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + { + // TODO: improve applicability + return num_loop % 2 == 0; + } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return (num_loop / 2) > 1; + } + + template + static __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // preload data into LDS + { + // Read 0 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Read 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Write i + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Read i+2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm i + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Write i+1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); + + // Read i+3 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + + // Sync + block_sync_lds(); + + // Gemm i+1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + { + // Write num_loop - 2 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm num_loop - 2 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Write num_loop - 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); + + // Sync + block_sync_lds(); + + // Gemm num_loop - 1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +template +struct GridwiseGemmPipelineInterwave_v1; + +template <> +struct GridwiseGemmPipelineInterwave_v1<1> +{ + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + static __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // block_sync_lds(); // moved into blockwise_gemm + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +// Note: 2 stage prefetch not optimized for inter-wave loop scheduler +template <> +struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2> +{ +}; + +template +constexpr auto GridwiseGemmPipeline_v1_Selector() +{ + if constexpr(LoopSched == LoopScheduler::Default) + { + return GridwiseGemmPipeline_v1{}; + } + else if constexpr(LoopSched == LoopScheduler::Interwave) + { + return GridwiseGemmPipelineInterwave_v1{}; + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp new file mode 100644 index 00000000000..bc8850e4a6a --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -0,0 +1,871 @@ +#pragma once +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v6r1.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" +#include "reduction_functions_threadwise.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + DPtrsGlobal p_ds_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const DxsInElementwiseOperation dxs_in_element_op, + const DxsOutElementwiseOperation dxs_out_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_ds_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + dxs_in_element_op, + dxs_out_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + d_grid_desc_mblock_mperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_ds_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = dxs_in_element_op; + ignore = dxs_out_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = d_grid_desc_mblock_mperblock; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + // static_assert(is_known_at_compile_time>::value && + // is_known_at_compile_time>::value, + // "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr auto + MakeDGridDescriptor_MBlock_MPerBlock(const DGridDesc_M& d_grid_desc_m) + { + const auto M = d_grid_desc_m.GetLength(I0); + const auto MBlock = M / MPerBlock; + + const auto d_grid_desc_mblock_mperblock = transform_tensor_descriptor( + d_grid_desc_m, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{}))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + + return d_grid_desc_mblock_mperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using DGridDescriptor_MBlock_MPerBlock = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + DPtrsGlobal p_ds_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const DxsInElementwiseOperation& dxs_in_element_op, + const DxsOutElementwiseOperation& dxs_out_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C + reduction + write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + // TODO: this should be implemented as a blockwise reduction + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) * + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + BlockSize, + "wrong!"); + + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) == + 0 && + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1); + + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; + + // VGPR c_reduce_thread_desc_mperblock_nperblock + constexpr auto c_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + // VGPR d_reduce_thread_desc_mperblock + constexpr auto d_reduce_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + // VGPR d_reduce_thread_desc_mblock_mperblock + constexpr auto d_reduce_thread_desc_mblock_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); + + auto c_reduce_thread_buf = make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + // reduce: threadwise copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = + c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; + + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatCShuffle, + FloatReduceAcc, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; + + auto dxs_reduce_thread_copy_vgpr_to_global = generate_tuple( + [&](auto I) { + auto p_d_grid = p_ds_grid[I]; + auto d_out_element_op = dxs_out_element_op[I]; + + return ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + remove_pointer_t, + decltype(d_reduce_thread_desc_mblock_mperblock), + decltype(d_grid_desc_mblock_mperblock), + decltype(d_out_element_op), + Sequence<1, mreduce_per_thread>, + Sequence<0, 1>, + 1, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + DGlobalMemoryDataOperation::At(I), + 1, + false>{d_grid_desc_mblock_mperblock, + make_multi_index(block_work_idx[I0], // mblock + c_reduce_thread_data_idx_begin[I0]), // mperblock + d_out_element_op}; + }, + Number{}); + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // TODO - extract following into reduction_blockwise + { + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf); + + static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) { + auto& p_d_grid = p_ds_grid[In]; + + auto d_grid_buf = make_dynamic_buffer( + p_d_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize()); + + auto d_thread_buf = + make_static_buffer( + d_reduce_thread_desc_mperblock.GetElementSpaceSize()); + + auto& d_in_element_op = dxs_in_element_op[In]; + + auto& d_reduce_thread_copy_vgpr_to_global = + dxs_reduce_thread_copy_vgpr_to_global(In); + + using DReduceOperation = remove_cvref_t; + using ThreadwiseReduce = + ThreadwiseReduction; + + // Global write Gemm shuffle + reduction + const auto d_zeroVal = DReduceOperation::GetReductionZeroVal(); + + static_for<0, mreduce_per_thread, 1>{}( + [&](auto I) { d_thread_buf(I) = d_zeroVal; }); + + // reduce in VGPR + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto offset = + Number{}; + + d_in_element_op(c_reduce_thread_buf(offset), + c_reduce_thread_buf(offset)); + }); + }); + + ThreadwiseReduce::Reduce(c_reduce_thread_buf, d_thread_buf); + + // copy from VGPR to Global + d_reduce_thread_copy_vgpr_to_global.Run( + d_reduce_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + d_thread_buf, + d_grid_desc_mblock_mperblock, + d_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + d_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + d_grid_desc_mblock_mperblock, + make_tuple(c_global_step[I0], c_global_step[I1])); + } + }); + } + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + + // Reduction + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp new file mode 100644 index 00000000000..55390dbc864 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -0,0 +1,645 @@ +#pragma once +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v6r1.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp new file mode 100644 index 00000000000..0d3f8ddefb2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -0,0 +1,974 @@ +#pragma once + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v6r1.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" + +namespace ck { + +// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to +// be used for low_lengths that are known at compile time and are power of 2, otherwise performance +// will be very bad +template +struct Merge_v4_no_carry +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = + decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v4_no_carry() = default; + + __host__ __device__ constexpr Merge_v4_no_carry(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + // division and mod + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_low(i) = tmp / this->low_lengths_scan_[i]; + tmp %= this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_up_diff, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + constexpr auto INm1 = Number{}; + + index_t tmp = idx_up_new[I0]; + + idx_low(INm1) = tmp; + idx_diff_low(INm1) = idx_up_diff[I0]; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v3_direct_division_mod_wrw, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan_ "); + print_multi_index(low_lengths_scan_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +__host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLengths& low_lengths) +{ + return Merge_v4_no_carry{low_lengths}; +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_bwd_weight(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const CBlockClusterAdaptor c_block_cluster_adaptor) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + c_block_cluster_adaptor); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + // M0/M1/M1Padding + static constexpr auto M1PerBlock = Number{}; + static constexpr auto M0PerBlock = Number{}; + static constexpr auto M1Padding = Number{}; + + // N0/N1/N1Padding + static constexpr auto N1PerBlock = Number{}; + static constexpr auto N0PerBlock = Number{}; + static constexpr auto N1Padding = Number{}; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + if constexpr(ABlockLdsExtraM1Wrw) + { + constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor( + make_tuple( + Number{}, Number{}, Number{}, K1), + make_tuple(Number{} * (Number{} * K1 + M1Padding), + Number{} * K1 + M1Padding, + K1, + I1)); + + constexpr auto a_block_desc_k0_m_k1_tmp = transform_tensor_descriptor( + a_block_desc_k0_m0_m1_k1, + make_tuple(make_pass_through_transform(Number{}), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_block_desc_k0_m_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_b_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + if constexpr(ABlockLdsExtraM1Wrw) + { + constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor( + make_tuple(Number<1>{}, + Number{}, + Number{}, + Number{}, + K1), + make_tuple(Number{} * Number{} * + (Number{} * K1 + M1Padding), + Number{} * (Number{} * K1 + M1Padding), + Number{} * K1 + M1Padding, + K1, + I1)); + + constexpr auto a_block_desc_b_k0_m_k1_tmp = transform_tensor_descriptor( + a_block_desc_b_k0_m0_m1_k1, + make_tuple(make_pass_through_transform(Number<1>{}), + make_pass_through_transform(Number{}), + make_merge_transform_v4_no_carry( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return a_block_desc_b_k0_m_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + + return a_block_desc_b_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + if constexpr(BBlockLdsExtraN1Wrw) + { + constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor( + make_tuple( + Number{}, Number{}, Number{}, K1), + make_tuple(Number{} * (Number{} * K1 + N1Padding), + Number{} * K1 + N1Padding, + K1, + I1)); + + constexpr auto b_block_desc_k0_n_k1_tmp = transform_tensor_descriptor( + b_block_desc_k0_n0_n1_k1, + make_tuple(make_pass_through_transform(Number{}), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_block_desc_k0_n_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_b_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + if constexpr(BBlockLdsExtraN1Wrw) + { + constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor( + make_tuple(Number<1>{}, + Number{}, + Number{}, + Number{}, + K1), + make_tuple(Number{} * Number{} * + (Number{} * K1 + N1Padding), + Number{} * (Number{} * K1 + N1Padding), + Number{} * K1 + N1Padding, + K1, + I1)); + + constexpr auto b_block_desc_b_k0_n_k1_tmp = transform_tensor_descriptor( + b_block_desc_b_k0_n0_n1_k1, + make_tuple(make_pass_through_transform(Number<1>{}), + make_pass_through_transform(Number{}), + make_merge_transform_v4_no_carry( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return b_block_desc_b_k0_n_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + + return b_block_desc_b_k0_n_k1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = math::integer_least_multiple( + a_b_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = math::integer_least_multiple( + b_b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto c_block_size = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); + + return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); + const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) && + K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) && + K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) && + KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + // const bool has_main_k0_block_loop = K0 > K0PerBlock; + const index_t num_loop = K0 / K0PerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + + // return has_main_k0_block_loop; + } + + __host__ __device__ static constexpr auto + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + return transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( + const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) + { + return BlockToCTileMap_KSplit_M00_N00_M01_N01( + c_m_n_grid_desc, M01, N01, KBatch); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + } + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const CBlockClusterAdaptor& c_block_cluster_adaptor) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t k_batch_id = block_work_idx[I0]; + + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1(); + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + constexpr index_t KPack = + math::max(K1, MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = p_shared_block; + FloatAB* p_b_block = p_shared_block + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_b_k0_m_k1_grid_desc, + a_b_k0_m_k1_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_b_k0_n_k1_grid_desc, + b_b_k0_n_k1_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared_block), + c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + static_assert(M1 == MWave, ""); + static_assert(N1 == NWave, ""); + static_assert(M2 * M3 * M4 == MPerXDL, ""); + static_assert(N2 == NPerXDL, ""); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); + } + }); + } + } +}; // namespace ck + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp new file mode 100644 index 00000000000..974455fa3b7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -0,0 +1,550 @@ +#pragma once + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / (K0PerBlock * K1); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + using BlockwiseGemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); + using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0), + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_k0_n_k1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], + m_thread_data_on_grid_idx[I4], + n_thread_data_on_grid_idx[I2]), + c_element_op}; + + c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp similarity index 68% rename from composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index f27fc73b3b5..a54906cfbc5 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -5,20 +5,22 @@ #include "multi_index_transform_helper.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_set.hpp" namespace ck { -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE template __global__ void @@ -31,8 +33,12 @@ __global__ void const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor) { +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -45,63 +51,35 @@ __global__ void a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + a_element_op, + b_element_op, + c_element_op, c_block_cluster_adaptor); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_b_k0_m_k1_grid_desc, - const void CONSTANT* p_b_b_k0_n_k1_grid_desc, - const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - const void CONSTANT* p_c_block_cluster_adaptor) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - const auto a_b_k0_m_k1_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_a_b_k0_m_k1_grid_desc)); - const auto b_b_k0_n_k1_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_b_b_k0_n_k1_grid_desc)); - const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = - *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc)); - const auto c_block_cluster_adaptor = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor)); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); -} -#endif template + index_t CThreadTransferDstScalarPerVector> struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 { static constexpr auto I0 = Number<0>{}; @@ -151,6 +121,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 // K1 should be Number<...> static constexpr auto K1 = Number{}; + using ThisThreadBlock = ThisThreadBlock; + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { constexpr auto max_lds_align = K1; @@ -196,12 +168,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template __host__ __device__ static constexpr bool CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, const CMNGridDesc& c_m_n_grid_desc, - index_t M01, - index_t N01) + const Block2CTileMap& block_2_ctile_map) { static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); @@ -225,31 +197,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) + if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) + { return false; + } // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } - __host__ __device__ static constexpr index_t - CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch; - - return grid_size; - } - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) { const bool has_main_k0_block_loop = K0 > K0PerBlock; @@ -304,44 +260,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 NRepeat, K1>; - return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_m_n_grid_desc); } // return block_id to C matrix tile idx (m0, n0) mapping __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( - const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) + const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_pass_through_transform(KBatch), - make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); - - const auto c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto c_blockid_to_kbatch_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor); - - return c_blockid_to_kbatch_m0_n0_block_cluster_adaptor; + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + c_m_n_grid_desc, 8, KBatch); } using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); @@ -355,13 +282,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, const CBlockClusterAdaptor& c_block_cluster_adaptor) { - const auto a_grid_buf = make_dynamic_buffer( + const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( + const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( + auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); @@ -370,7 +300,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 const auto block_work_idx = c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I0), + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I1)))) + { + return; + } + const index_t k_batch_id = block_work_idx[I0]; + // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); @@ -447,57 +386,63 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 }(); // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_K0_M_K1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_b_k0_m_k1_grid_desc), - decltype(a_b_k0_m_k1_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 2, 1, 3>, - ABlockTransferSrcVectorDim, - 3, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( a_b_k0_m_k1_grid_desc, make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, a_b_k0_m_k1_block_desc, - make_multi_index(0, 0, 0, 0)); + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_b_k0_n_k1_grid_desc), - decltype(b_b_k0_n_k1_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<0, 2, 1, 3>, - BBlockTransferSrcVectorDim, - 3, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( b_b_k0_n_k1_grid_desc, make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, b_b_k0_n_k1_block_desc, - make_multi_index(0, 0, 0, 0)); + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx @@ -531,49 +476,38 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); - // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; - constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{}; - - // hack to control index calculation when move slice window for A and B matrix for - // threadwise copy - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{}; - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{}; - - auto a_block_buf = make_dynamic_buffer( + auto a_block_buf = make_dynamic_buffer( p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer( + auto b_block_buf = make_dynamic_buffer( p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); // preload data into LDS { - a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); - b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); } + // Initialize C + c_thread_buf.Clear(); + // main body - index_t k_block_data_begin = 0; if constexpr(HasMainKBlockLoop) { + index_t k0_block_data_begin = 0; + do { - a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, - a_block_slice_copy_step, - a_k0_m_k1_grid_move_slice_window_step_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, - b_block_slice_copy_step, - b_k0_n_k1_grid_move_slice_window_step_hack); + a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step); - a_blockwise_copy.RunRead( - a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); block_sync_lds(); - b_blockwise_copy.RunRead( - b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); @@ -582,8 +516,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); - k_block_data_begin += K0PerBlock; - } while(k_block_data_begin < (K0 - K0PerBlock)); + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); } // tail @@ -596,7 +530,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 // output: register to global memory { constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = - blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor(); + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); @@ -622,8 +556,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; - const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor( make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), @@ -648,6 +580,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 FloatC, decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc), + CElementwiseOperation, Sequence, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, @@ -664,14 +597,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 m_thread_data_on_grid_idx[I2], m_thread_data_on_grid_idx[I3], m_thread_data_on_grid_idx[I4], - n_thread_data_on_grid_idx[I2])}; + n_thread_data_on_grid_idx[I2]), + c_element_op}; c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_grid_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); + c_grid_buf); } } }; // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp new file mode 100644 index 00000000000..dbff1577e1f --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -0,0 +1,723 @@ +#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R4R2_HPP +#define CK_GRIDWISE_GEMM_XDLOPS_V2R4R2_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v6r1.hpp" +#include "threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r4r2(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const CBlockClusterAdaptor c_block_cluster_adaptor) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + c_block_cluster_adaptor); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto c_block_size = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); + + return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); + const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); + + if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) && + K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) && + K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) && + KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = K0 > K0PerBlock; + + return has_main_k0_block_loop; + } + + __host__ __device__ static constexpr auto + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + return transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( + const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) + { + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + c_m_n_grid_desc, 8, KBatch); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + } + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const CBlockClusterAdaptor& c_block_cluster_adaptor) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t k_batch_id = block_work_idx[I0]; + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto a_b_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto b_b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = p_shared_block; + FloatAB* p_b_block = p_shared_block + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // preload data into LDS + { + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); + + a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainKBlockLoop) + { + index_t k0_block_data_begin = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step); + + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); + + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared_block), + c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + static_assert(M1 == MWave, ""); + static_assert(N1 == NWave, ""); + static_assert(M2 * M3 * M4 == MPerXDL, ""); + static_assert(N2 == NPerXDL, ""); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); + } + }); + } + } +}; // namespace ck + +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp new file mode 100644 index 00000000000..ffa82a75703 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -0,0 +1,716 @@ +#pragma once +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v6r1.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" +#include "tensor_space_filling_curve.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v3r1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename FloatCShuffle, + typename FloatC, + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename AGridDesc_AK0_M_AK1, + typename BGridDesc_BK0_N_BK1, + typename CGridDesc_M_N, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t AK1Value, + index_t BK1Value, + index_t MPerXdl, + index_t NPerXdl, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_AK0_M_AK1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumGemmKPrefetchStage = 1> +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + constexpr auto max_lds_align = AK1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(AK0, Number{}, AK1), max_lds_align); + } + }(); + + return a_block_desc_ak0_m_ak1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + constexpr auto max_lds_align = BK1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(BK0, Number{}, BK1), max_lds_align); + } + }(); + + return b_block_desc_bk0_n_bk1; + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_ak0_m_ak1.GetElementSpaceSize(), AK1); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_bk0_n_bk1.GetElementSpaceSize(), BK1); + + // LDS allocation for C shuffle in LDS + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + constexpr auto c_block_size = + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + // static_assert(is_known_at_compile_time>::value && + // is_known_at_compile_time>::value, + // "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + MBlock, Number{}, Number{})), + make_unmerge_transform(make_tuple( + NBlock, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I3)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t k_pack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle, + MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle, + NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, + 5, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(0, 0, 0, 0, 0, 0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0); + + static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, + NXdlPerWave, + CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_shuffle_block_buf, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp new file mode 100644 index 00000000000..3a7a551181b --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -0,0 +1,758 @@ +#ifndef CK_GRIDWISE_GEMM_XDLOPS_V3R2_HPP +#define CK_GRIDWISE_GEMM_XDLOPS_V3R2_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v6r2.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v3r2( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_c0_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_c0_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename FloatC, + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename AGridDesc_K0_M_K1, + typename BGridDesc_K0_N_K1, + typename CGridDesc_M_N, + typename C0GridDesc_M_N, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t MPerXdl, + index_t NPerXdl, + index_t K1Value, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumGemmKPrefetchStage = 1> +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + constexpr auto c_block_size = + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / (K0PerBlock * K1); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + const CGridDesc_M_N_& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + MBlock, Number{}, Number{})), + make_unmerge_transform(make_tuple( + NBlock, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + auto c0_grid_buf = make_dynamic_buffer( + p_c0_grid, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I3)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_k0_n_k1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 7>{}) + + ); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r2< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle, + MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle, + NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, + FloatC, // typename Src0Data, + FloatC, // typename Src1Data, + FloatC, // typename DstData, + decltype( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, + 5, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, + true, // bool ThreadTransferSrc0ResetCoordinateAfterRun, + false, // bool ThreadTransferSrc1ResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(0, 0, 0, 0, 0, 0), + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0); + + static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, + NXdlPerWave, + CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_block_buf, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_buf, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp new file mode 100644 index 00000000000..745dfde0ba3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -0,0 +1,794 @@ +#pragma once +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v6r3.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v3r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const FloatC* __restrict__ p_c1_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_c0_grid, + p_c1_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_c0_grid; + ignore = p_c1_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename FloatC, + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename AGridDesc_K0_M_K1, + typename BGridDesc_K0_N_K1, + typename CGridDesc_M_N, + typename C0GridDesc_M_N, + typename C1GridDesc_M_N, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t MPerXdl, + index_t NPerXdl, + index_t K1Value, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumGemmKPrefetchStage = 1> +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + constexpr auto c_block_size = + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / (K0PerBlock * K1); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + const CGridDesc_M_N_& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + MBlock, Number{}, Number{})), + make_unmerge_transform(make_tuple( + NBlock, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const FloatC* __restrict__ p_c1_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + auto c0_grid_buf = make_dynamic_buffer( + p_c0_grid, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + auto c1_grid_buf = make_dynamic_buffer( + p_c1_grid, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I3)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_k0_n_k1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_tuple(make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per + // shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per + // shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 7>{}) + + ); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r3< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle, + MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle, + NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, + FloatC, // typename Src0Data, + FloatC, // typename Src1Data, + FloatC, // typename Src2Data, + FloatC, // typename DstData, + decltype( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, + 5, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, + true, // bool ThreadTransferSrc0ResetCoordinateAfterRun, + false, // bool ThreadTransferSrc1ResetCoordinateAfterRun, + false, // bool ThreadTransferSrc2ResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(0, 0, 0, 0, 0, 0), + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0); + + static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, + NXdlPerWave, + CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_block_buf, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_buf, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c1_grid_buf, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveSrc2SliceWindow( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + + c_block_copy_lds_to_global.MoveSrc2SliceWindow( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveSrc2SliceWindow( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp new file mode 100644 index 00000000000..6d95aec9384 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp @@ -0,0 +1,80 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef CK_GRIDWISE_SET_BUFFER_VALUE_HPP +#define CK_GRIDWISE_SET_BUFFER_VALUE_HPP + +#include "threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffer_desc, + DataType* const __restrict__ p_global, + DataType value) + +{ + + using PassThroughOp = tensor_operation::element_wise::UnaryIdentic; + + constexpr auto I0 = Number<0>{}; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const index_t thread_global_id = block_global_id * BlockSize + thread_local_id; + + StaticBuffer value_buf; + + value_buf(I0) = value; + + constexpr auto val_buff_desc = make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); + + auto global_buf = make_dynamic_buffer( + p_global, grid_1d_buffer_desc.GetElementSpaceSize()); + + if(thread_global_id < grid_1d_buffer_desc.GetElementSize()) + { + auto threadwise_store = ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + grid_1d_buffer_desc, make_multi_index(thread_global_id), PassThroughOp{}); + + threadwise_store.Run( + val_buff_desc, make_tuple(I0), value_buf, grid_1d_buffer_desc, global_buf); + } +}; + +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp new file mode 100644 index 00000000000..3dcfe3a0309 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp @@ -0,0 +1,122 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef CK_REDUCTION_FUNCTIONS_THREADWISE_HPP +#define CK_REDUCTION_FUNCTIONS_THREADWISE_HPP + +#include "reduction_functions_accumulate.hpp" + +namespace ck { + +// Assume +// 1) SrcDesc is known at compile-time +// 2) DstDesc is known at compile-time +// 3) SrcBuffer is static buffer +// 4) DstBuffer is static buffer +template +struct ThreadwiseReduction +{ + static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{}; + static constexpr auto dst_thread_desc_m = DstThreadDesc_M{}; + + static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{}); + static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{}); + static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{}); + + static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); + + using Accumulation = detail::AccumulateWithNanCheck; + + template + __device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf) + { + static_for<0, src_length_m, 1>{}([&](auto iM) { + constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM)); + + static_for<0, src_length_k, 1>{}([&](auto iK) { + constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + Accumulation::Calculate(dst_buf(Number{}), src_buf[Number{}]); + }); + }); + }; +}; + +// Assume +// 1) SrcDesc is known at compile-time +// 2) DstDesc is known at compile-time +// 3) SrcBuffer is static buffer +// 4) DstBuffer is static buffer +template +struct ThreadwiseReductionWithIndex +{ + static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{}; + static constexpr auto dst_thread_desc_m = DstThreadDesc_M{}; + + static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{}); + static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{}); + static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{}); + + static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); + + using Accumulation = + detail::AccumulateWithIndexAndNanCheck; + + template + __device__ static void Reduce(const SrcValueBufferType& src_val_buf, + const SrcIndexBufferType& src_idx_buf, + DstValueBufferType& dst_val_buf, + DstIndexBufferType& dst_idx_buf) + { + static_for<0, src_length_m, 1>{}([&](auto iM) { + constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM)); + + static_for<0, src_length_k, 1>{}([&](auto iK) { + constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + Accumulation::Calculate(dst_val_buf(Number{}), + src_val_buf[Number{}], + dst_idx_buf(Number{}), + src_idx_buf[Number{}]); + }); + }); + }; +}; + +}; // end of namespace ck + +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp similarity index 96% rename from composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp rename to include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp index 8b753810268..6a532c79f9f 100644 --- a/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp @@ -1,6 +1,4 @@ -#ifndef CK_THREADWISE_CONTRACTION_DLOPS_HPP -#define CK_THREADWISE_CONTRACTION_DLOPS_HPP - +#pragma once #include "common_header.hpp" #include "math.hpp" @@ -25,9 +23,9 @@ template ::type = false> -struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 +struct ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1 { - __device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1() + __device__ constexpr ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1() { static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && @@ -124,9 +122,9 @@ template ::type = false> -struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 +struct ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 { - __device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() + __device__ constexpr ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() { static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && @@ -220,4 +218,3 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_ }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp new file mode 100644 index 00000000000..360b115015a --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp @@ -0,0 +1,165 @@ +#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP +#define CK_THREADWISE_GEMM_DLOPS_V3_HPP + +#include "common_header.hpp" +#include "math.hpp" + +namespace ck { + +// C[M, N] += transpose(A[K, M]) * B[K, N] +// Element of matrix can be vectorized data +// Assume: +// 1. AThreadDesc_E1_K_E2, BThreadDesc_E1_N_Ho_Wo_E2, CThreadDesc_K_N_Ho_Wo are known at +// compile-time +// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time +template ::type = false> +struct ThreadwiseGemmDlops_km_kn_mn_v3 +{ + + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) + { + + static_assert(AThreadDesc_E1_K_E2::IsKnownAtCompileTime() && + BThreadDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() && + CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto E1 = AThreadDesc_E1_K_E2{}.GetLength(I0); + constexpr auto K = AThreadDesc_E1_K_E2{}.GetLength(I1); + constexpr auto E2 = AThreadDesc_E1_K_E2{}.GetLength(I2); + + constexpr auto Ho = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I2); + constexpr auto Wo = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I3); + + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); + + if constexpr((Ho % 2 == 0) && (Wo % 2 == 0)) + { + constexpr auto SubHW = 2; + + static_for<0, K, 1>{}([&](auto k) { + static_for<0, Ho, SubHW>{}([&](auto h) { + static_for<0, Wo, SubHW>{}([&](auto w) { + static_for<0, E1, 1>{}([&](auto e1) { + static_for<0, E2, 1>{}([&](auto e2) { + constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset( + a_origin_idx + make_tuple(e1, k, e2)); + + constexpr index_t b0_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h, w, e2)); + + constexpr index_t b1_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h, w + 1, e2)); + + constexpr index_t b2_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h + 1, w, e2)); + + constexpr index_t b3_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h + 1, w + 1, e2)); + + constexpr index_t c0_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + + make_tuple(k, 0, h, w)); + + constexpr index_t c1_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h, w + 1)); + + constexpr index_t c2_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h + 1, w)); + + constexpr index_t c3_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h + 1, w + 1)); + + amd_assembly_outer_product_1x4(a_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{})); + }); + }); + }); + }); + }); + } + else + { + + static_for<0, K, 1>{}([&](auto k) { + static_for<0, Ho, 1>{}([&](auto h) { + static_for<0, Wo, 1>{}([&](auto w) { + static_for<0, E1, 1>{}([&](auto e1) { + static_for<0, E2, 1>{}([&](auto e2) { + constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset( + a_origin_idx + make_tuple(e1, k, e2)); + + constexpr index_t b_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h, w, e2)); + + constexpr index_t c_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + + make_tuple(k, 0, h, w)); + + inner_product(a_buf[Number{}], + b_buf[Number{}], + c_buf(Number{})); + }); + }); + }); + }); + }); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp similarity index 100% rename from composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp rename to include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp similarity index 73% rename from composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp rename to include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 7e3f6b3489a..7a75ca53808 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -4,6 +4,7 @@ #include "common_header.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" +#include "tensor_space_filling_curve.hpp" namespace ck { @@ -50,11 +51,12 @@ template ::type = false> @@ -69,11 +71,15 @@ struct ThreadwiseTensorSliceTransfer_v1r3 using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); __device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc, - const Index& dst_slice_origin_idx) - : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)) + const Index& dst_slice_origin_idx, + const ElementwiseOperation& element_op) + : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)), + element_op_{element_op} { static_assert(SrcDesc::IsKnownAtCompileTime(), "wrong! SrcDesc need to known at compile-time"); + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); } __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) @@ -81,16 +87,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3 dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); } - template + template __device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, const SrcBuffer& src_buf, const DstDesc& dst_desc, - DstBuffer& dst_buf, - const DstStepHacks& dst_step_hacks) + DstBuffer& dst_buf) { static_assert(SrcDesc::IsKnownAtCompileTime(), "wrong! SrcDesc need to known at compile-time"); @@ -104,9 +106,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3 constexpr auto src_desc = remove_cvref_t{}; constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto dst_scalar_per_access = generate_sequence( @@ -115,141 +114,52 @@ struct ThreadwiseTensorSliceTransfer_v1r3 constexpr auto dst_scalar_step_in_vector = generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); - constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dim_access_order = DimAccessOrder{}; - - constexpr auto ordered_access_lengths = - container_reorder_given_new2old(access_lengths, dim_access_order); + using SpaceFillingCurve = SpaceFillingCurve>; - // make forward steps - const auto dst_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; + // TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector? + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + typename vector_type_maker::type dst_vector; + using dst_vector_t = typename vector_type_maker::type::type; - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; - }); + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - return make_tensor_coordinate_step( - dst_desc, forward_step_idx, dst_step_hacks[I0][i]); - }, - Number{}); - - // make backward steps - const auto dst_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step( - dst_desc, backward_step_idx, dst_step_hacks[I1][i]); - }, - Number{}); - - // loop over tensor and copy - static_ford{}([&](auto ordered_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_access_idx[I0]; - - static_for<0, i, 1>{}([&](auto j) { - tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] - ? ordered_access_idx[i] - : ordered_access_lengths[i] - 1 - ordered_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, dim_access_order) * - dst_scalar_per_access; - }(); - - typename vector_type_maker::type dst_vector; - - using dst_vector_t = - typename vector_type_maker::type::type; + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); // copy data from src_buf into dst_vector + // TODO: It's a hack here to use \p dst_scalar_step_in_vector. Use SpaceFillingCurve? static_for<0, DstScalarPerVector, 1>{}([&](auto i) { constexpr index_t src_offset = src_desc.CalculateOffset( - src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector); + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + SrcData v; - dst_vector.template AsType()(i) = - type_convert{}(src_buf[Number{}]); + // apply element-wise operation + element_op_(v, src_buf[Number{}]); + + // apply type convert + dst_vector.template AsType()(i) = type_convert(v); }); const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); // copy data from dst_vector into dst_buf - if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set) - { - dst_buf.template Set( - dst_coord_.GetOffset(), - is_dst_valid, - dst_vector.template AsType()[Number<0>{}]); - } - else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd) - { - dst_buf.template AtomicAdd( - dst_coord_.GetOffset(), - is_dst_valid, - dst_vector.template AsType()[Number<0>{}]); - } + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); - constexpr auto move_on_dim = [&]() constexpr + if constexpr(idx_1d.value != num_access - 1) { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; - }); - }); + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); - return move_on_dim_; + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); } - (); - - // move - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) - { - move_tensor_coordinate( - dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]); - } - else - { - move_tensor_coordinate( - dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]); - } - } - }); }); // move dst coordinate back to slice origin (or not) @@ -262,82 +172,27 @@ struct ThreadwiseTensorSliceTransfer_v1r3 } } - template - __device__ void Run(const SrcDesc&, - const SrcSliceOriginIdx&, - const SrcBuffer& src_buf, - const DstDesc& dst_desc, - DstBuffer& dst_buf) - { - constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); - - constexpr auto zeros = typename uniform_sequence_gen::type{}; - - constexpr auto dst_step_hacks = - make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), - generate_tuple([&](auto) { return zeros; }, Number{})); - - Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_step_hacks); - } - __device__ static constexpr auto GetDstCoordinateResetStep() { - constexpr auto I0 = Number<0>{}; - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access constexpr auto dst_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); - constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dim_access_order = DimAccessOrder{}; - - constexpr auto ordered_access_lengths = - container_reorder_given_new2old(access_lengths, dim_access_order); + using SpaceFillingCurve = SpaceFillingCurve>; - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_access_lengths[I0] - 1; - - static_for<0, i, 1>{}([&](auto j) { - tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index after last iteration in Run(), if it has not being reset by - // RunWrite() - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, dim_access_order) * - dst_scalar_per_access; - }(); - - // - constexpr auto reset_dst_data_step = [&]() { - Index reset_dst_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); - - return reset_dst_data_step_; - }(); + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); - return reset_dst_data_step; + return reset_step; + } } // dst_slice_origin_step_idx need to be known at compile-time, for performance reason @@ -357,7 +212,8 @@ struct ThreadwiseTensorSliceTransfer_v1r3 private: DstCoord dst_coord_; -}; // namespace ck + const ElementwiseOperation element_op_; +}; // namespace ThreadwiseTensorSliceTransfer_v1r3 // Assume: // 1. src: @@ -395,6 +251,8 @@ struct ThreadwiseTensorSliceTransfer_v2 { static_assert(DstDesc::IsKnownAtCompileTime(), "wrong! SrcDesc need to known at compile-time"); + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! Not divisible"); } __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) @@ -402,16 +260,12 @@ struct ThreadwiseTensorSliceTransfer_v2 src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); } - template + template __device__ void Run(const SrcDesc& src_desc, const SrcBuffer& src_buf, const DstDesc&, const DstSliceOriginIdx&, - DstBuffer& dst_buf, - const SrcStepHacks& src_step_hacks) + DstBuffer& dst_buf) { static_assert(DstDesc::IsKnownAtCompileTime(), "wrong! DstDesc need to known at compile-time"); @@ -427,9 +281,6 @@ struct ThreadwiseTensorSliceTransfer_v2 constexpr auto dst_desc = remove_cvref_t{}; constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( @@ -438,80 +289,19 @@ struct ThreadwiseTensorSliceTransfer_v2 constexpr auto src_scalar_step_in_vector = generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); - constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto dim_access_order = DimAccessOrder{}; - - constexpr auto ordered_access_lengths = - container_reorder_given_new2old(access_lengths, dim_access_order); - - // make forward steps - const auto src_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step( - src_desc, forward_step_idx, src_step_hacks[I0][i]); - }, - Number{}); - - // make backward steps - const auto src_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step( - src_desc, backward_step_idx, src_step_hacks[I1][i]); - }, - Number{}); + using SpaceFillingCurve = SpaceFillingCurve>; // loop over tensor and copy - static_ford{}([&](auto ordered_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_access_idx[I0]; - - static_for<0, i, 1>{}([&](auto j) { - tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate src data index - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] - ? ordered_access_idx[i] - : ordered_access_lengths[i] - 1 - ordered_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, dim_access_order) * - src_scalar_per_access; - }(); + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + static_for<0, num_access, 1>{}([&](auto idx_1d) { typename vector_type_maker::type src_vector; using src_vector_t = typename vector_type_maker::type::type; + constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d); const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); @@ -526,41 +316,17 @@ struct ThreadwiseTensorSliceTransfer_v2 dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + i * src_scalar_step_in_vector); - dst_buf(Number{}) = src_vector.template AsType()[i]; + dst_buf(Number{}) = + type_convert(src_vector.template AsType()[i]); }); - constexpr auto move_on_dim = [&]() constexpr + if constexpr(idx_1d.value != num_access - 1) { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); - static_for{}([&](auto j) { - move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; + move_tensor_coordinate( + src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); } - (); - - // move - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) - { - move_tensor_coordinate( - src_desc, src_coord_, src_forward_steps[dim_access_order[i]]); - } - else - { - move_tensor_coordinate( - src_desc, src_coord_, src_backward_steps[dim_access_order[i]]); - } - } - }); }); // move src coordinate back to slice origin (or not) @@ -573,82 +339,27 @@ struct ThreadwiseTensorSliceTransfer_v2 } } - template - __device__ void Run(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - const DstDesc&, - const DstSliceOriginIdx&, - DstBuffer& dst_buf) - { - constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); - - constexpr auto zeros = typename uniform_sequence_gen::type{}; - - constexpr auto src_step_hacks = - make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), - generate_tuple([&](auto) { return zeros; }, Number{})); - - Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_step_hacks); - } - __device__ static constexpr auto GetSrcCoordinateResetStep() { - constexpr auto I0 = Number<0>{}; - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); - constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto dim_access_order = DimAccessOrder{}; - - constexpr auto ordered_access_lengths = - container_reorder_given_new2old(access_lengths, dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_access_lengths[I0] - 1; - - static_for<0, i, 1>{}([&](auto j) { - tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); + using SpaceFillingCurve = SpaceFillingCurve>; - // calculate src data index after last iteration in Run(), if it has not being reset by - // RunWrite() - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, dim_access_order) * - src_scalar_per_access; - }(); - - // - constexpr auto reset_src_data_step = [&]() { - Index reset_src_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); - - return reset_src_data_step_; - }(); + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); - return reset_src_data_step; + return reset_step; + } } // dst_slice_origin_step_idx need to be known at compile-time, for performance reason @@ -666,6 +377,25 @@ struct ThreadwiseTensorSliceTransfer_v2 move_tensor_coordinate(src_desc, src_coord_, adjusted_step); } + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx, + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step( + src_desc, adjusted_step_idx, src_move_slice_window_step_hack); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + private: SrcCoord src_coord_; }; // namespace ck @@ -676,7 +406,7 @@ struct ThreadwiseTensorSliceTransfer_v2 // 3. src_slice_origin and dst_slice_origin are not known at compile-time, // 4. Use thread buffer template {}) % SrcScalarPerVector == 0, + "wrong! Not divisible"); + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); } __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) @@ -729,8 +463,8 @@ struct ThreadwiseTensorSliceTransfer_v3 __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) { - static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or - SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, "wrong!"); static_assert( @@ -794,7 +528,7 @@ struct ThreadwiseTensorSliceTransfer_v3 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_src_access_idx[I0]; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; }); @@ -886,8 +620,8 @@ struct ThreadwiseTensorSliceTransfer_v3 __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) { - static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or - DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, "wrong!"); static_assert( @@ -951,7 +685,7 @@ struct ThreadwiseTensorSliceTransfer_v3 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_dst_access_idx[I0]; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; }); @@ -983,7 +717,7 @@ struct ThreadwiseTensorSliceTransfer_v3 buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector); dst_tmp_vector.template AsType()(i) = - type_convert{}(buffer_[Number{}]); + type_convert(buffer_[Number{}]); }); using dst_vector_t = typename decltype(dst_tmp_vector)::type; @@ -1095,7 +829,7 @@ struct ThreadwiseTensorSliceTransfer_v3 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_src_access_lengths[I0] - 1; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; }); @@ -1155,7 +889,7 @@ struct ThreadwiseTensorSliceTransfer_v3 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_dst_access_lengths[I0] - 1; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; }); @@ -1244,7 +978,7 @@ struct ThreadwiseTensorSliceTransfer_v3 static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); - StaticBuffer buffer_; + StaticBuffer buffer_; SrcCoord src_coord_; DstCoord dst_coord_; @@ -1290,7 +1024,8 @@ struct ThreadwiseTensorSliceTransfer_v4 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), "wrong! SrcDesc and DstDesc need to known at compile-time"); - static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, "wrong!"); + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! Not divisible"); } template {}([&](auto i) { dst_tmp_vector.template AsType()(i) = - type_convert{}(src_tmp_vector.template AsType()[i]); + type_convert(src_tmp_vector.template AsType()[i]); }); // copy data from dst_tmp_vector into dst_buf diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp new file mode 100644 index 00000000000..4cd41ddb30d --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -0,0 +1,792 @@ +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R1_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R1_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "static_tensor.hpp" + +namespace ck { + +namespace detail { +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access_for_src_and_dst +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + if(i == SrcVectorDim && i == DstVectorDim) + { + return math::lcm(SrcScalarPerVector, DstScalarPerVector); + } + else if(i == SrcVectorDim) + { + return SrcScalarPerVector; + } + else if(i == DstVectorDim) + { + return DstScalarPerVector; + } + else + { + return 1; + } + } +}; + +} // namespace detail + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template +struct ThreadwiseTensorSliceTransfer_v3r1 +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( + const SrcDesc& src_desc, + const Index& src_slice_origin, + const SrcElementwiseOperation& src_element_op, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const DstElementwiseOperation& dst_element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + src_element_op_(src_element_op), + dst_element_op_(dst_element_op) + { + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + // apply SrcElementwiseOperation on src_vector_container + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + SrcData src_v; + + src_element_op_(src_v, src_vector_container.template AsType()[i]); + + src_vector_container.template AsType()(i) = src_v; + }); + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType( + src_data_idx_seq, src_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + template + __device__ void + TransferDataFromSrcThreadScratchToDstThreadScratch(Number thread_scratch_id) + { +#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE + static_ford{}([&](auto idx) { + // convert from SrcData to DstData here + dst_thread_scratch_(idx) = + type_convert(src_thread_scratch_tuple_[thread_scratch_id][idx]); + }); +#else + // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ + // TODO make this logic more generic for more sub-dword datatype + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + static_assert(SrcVectorDim != DstVectorDim, "wrong"); + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + // TODO type_convert is not used yet!!!!! + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + // TODO type_convert is not used yet!!!!! + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + static_ford{}([&](auto idx) { + // convert from SrcData to DstData here + dst_thread_scratch_(idx) = + type_convert(src_thread_scratch_tuple_[thread_scratch_id][idx]); + }); + } +#endif + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) + { + // if there is transpose, it's done here + // TODO move this elsewhere + TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); + + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + constexpr auto dst_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + // copy data from dst_thread_scratch_ into dst_vector_container + auto dst_vector_container = dst_vector_type{ + dst_thread_scratch_.template GetAsType(dst_data_idx_seq)}; + + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + DstData dst_v; + + // apply DstElementwiseOperation + dst_element_op_(dst_v, dst_vector_container.template AsType()[i]); + + dst_vector_container.template AsType()(i) = dst_v; + }); + + // copy data from dst_vector_container to dst_buf + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move dst coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + private: + static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; + static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; + + using SrcThreadScratch = StaticTensorTupleOfVectorBuffer; + + using DstThreadScratch = StaticTensorTupleOfVectorBuffer; + + StaticallyIndexedArray src_thread_scratch_tuple_; + + DstThreadScratch dst_thread_scratch_; + + SrcCoord src_coord_; + DstCoord dst_coord_; + const SrcElementwiseOperation src_element_op_; + const DstElementwiseOperation dst_element_op_; +}; + +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp new file mode 100644 index 00000000000..1447f06f022 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp @@ -0,0 +1,883 @@ +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "static_tensor.hpp" + +namespace ck { + +namespace detail { +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access_for_src_and_dst +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + if(i == SrcVectorDim && i == DstVectorDim) + { + return math::lcm(SrcScalarPerVector, DstScalarPerVector); + } + else if(i == SrcVectorDim) + { + return SrcScalarPerVector; + } + else if(i == DstVectorDim) + { + return DstScalarPerVector; + } + else + { + return 1; + } + } +}; + +} // namespace detail + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template // control whether to move back dst coordinate after each + // RunWrite(), will be fused with MoveDstSliceWindow to + // save addr computation +struct ThreadwiseTensorSliceTransfer_v3r3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + using Dst0Coord = decltype(make_tensor_coordinate(Dst0Desc{}, Index{})); + using Dst1Coord = decltype(make_tensor_coordinate(Dst1Desc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + using Dst0CoordStep = decltype(make_tensor_coordinate_step(Dst0Desc{}, Index{})); + using Dst1CoordStep = decltype(make_tensor_coordinate_step(Dst1Desc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r3( + const SrcDesc& src_desc, + const Index& src_slice_origin, + const SrcElementwiseOperation& src_element_op, + const DstDesc& dst_desc, + const Dst0Desc& dst0_desc, + const Dst1Desc& dst1_desc, + const Index& dst_slice_origin, + const DstElementwiseOperation& dst_element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin)), + dst1_coord_(make_tensor_coordinate(dst1_desc, dst_slice_origin)), + src_element_op_(src_element_op), + dst_element_op_(dst_element_op) + { + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, + const Dst0Desc& dst0_desc, + const Dst1Desc& dst1_desc, + const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + dst0_coord_ = make_tensor_coordinate(dst0_desc, dst_slice_origin_idx); + dst1_coord_ = make_tensor_coordinate(dst1_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + // apply SrcElementwiseOperation on src_vector_container + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + src_vector_container.template AsType()(i) = + src_element_op_(src_vector_container.template AsType()[i]); + }); + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_.template SetAsType( + src_data_idx_seq, src_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + __device__ void TransferDataFromSrcThreadScratchToDstThreadScratch() + { +#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE + static_ford{}([&](auto idx) { + // convert from SrcData to DstData here + dst_thread_scratch_(idx) = type_convert(src_thread_scratch_[idx]); + }); +#else + // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ + // TODO make this logic more generic for more sub-dword datatype + if constexpr(SrcVectorDim != DstVectorDim && + is_same>::value && + is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + static_assert(SrcVectorDim != DstVectorDim, "wrong"); + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + // TODO type_convert is not used yet!!!!! + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return src_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + // TODO type_convert is not used yet!!!!! + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + static_ford{}([&](auto idx) { + // convert from SrcData to DstData here + dst_thread_scratch_(idx) = type_convert(src_thread_scratch_[idx]); + }); + } +#endif + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + const Dst0Desc& dst0_desc, + const Dst0Buffer& dst0_buf, + const Dst1Desc& dst1_desc, + const Dst1Buffer& dst1_buf) + { + // if there is transpose, it's done here + // TODO move this elsewhere + TransferDataFromSrcThreadScratchToDstThreadScratch(); + + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, forward_step_idx); + }, + Number{}); + + // make forward steps: dst0 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector + // TODO: fix this + const auto dst0_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst0_desc, forward_step_idx); + }, + Number{}); + + // make forward steps: dst1 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector + // TODO: fix this + const auto dst1_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst1_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, backward_step_idx); + }, + Number{}); + + // make backward steps: dst0 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector + // TODO: fix this + const auto dst0_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst0_desc, backward_step_idx); + }, + Number{}); + + // make backward steps: dst1 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector + // TODO: fix this + const auto dst1_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst1_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + constexpr auto dst_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + // copy data from dst_thread_scratch_ into dst_vector_container + auto dst_vector_container = dst_vector_type{ + dst_thread_scratch_.template GetAsType(dst_data_idx_seq)}; + + // apply DstElementwiseOperation on dst_vector_container + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + dst_vector_container.template AsType()(i) = + dst_element_op_(dst_vector_container.template AsType()[i]); + }); + + // copy data from dst_vector_container to dst_buf + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move dst coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + // TODO: BUG: should start at 1 + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Dst0Desc dst0_desc, + const Dst1Desc dst1_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + move_tensor_coordinate(dst0_desc, dst0_coord_, adjusted_step); + move_tensor_coordinate(dst1_desc, dst1_coord_, adjusted_step); + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + private: + static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; + static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; + + StaticTensorTupleOfVectorBuffer + src_thread_scratch_; + + StaticTensorTupleOfVectorBuffer + dst_thread_scratch_; + + SrcCoord src_coord_; + DstCoord dst_coord_; + const SrcElementwiseOperation src_element_op_; + const DstElementwiseOperation dst_element_op_; +}; + +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp new file mode 100644 index 00000000000..2504c928567 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp @@ -0,0 +1,174 @@ +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is DynamicBuffer +// 3. src_ref_idx is known at run-time +// 4. SrcRefToOriginDisplacement is known at compile-time +// 5. use #-step +// 2. dst: +// 1. DstDesc is known at compile-time +// 2. DstBuffer is StaticBuffer +// 3. DstOriginIdx is known at compile-time +// 4. use direct address calculation +// 3. vector access on src +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v4r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx) + : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_for<0, nDim, 1>{}([](auto i) { + static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0, "wrong!"); + }); + } + + template + __device__ void Run(const SrcDesc&, + const SrcRefToOriginDisplacement&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); + + // SrcDesc and DstDesc are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + + // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time + constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); + constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); + + // tensor descriptor for src_vector + constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; + + constexpr auto src_vector_tensor_strides = container_reorder_given_old2new( + container_reverse_exclusive_scan( + container_reorder_given_new2old(src_vector_tensor_lengths, + SrcVectorTensorContiguousDimOrder{}), + math::multiplies{}, + I1), + SrcVectorTensorContiguousDimOrder{}); + + constexpr auto src_vector_desc = + make_naive_tensor_descriptor(sequence_to_tuple_of_number(src_vector_tensor_lengths), + sequence_to_tuple_of_number(src_vector_tensor_strides)); + + // access order and lengths + constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static_ford{}([&](auto ordered_access_idx) { + // position in slice window + constexpr auto data_to_origin_disp_idx = + ordered_access_idx.ReorderGivenOld2New(dim_access_order) * + src_vector_tensor_lengths; + + // src coordinate at starting point of src_vector + constexpr auto src_ref_to_data_disp_idx = + src_ref_to_origin_disp_idx + data_to_origin_disp_idx; + + constexpr auto src_ref_to_data_disp_coord_step = + make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); + + auto src_data_coord = src_ref_coord_; + + move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); + + vector_type_maker_t src_vector; + + using src_vector_t = typename decltype(src_vector)::type; + + const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_desc, src_data_coord); + + // copy data from src_buf into src_vector + src_vector.template AsType()(I0) = + src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); + + // copy data from src_vector into dst_buf (also cast from SrcData to DstData) + static_ford{}([&](auto src_vector_idx_) { + constexpr auto src_vector_idx = to_multi_index(src_vector_idx_); + + constexpr index_t src_vector_offset = + src_vector_desc.CalculateOffset(src_vector_idx); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + src_vector_idx); + + dst_buf(Number{}) = type_convert( + src_vector.template AsType()[Number{}]); + }); + }); + } + + template + __device__ void MoveSrcSliceWindow(const SrcDesc&, + const SrcSliceMoveStepIdx& src_slice_move_step_idx) + { + constexpr auto src_desc = SrcDesc{}; + + const auto src_slice_move_step_iter = + make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx)); + + move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); + } + + private: + SrcCoord src_ref_coord_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp similarity index 75% rename from composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp rename to include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp index bbdaa5fa2bc..f0e9c7e7614 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -1,5 +1,4 @@ -#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP -#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP +#pragma once #include "common_header.hpp" #include "tensor_descriptor.hpp" @@ -13,7 +12,7 @@ namespace ck { // 3. src_slice_origin and dst_slice_origin are not known at compile-time, // 4. Use thread buffer template // control whether to move back dst coordinate after each // RunWrite(), will be fused with MoveDstSliceWindow to // save addr computation -struct ThreadwiseTensorSliceTransfer_v3r1 +struct ThreadwiseTensorSliceTransfer_v5r1 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -44,7 +43,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); - __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(const SrcDesc& src_desc, + __device__ constexpr ThreadwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc, const Index& src_slice_origin, const DstDesc& dst_desc, const Index& dst_slice_origin) @@ -76,8 +75,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) { - static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or - SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, "wrong!"); static_assert( @@ -244,8 +243,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) { - static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or - DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, "wrong!"); static_assert( @@ -351,7 +350,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 dst_vector_desc.CalculateOffset(dst_vector_idx); dst_vector.template AsType()(Number{}) = - type_convert{}(buffer_[Number{}]); + type_convert(buffer_[Number{}]); }); using dst_vector_t = typename decltype(dst_vector)::type; @@ -602,175 +601,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1 static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); - StaticBuffer buffer_; + StaticBuffer buffer_; SrcCoord src_coord_; DstCoord dst_coord_; }; -// Assume: -// 1. src: -// 1. SrcDesc is known at compile-time -// 2. SrcBuffer is DynamicBuffer -// 3. src_ref_idx is known at run-time -// 4. SrcRefToOriginDisplacement is known at compile-time -// 5. use #-step -// 2. dst: -// 1. DstDesc is known at compile-time -// 2. DstBuffer is StaticBuffer -// 3. DstOriginIdx is known at compile-time -// 4. use direct address calculation -// 3. vector access on src -template ::type = false> -struct ThreadwiseTensorSliceTransfer_v4r1 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - static constexpr index_t nDim = SliceLengths::Size(); - - using Index = MultiIndex; - - using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - - using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - - __device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx) - : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) - { - static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), - "wrong! SrcDesc and DstDesc need to known at compile-time"); - - static_for<0, nDim, 1>{}([](auto i) { - static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0, "wrong!"); - }); - } - - template - __device__ void Run(const SrcDesc&, - const SrcRefToOriginDisplacement&, - const SrcBuffer& src_buf, - const DstDesc&, - const DstOriginIdx&, - DstBuffer& dst_buf) const - { - static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), - "wrong! SrcDesc and DstDesc need to known at compile-time"); - - static_assert( - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); - - static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); - - static_assert(is_known_at_compile_time>::value && - is_known_at_compile_time>::value, - "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " - "at compile-time"); - - // SrcDesc and DstDesc are known at compile-time - constexpr auto src_desc = remove_cvref_t{}; - constexpr auto dst_desc = remove_cvref_t{}; - - // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time - constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); - constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); - - // tensor descriptor for src_vector - constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; - - constexpr auto src_vector_tensor_strides = container_reorder_given_old2new( - container_reverse_exclusive_scan( - container_reorder_given_new2old(src_vector_tensor_lengths, - SrcVectorTensorContiguousDimOrder{}), - math::multiplies{}, - I1), - SrcVectorTensorContiguousDimOrder{}); - - constexpr auto src_vector_desc = - make_naive_tensor_descriptor(sequence_to_tuple_of_number(src_vector_tensor_lengths), - sequence_to_tuple_of_number(src_vector_tensor_strides)); - - // access order and lengths - constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths; - - constexpr auto dim_access_order = DimAccessOrder{}; - - constexpr auto ordered_access_lengths = - container_reorder_given_new2old(access_lengths, dim_access_order); - - static_ford{}([&](auto ordered_access_idx) { - // position in slice window - constexpr auto data_to_origin_disp_idx = - ordered_access_idx.ReorderGivenOld2New(dim_access_order) * - src_vector_tensor_lengths; - - // src coordinate at starting point of src_vector - constexpr auto src_ref_to_data_disp_idx = - src_ref_to_origin_disp_idx + data_to_origin_disp_idx; - - constexpr auto src_ref_to_data_disp_coord_step = - make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); - - auto src_data_coord = src_ref_coord_; - - move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); - - vector_type_maker_t src_vector; - - using src_vector_t = typename decltype(src_vector)::type; - - const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( - src_desc, src_data_coord); - - // copy data from src_buf into src_vector - src_vector.template AsType()(I0) = - src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); - - // copy data from src_vector into dst_buf (also cast from SrcData to DstData) - static_ford{}([&](auto src_vector_idx_) { - constexpr auto src_vector_idx = to_multi_index(src_vector_idx_); - - constexpr index_t src_vector_offset = - src_vector_desc.CalculateOffset(src_vector_idx); - - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_origin_idx + data_to_origin_disp_idx + src_vector_idx); - - dst_buf(Number{}) = type_convert{}( - src_vector.template AsType()[Number{}]); - }); - }); - } - - template - __device__ void MoveSrcSliceWindow(const SrcDesc&, - const SrcSliceMoveStepIdx& src_slice_move_step_idx) - { - constexpr auto src_desc = SrcDesc{}; - - const auto src_slice_move_step_iter = - make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx)); - - move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); - } - - private: - SrcCoord src_ref_coord_; -}; - } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp new file mode 100644 index 00000000000..042bc95f55e --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp @@ -0,0 +1,212 @@ +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "tensor_space_filling_curve.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +template +struct ThreadwiseTensorSliceTransfer_v6r1 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const ElementwiseOperation& element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + // loop over space-filling curve + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + auto dst_vector_container = dst_vector_type{}; + + // apply pointwise operation + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + SrcData v; + + // apply element-wise operation + element_op_(v, src_vector_container.template AsType()[i]); + + // apply type convert + dst_vector_container.template AsType()(i) = type_convert(v); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + // move coordinate + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + move_tensor_coordinate( + src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = SrcResetCoordinateAfterRun + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRun + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + SrcCoord src_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; // namespace ck + +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp new file mode 100644 index 00000000000..ae85ba91e58 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp @@ -0,0 +1,259 @@ +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "tensor_space_filling_curve.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// Assume: +// 1. src0_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +template +struct ThreadwiseTensorSliceTransfer_v6r2 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using Src0Coord = decltype(make_tensor_coordinate(Src0Desc{}, Index{})); + using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, + const Index& src0_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const ElementwiseOperation& element_op) + : src0_coord_(make_tensor_coordinate(src0_desc, src0_slice_origin)), + src1_coord_(make_tensor_coordinate(src1_desc, src1_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + __device__ void SetSrc0SliceOrigin(const Src0Desc& src0_desc, + const Index& src0_slice_origin_idx) + { + src0_coord_ = make_tensor_coordinate(src0_desc, src0_slice_origin_idx); + } + + __device__ void SetSrc1SliceOrigin(const Src1Desc& src1_desc, + const Index& src1_slice_origin_idx) + { + src1_coord_ = make_tensor_coordinate(src1_desc, src1_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto idx_1d) { + using src0_vector_type = vector_type_maker_t; + using src0_vector_t = typename src0_vector_type::type; + + using src1_vector_type = vector_type_maker_t; + using src1_vector_t = typename src1_vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + const bool is_src0_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src0_desc, src0_coord_); + + const bool is_src1_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src1_desc, src1_coord_); + + // copy data from src0_buf into src0_vector_container + auto src0_vector_container = src0_vector_type{ + src0_buf.template Get(src0_coord_.GetOffset(), is_src0_valid)}; + + auto src1_vector_container = src1_vector_type{ + src1_buf.template Get(src1_coord_.GetOffset(), is_src1_valid)}; + + auto dst_vector_container = dst_vector_type{}; + + // apply pointwise operation + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + element_op_(dst_vector_container.template AsType()(i), + src0_vector_container.template AsType()[i], + src1_vector_container.template AsType()[i]); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + // move coordinate + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + move_tensor_coordinate( + src0_desc, src0_coord_, make_tensor_coordinate_step(src0_desc, forward_step)); + move_tensor_coordinate( + src1_desc, src1_coord_, make_tensor_coordinate_step(src1_desc, forward_step)); + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move coordinate back to slice origin (or not) + if constexpr(Src0ResetCoordinateAfterRun) + { + const auto src0_reset_step = + make_tensor_coordinate_step(src0_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src0_desc, src0_coord_, src0_reset_step); + } + + if constexpr(Src1ResetCoordinateAfterRun) + { + const auto src1_reset_step = + make_tensor_coordinate_step(src1_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src1_desc, src1_coord_, src1_reset_step); + } + + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, + const Index& src0_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src0ResetCoordinateAfterRun + ? src0_slice_origin_step_idx + : src0_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src0_desc, adjusted_step_idx); + + move_tensor_coordinate(src0_desc, src0_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, + const Index& src1_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src1ResetCoordinateAfterRun + ? src1_slice_origin_step_idx + : src1_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src1_desc, adjusted_step_idx); + + move_tensor_coordinate(src1_desc, src1_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRun + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + Src0Coord src0_coord_; + Src1Coord src1_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp new file mode 100644 index 00000000000..47024d5e688 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp @@ -0,0 +1,309 @@ +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "tensor_space_filling_curve.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// Assume: +// 1. src0_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +template +struct ThreadwiseTensorSliceTransfer_v6r3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using Src0Coord = decltype(make_tensor_coordinate(Src0Desc{}, Index{})); + using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{})); + using Src2Coord = decltype(make_tensor_coordinate(Src2Desc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, + const Index& src0_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_slice_origin, + const Src2Desc& src2_desc, + const Index& src2_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const ElementwiseOperation& element_op) + : src0_coord_(make_tensor_coordinate(src0_desc, src0_slice_origin)), + src1_coord_(make_tensor_coordinate(src1_desc, src1_slice_origin)), + src2_coord_(make_tensor_coordinate(src2_desc, src2_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + __device__ void SetSrc0SliceOrigin(const Src0Desc& src0_desc, + const Index& src0_slice_origin_idx) + { + src0_coord_ = make_tensor_coordinate(src0_desc, src0_slice_origin_idx); + } + + __device__ void SetSrc1SliceOrigin(const Src1Desc& src1_desc, + const Index& src1_slice_origin_idx) + { + src1_coord_ = make_tensor_coordinate(src1_desc, src1_slice_origin_idx); + } + + __device__ void SetSrc2SliceOrigin(const Src2Desc& src2_desc, + const Index& src2_slice_origin_idx) + { + src2_coord_ = make_tensor_coordinate(src2_desc, src2_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const Src2Desc& src2_desc, + const Src2Buffer& src2_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto idx_1d) { + using src0_vector_type = vector_type_maker_t; + using src0_vector_t = typename src0_vector_type::type; + + using src1_vector_type = vector_type_maker_t; + using src1_vector_t = typename src1_vector_type::type; + + using src2_vector_type = vector_type_maker_t; + using src2_vector_t = typename src2_vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + const bool is_src0_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src0_desc, src0_coord_); + + const bool is_src1_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src1_desc, src1_coord_); + + const bool is_src2_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src2_desc, src2_coord_); + + // copy data from src0_buf into src0_vector_container + auto src0_vector_container = src0_vector_type{ + src0_buf.template Get(src0_coord_.GetOffset(), is_src0_valid)}; + + auto src1_vector_container = src1_vector_type{ + src1_buf.template Get(src1_coord_.GetOffset(), is_src1_valid)}; + + auto src2_vector_container = src2_vector_type{ + src2_buf.template Get(src2_coord_.GetOffset(), is_src2_valid)}; + + auto dst_vector_container = dst_vector_type{}; + + // apply pointwise operation + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + element_op_(dst_vector_container.template AsType()(i), + src0_vector_container.template AsType()[i], + src1_vector_container.template AsType()[i], + src2_vector_container.template AsType()[i]); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + // move coordinate + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + move_tensor_coordinate( + src0_desc, src0_coord_, make_tensor_coordinate_step(src0_desc, forward_step)); + move_tensor_coordinate( + src1_desc, src1_coord_, make_tensor_coordinate_step(src1_desc, forward_step)); + move_tensor_coordinate( + src2_desc, src2_coord_, make_tensor_coordinate_step(src2_desc, forward_step)); + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move coordinate back to slice origin (or not) + if constexpr(Src0ResetCoordinateAfterRun) + { + const auto src0_reset_step = + make_tensor_coordinate_step(src0_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src0_desc, src0_coord_, src0_reset_step); + } + + if constexpr(Src1ResetCoordinateAfterRun) + { + const auto src1_reset_step = + make_tensor_coordinate_step(src1_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src1_desc, src1_coord_, src1_reset_step); + } + + if constexpr(Src2ResetCoordinateAfterRun) + { + const auto src2_reset_step = + make_tensor_coordinate_step(src2_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src2_desc, src2_coord_, src2_reset_step); + } + + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, + const Index& src0_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src0ResetCoordinateAfterRun + ? src0_slice_origin_step_idx + : src0_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src0_desc, adjusted_step_idx); + + move_tensor_coordinate(src0_desc, src0_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, + const Index& src1_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src1ResetCoordinateAfterRun + ? src1_slice_origin_step_idx + : src1_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src1_desc, adjusted_step_idx); + + move_tensor_coordinate(src1_desc, src1_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, + const Index& src2_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src2ResetCoordinateAfterRun + ? src2_slice_origin_step_idx + : src2_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src2_desc, adjusted_step_idx); + + move_tensor_coordinate(src2_desc, src2_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRun + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + Src0Coord src0_coord_; + Src1Coord src1_coord_; + Src2Coord src2_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp similarity index 76% rename from composable_kernel/include/tensor_operation/xdlops_gemm.hpp rename to include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 10633f8f328..9d72abb72ea 100644 --- a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -12,18 +12,19 @@ enum struct MfmaInstr mfma_f32_32x32x1xf32 = 0, mfma_f32_16x16x1xf32, mfma_f32_4x4x1xf32, - mfma_f32_32x32x2xf32, // k reduction - mfma_f32_16x16x4xf32, // k reduction + mfma_f32_32x32x2xf32, + mfma_f32_16x16x4xf32, mfma_f32_32x32x4f16, mfma_f32_16x16x4f16, mfma_f32_4x4x4f16, - mfma_f32_32x32x8f16, // k reduction - mfma_f32_16x16x16f16, // k reduction - mfma_f32_32x32x2bf16, - mfma_f32_16x16x2bf16, - mfma_f32_4x4x2bf16, - mfma_f32_32x32x4bf16, // k reduction - mfma_f32_16x16x8bf16, // k reduction + mfma_f32_32x32x8f16, + mfma_f32_16x16x16f16, + mfma_f32_32x32x8bf16_1k, + mfma_f32_16x16x16bf16_1k, + mfma_f32_32x32x4bf16, + mfma_f32_16x16x8bf16, + mfma_i32_32x32x8i8, + mfma_i32_16x16x16i8, }; template @@ -250,9 +251,8 @@ struct mfma_type } }; -#if 0 template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 4; @@ -260,26 +260,38 @@ struct mfma_type static constexpr index_t num_threads_per_blk = 32; static constexpr index_t wave_size = 64; static constexpr index_t num_input_blks = 2; - static constexpr index_t num_output_blks = 2; + static constexpr index_t num_output_blks = 1; static constexpr index_t m_per_blk = 32; static constexpr index_t n_per_blk = 32; - static constexpr index_t k_per_blk = 2; - static constexpr bool is_k_reduction = false; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; - template - __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - const auto p_a = c_style_pointer_cast(a); - const auto p_b = c_style_pointer_cast(b); + intrin_mfma_f32_32x32x8bf16_1k::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; - return intrin_mfma_f32_32x32x2bf16::run( - p_a, p_b, reg_c); + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x16bf16_1k::Run(a, b, reg_c); } }; @@ -298,19 +310,10 @@ struct mfma_type static constexpr index_t k_per_blk = 2; static constexpr bool is_k_reduction = true; - template - __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - const auto p_a = c_style_pointer_cast(a); - const auto p_b = c_style_pointer_cast(b); - - return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c); + intrin_mfma_f32_32x32x4bf16::Run(a, b, reg_c); } }; @@ -329,84 +332,56 @@ struct mfma_type static constexpr index_t k_per_blk = 2; static constexpr bool is_k_reduction = true; - template - __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - const auto p_a = c_style_pointer_cast(a); - const auto p_b = c_style_pointer_cast(b); - - return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c); + intrin_mfma_f32_16x16x8bf16::Run(a, b, reg_c); } }; template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; - static constexpr index_t num_groups_per_blk = 1; - static constexpr index_t num_regs_per_blk = 4; - static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = 4; - static constexpr index_t num_output_blks = 4; - static constexpr index_t m_per_blk = 16; - static constexpr index_t n_per_blk = 16; - static constexpr index_t k_per_blk = 2; - static constexpr bool is_k_reduction = false; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; - template - __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - const auto p_a = c_style_pointer_cast(a); - const auto p_b = c_style_pointer_cast(b); - - return intrin_mfma_f32_16x16x2bf16(p_a, p_b, reg_c); + intrin_mfma_i32_32x32x8i8::Run(a, b, reg_c); } }; template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 1; static constexpr index_t num_regs_per_blk = 4; - static constexpr index_t num_threads_per_blk = 64; + static constexpr index_t num_threads_per_blk = 16; static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = 1; + static constexpr index_t num_input_blks = 4; static constexpr index_t num_output_blks = 1; - static constexpr index_t m_per_blk = 4; - static constexpr index_t n_per_blk = 64; - static constexpr index_t k_per_blk = 2; - static constexpr bool is_k_reduction = false; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; - template - __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - const auto p_a = c_style_pointer_cast(a); - const auto p_b = c_style_pointer_cast(b); - - return intrin_mfma_f32_4x4x2bf16::run(p_a, p_b, reg_c); + intrin_mfma_i32_16x16x16i8::Run(a, b, reg_c); } }; -#endif template struct MfmaSelector @@ -498,77 +473,41 @@ struct MfmaSelector return MfmaInstr::mfma_f32_4x4x4f16; } -#if 0 - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - template <> - static constexpr auto GetMfma() + static constexpr auto GetMfma() { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; +#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) + return MfmaInstr::mfma_f32_32x32x8bf16_1k; +#else + return MfmaInstr::mfma_f32_32x32x4bf16; +#endif } template <> - static constexpr auto GetMfma() + static constexpr auto GetMfma() { - return xdlops_info{}; +#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) + return MfmaInstr::mfma_f32_16x16x16bf16_1k; +#else + return MfmaInstr::mfma_f32_16x16x8bf16; +#endif } template <> - static constexpr auto GetMfma() + static constexpr auto GetMfma() { - return xdlops_info{}; + return MfmaInstr::mfma_i32_32x32x8i8; } template <> - static constexpr auto GetMfma() + static constexpr auto GetMfma() { - return xdlops_info{}; + return MfmaInstr::mfma_i32_16x16x16i8; } -#endif static constexpr auto selected_mfma = mfma_type()>{}; - __host__ __device__ static constexpr void mfma_check() + __host__ __device__ constexpr MfmaSelector() { static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk == selected_mfma.num_regs_per_blk, @@ -594,8 +533,6 @@ struct MfmaSelector "is_k_reduction wrong!"); } - __host__ __device__ constexpr MfmaSelector() { mfma_check(); } - static constexpr bool IsABroadcast() { static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast"); @@ -608,7 +545,7 @@ struct MfmaSelector selected_mfma.k_per_blk; } - static constexpr index_t GetKPerThread() { return selected_mfma.k_per_blk; } + static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; } }; template @@ -644,17 +581,17 @@ struct XdlopsGemm static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk"); } - template + template __host__ __device__ static constexpr auto - MakeCM0N0M1N1M2M3M4N2Descriptor(const CM0N0M1N1M2N2Desc& c_m0_n0_m1_n1_m2_n2_desc) + MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) { - const auto M0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I0); - const auto N0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I1); - const auto M1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I2); - const auto N1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I3); + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); return transform_tensor_descriptor( - c_m0_n0_m1_n1_m2_n2_desc, + c_desc_m0_n0_m1_n1_m2_n2, make_tuple(make_pass_through_transform(M0), make_pass_through_transform(N0), make_pass_through_transform(M1), @@ -677,17 +614,56 @@ struct XdlopsGemm Sequence<7>{})); } + template + __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2) + { + const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4); + + return transform_tensor_descriptor( + c_desc_g_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(G), + make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk, + mfma_instr.num_input_blks, + mfma_instr.group_size)), + make_pass_through_transform(mfma_instr.num_threads_per_blk)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6, 7>{}, + Sequence<8>{})); + } + __device__ static constexpr index_t GetRegSizePerXdlops() { return MPerXdlops * NPerXdlops / mfma_instr.wave_size; } + __device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; } + template __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { static_assert(is_same::value || is_same::value || - is_same::value, - "base base_type must be float, half, ushort!"); + is_same::value || is_same::value, + "base base_type must be float, half, bfloat16, and int8_t!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { mfma_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); @@ -769,7 +745,7 @@ struct XdlopsGemm static constexpr auto mfma_instr = mfma.selected_mfma; static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); - static constexpr auto K1PerXdlops = mfma.GetKPerThread(); + static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() diff --git a/composable_kernel/include/utility/amd_address_space.hpp b/include/ck/utility/amd_address_space.hpp similarity index 76% rename from composable_kernel/include/utility/amd_address_space.hpp rename to include/ck/utility/amd_address_space.hpp index 24c95b27af0..3c5939aaf30 100644 --- a/composable_kernel/include/utility/amd_address_space.hpp +++ b/include/ck/utility/amd_address_space.hpp @@ -9,7 +9,7 @@ namespace ck { -enum AddressSpaceEnum_t +enum struct AddressSpaceEnum { Generic, Global, @@ -19,7 +19,7 @@ enum AddressSpaceEnum_t }; template -__device__ T* cast_pointer_to_generic_address_space(T CONSTANT* p) +__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p) { // cast a pointer in "Constant" address space (4) to "Generic" address space (0) // only c-style pointer cast seems be able to be compiled @@ -30,13 +30,13 @@ __device__ T* cast_pointer_to_generic_address_space(T CONSTANT* p) } template -__host__ __device__ T CONSTANT* cast_pointer_to_constant_address_space(T* p) +__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p) { // cast a pointer in "Generic" address space (0) to "Constant" address space (4) // only c-style pointer cast seems be able to be compiled #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wold-style-cast" - return (T CONSTANT*)p; // NOLINT(old-style-cast) + return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast) #pragma clang diagnostic pop } diff --git a/composable_kernel/include/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp similarity index 73% rename from composable_kernel/include/utility/amd_buffer_addressing.hpp rename to include/ck/utility/amd_buffer_addressing.hpp index 3df53bda443..6831658fc9b 100644 --- a/composable_kernel/include/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1,6 +1,4 @@ -#ifndef CK_AMD_BUFFER_ADDRESSING_HPP -#define CK_AMD_BUFFER_ADDRESSING_HPP - +#pragma once #include "data_type.hpp" namespace ck { @@ -31,7 +29,7 @@ __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_ return wave_buffer_resource.content; } -// load +// buffer load i8 __device__ int8_t llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, index_t voffset, @@ -50,11 +48,26 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); -__device__ int16_t +// buffer load i16 +__device__ bhalf_t llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, index_t voffset, index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16"); + +__device__ bhalf2_t +llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16"); + +__device__ bhalf4_t +llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16"); + +// buffer load i32 __device__ int32_t llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, index_t voffset, @@ -72,7 +85,8 @@ llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); -// half + +// buffer load fp16 __device__ half_t llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, index_t voffset, @@ -91,7 +105,7 @@ llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16"); -// float +// buffer load fp32 __device__ float llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, index_t voffset, @@ -110,7 +124,7 @@ llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32"); -// store +// buffer store i8 __device__ void llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, int32x4_t rsrc, @@ -132,13 +146,29 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8"); +// buffer store i16 __device__ void -llvm_amdgcn_raw_buffer_store_i16(int16_t vdata, +llvm_amdgcn_raw_buffer_store_i16(bhalf_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); +__device__ void +llvm_amdgcn_raw_buffer_store_i16x2(bhalf2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i16x4(bhalf4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16"); + +// buffer store i32 __device__ void llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, int32x4_t rsrc, @@ -160,7 +190,7 @@ llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); -// half +// buffer store fp16 __device__ void llvm_amdgcn_raw_buffer_store_fp16(half_t vdata, int32x4_t rsrc, @@ -181,7 +211,8 @@ llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16"); -// float + +// buffer store fp32 __device__ void llvm_amdgcn_raw_buffer_store_fp32(float vdata, int32x4_t rsrc, @@ -202,8 +233,16 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); -// atomic add -// int + +// buffer atomic-add fp16 +__device__ half2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( + half2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16"); + +// buffer atomic-add i32 __device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( int32_t vdata, int32x4_t rsrc, @@ -211,7 +250,7 @@ __device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32"); -// float +// buffer atomic-add fp32 __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( float vdata, int32x4_t rsrc, @@ -219,6 +258,14 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32"); +// buffer atomic-add fp32 +__device__ double +llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, + int32x4_t rsrc, // dst_wave_buffer_resource + int voffset, // dst_thread_addr_offset + int soffset, // dst_wave_addr_offset + int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); + template __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, @@ -228,6 +275,7 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w (is_same::value && (N == 1 || N == 2 || N == 4)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); @@ -240,14 +288,14 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w const float2_t tmp = llvm_amdgcn_raw_buffer_load_fp32x2( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - return as_type(tmp); + return bit_cast(tmp); } else if constexpr(N == 2) { const float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - return as_type(tmp); + return bit_cast(tmp); } else if constexpr(N == 4) { @@ -261,8 +309,8 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w 0); vector_type tmp; - tmp.AsType()(Number<0>{}) = as_type(f32_0); - tmp.AsType()(Number<1>{}) = as_type(f32_1); + tmp.AsType()(Number<0>{}) = bit_cast(f32_0); + tmp.AsType()(Number<1>{}) = bit_cast(f32_1); return tmp.AsType()(Number<0>{}); } @@ -323,7 +371,32 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - return as_type(tmp); + return bit_cast(tmp); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_i16( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { + return llvm_amdgcn_raw_buffer_load_i16x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 4) + { + return llvm_amdgcn_raw_buffer_load_i16x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 8) + { + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return bit_cast(tmp); } } else if constexpr(is_same::value) @@ -374,7 +447,7 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w int16_t tmp = llvm_amdgcn_raw_buffer_load_i16( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - return as_type(tmp); + return bit_cast(tmp); #endif } else if constexpr(N == 4) @@ -386,7 +459,7 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w int32_t tmp = llvm_amdgcn_raw_buffer_load_i32( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - return as_type(tmp); + return bit_cast(tmp); #endif } else if constexpr(N == 8) @@ -408,7 +481,7 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - return as_type(tmp); + return bit_cast(tmp); #endif } else if constexpr(N == 16) @@ -442,7 +515,7 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - return as_type(tmp); + return bit_cast(tmp); #endif } } @@ -458,6 +531,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src (is_same::value && (N == 1 || N == 2)) || (is_same::value && (N == 1 || N == 2 || N == 4)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); @@ -467,7 +541,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src // use fp32 store to mimic fp64 store if constexpr(N == 1) { - llvm_amdgcn_raw_buffer_store_fp32x2(as_type(src_thread_data), + llvm_amdgcn_raw_buffer_store_fp32x2(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -475,7 +549,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src } else if constexpr(N == 2) { - llvm_amdgcn_raw_buffer_store_fp32x4(as_type(src_thread_data), + llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -537,6 +611,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src } else if constexpr(N == 8) { +#if 0 vector_type tmp{src_thread_data}; llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<0>{}], @@ -550,6 +625,56 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_thread_addr_offset, dst_wave_addr_offset + 4 * sizeof(half_t), 0); +#else + llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); +#endif + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_i16(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_i16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_i16x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(bhalf_t), + 0); } } else if constexpr(is_same::value) @@ -598,7 +723,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_addr_offset, 0); #else - llvm_amdgcn_raw_buffer_store_i16(as_type(src_thread_data), + llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -614,7 +739,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_addr_offset, 0); #else - llvm_amdgcn_raw_buffer_store_i32(as_type(src_thread_data), + llvm_amdgcn_raw_buffer_store_i32(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -623,7 +748,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src } else if constexpr(N == 8) { - llvm_amdgcn_raw_buffer_store_i32x2(as_type(src_thread_data), + llvm_amdgcn_raw_buffer_store_i32x2(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -631,7 +756,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src } else if constexpr(N == 16) { - llvm_amdgcn_raw_buffer_store_i32x4(as_type(src_thread_data), + llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -647,6 +772,7 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type::typ index_t dst_wave_addr_offset) { static_assert((is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4)), "wrong! not implemented"); @@ -705,6 +831,41 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type::typ 0); } } + else if constexpr(is_same::value) + { + if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + static_for<0, 2, 1>{}([&](auto i) { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + i * sizeof(half2_t), + 0); + }); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + + static_for<0, 4, 1>{}([&](auto i) { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + i * sizeof(half2_t), + 0); + }); + } + } else if constexpr(is_same::value) { if constexpr(N == 1) @@ -762,16 +923,81 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type::typ } } +template +__device__ void amd_buffer_atomic_max_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert((is_same::value && (N == 1 || N == 2 || N == 4)), + "wrong! not implemented"); + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_max_fp64(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(double), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(double), + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(double), + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(double), + 0); + } + } +} + // buffer_load requires: // 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template __device__ typename vector_type_maker::type::type -amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave, - index_t src_thread_element_offset, - bool src_thread_element_valid, - index_t src_element_space_size) +amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, + index_t src_thread_element_offset, + bool src_thread_element_valid, + index_t src_element_space_size) { const int32x4_t src_wave_buffer_resource = make_wave_buffer_resource(p_src_wave, src_element_space_size); @@ -893,5 +1119,39 @@ amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thr #endif } -} // namespace ck +// buffer_atomic_max requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ void +amd_buffer_atomic_max(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; + + amd_buffer_atomic_max_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_atomic_max_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } #endif +} + +} // namespace ck diff --git a/composable_kernel/include/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp similarity index 94% rename from composable_kernel/include/utility/amd_inline_asm.hpp rename to include/ck/utility/amd_inline_asm.hpp index a2d9d5f062a..fc0a15bf849 100644 --- a/composable_kernel/include/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -211,14 +211,14 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0 v_dot4_i32_i8 %1, %2, %4, %1\n \ " : "=v"(c0), "=v"(c1) - : "v"(as_type(a)), - "v"(as_type(b0)), - "v"(as_type(b1)), + : "v"(bit_cast(a)), + "v"(bit_cast(b0)), + "v"(bit_cast(b1)), "0"(c0), "1"(c1)); #else - c0 = __builtin_amdgcn_sdot4(as_type(a), as_type(b0), c0, false); - c1 = __builtin_amdgcn_sdot4(as_type(a), as_type(b1), c1, false); + c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); #endif } @@ -244,20 +244,20 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a, v_dot4_i32_i8 %3, %4, %8, %3\n \ " : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) - : "v"(as_type(a)), - "v"(as_type(b0)), - "v"(as_type(b1)), - "v"(as_type(b2)), - "v"(as_type(b3)), + : "v"(bit_cast(a)), + "v"(bit_cast(b0)), + "v"(bit_cast(b1)), + "v"(bit_cast(b2)), + "v"(bit_cast(b3)), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); #else - c0 = __builtin_amdgcn_sdot4(as_type(a), as_type(b0), c0, false); - c1 = __builtin_amdgcn_sdot4(as_type(a), as_type(b1), c1, false); - c2 = __builtin_amdgcn_sdot4(as_type(a), as_type(b2), c2, false); - c3 = __builtin_amdgcn_sdot4(as_type(a), as_type(b3), c3, false); + c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); + c2 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b2), c2, false); + c3 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b3), c3, false); #endif } diff --git a/composable_kernel/include/utility/amd_llvm_intrinsic.hpp b/include/ck/utility/amd_llvm_intrinsic.hpp similarity index 100% rename from composable_kernel/include/utility/amd_llvm_intrinsic.hpp rename to include/ck/utility/amd_llvm_intrinsic.hpp diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp new file mode 100644 index 00000000000..94693f510e7 --- /dev/null +++ b/include/ck/utility/amd_xdlops.hpp @@ -0,0 +1,298 @@ +#ifndef CK_AMD_XDLOPS_HPP +#define CK_AMD_XDLOPS_HPP + +#include "data_type.hpp" + +namespace ck { + +// fp32 +template +struct intrin_mfma_f32_32x32x1f32; + +template <> +struct intrin_mfma_f32_32x32x1f32<64, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); + } +}; + +template <> +struct intrin_mfma_f32_32x32x1f32<32, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + } +}; + +template +struct intrin_mfma_f32_32x32x2f32; + +template <> +struct intrin_mfma_f32_32x32x2f32<32, 32> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x4f32; + +template <> +struct intrin_mfma_f32_16x16x4f32<16, 16> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x1f32; + +template <> +struct intrin_mfma_f32_16x16x1f32<16, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); + } +}; + +template +struct intrin_mfma_f32_4x4x1f32; + +template <> +struct intrin_mfma_f32_4x4x1f32<4, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + } +}; + +template <> +struct intrin_mfma_f32_4x4x1f32<8, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); + } +}; + +// fp16 +template +struct intrin_mfma_f32_32x32x4f16; + +template <> +struct intrin_mfma_f32_32x32x4f16<64, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); + } +}; + +template <> +struct intrin_mfma_f32_32x32x4f16<32, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + } +}; + +template +struct intrin_mfma_f32_32x32x8f16; + +template <> +struct intrin_mfma_f32_32x32x8f16<32, 32> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x16f16; + +template <> +struct intrin_mfma_f32_16x16x16f16<16, 16> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x4f16; + +template <> +struct intrin_mfma_f32_16x16x4f16<16, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); + } +}; + +template +struct intrin_mfma_f32_4x4x4f16; + +template <> +struct intrin_mfma_f32_4x4x4f16<4, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + } +}; + +template <> +struct intrin_mfma_f32_4x4x4f16<8, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); + } +}; + +// bfp16 +template +struct intrin_mfma_f32_32x32x8bf16_1k; + +template <> +struct intrin_mfma_f32_32x32x8bf16_1k<32, 32> +{ + template + __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x16bf16_1k; + +template <> +struct intrin_mfma_f32_16x16x16bf16_1k<16, 16> +{ + template + __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_32x32x4bf16; + +template <> +struct intrin_mfma_f32_32x32x4bf16<32, 32> +{ + template + __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x8bf16; + +template <> +struct intrin_mfma_f32_16x16x8bf16<16, 16> +{ + template + __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_i32_32x32x8i8; + +template <> +struct intrin_mfma_i32_32x32x8i8<32, 32> +{ + template + __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_i32_16x16x16i8; + +template <> +struct intrin_mfma_i32_16x16x16i8<16, 16> +{ + template + __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/array.hpp b/include/ck/utility/array.hpp similarity index 93% rename from composable_kernel/include/utility/array.hpp rename to include/ck/utility/array.hpp index 911cefd0571..4c9dfd9a934 100644 --- a/composable_kernel/include/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -49,7 +49,7 @@ template __host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) { using data_type = remove_cvref_t; - return Array{{std::forward(x), std::forward(xs)...}}; + return Array{std::forward(x), std::forward(xs)...}; } // make empty array diff --git a/composable_kernel/include/utility/array_multi_index.hpp b/include/ck/utility/array_multi_index.hpp similarity index 100% rename from composable_kernel/include/utility/array_multi_index.hpp rename to include/ck/utility/array_multi_index.hpp diff --git a/composable_kernel/include/utility/c_style_pointer_cast.hpp b/include/ck/utility/c_style_pointer_cast.hpp similarity index 100% rename from composable_kernel/include/utility/c_style_pointer_cast.hpp rename to include/ck/utility/c_style_pointer_cast.hpp diff --git a/composable_kernel/include/utility/common_header.hpp b/include/ck/utility/common_header.hpp similarity index 77% rename from composable_kernel/include/utility/common_header.hpp rename to include/ck/utility/common_header.hpp index 85c02a1b99d..34c0a7821b3 100644 --- a/composable_kernel/include/utility/common_header.hpp +++ b/include/ck/utility/common_header.hpp @@ -1,6 +1,4 @@ -#ifndef CK_COMMON_HEADER_HPP -#define CK_COMMON_HEADER_HPP - +#pragma once #include "config.hpp" #include "array.hpp" #include "container_helper.hpp" @@ -15,32 +13,37 @@ #include "functional3.hpp" #include "functional4.hpp" #include "enable_if.hpp" +#include "ignore.hpp" #include "integral_constant.hpp" #include "math.hpp" #include "number.hpp" #include "sequence.hpp" #include "sequence_helper.hpp" -#include "synchronization.hpp" #include "tuple.hpp" #include "tuple_helper.hpp" #include "type.hpp" #include "magic_division.hpp" -#include "utility.hpp" #include "c_style_pointer_cast.hpp" -#include "amd_address_space.hpp" +#include "is_known_at_compile_time.hpp" +#include "transpose_vectors.hpp" +#include "inner_product.hpp" +#include "element_wise_operation.hpp" +#include "thread_group.hpp" +#include "debug.hpp" + #include "amd_buffer_addressing.hpp" +#include "generic_memory_space_atomic.hpp" +#include "get_id.hpp" +#include "synchronization.hpp" +#include "amd_address_space.hpp" #include "static_buffer.hpp" #include "dynamic_buffer.hpp" -#include "inner_product.hpp" - // TODO: remove this #if CK_USE_AMD_INLINE_ASM #include "amd_inline_asm.hpp" #endif -#if CK_USE_AMD_XDLOPS +#ifdef CK_USE_AMD_MFMA #include "amd_xdlops.hpp" #endif - -#endif diff --git a/composable_kernel/include/utility/container_element_picker.hpp b/include/ck/utility/container_element_picker.hpp similarity index 100% rename from composable_kernel/include/utility/container_element_picker.hpp rename to include/ck/utility/container_element_picker.hpp diff --git a/composable_kernel/include/utility/container_helper.hpp b/include/ck/utility/container_helper.hpp similarity index 97% rename from composable_kernel/include/utility/container_helper.hpp rename to include/ck/utility/container_helper.hpp index a7ed8ec059e..a92e79908d9 100644 --- a/composable_kernel/include/utility/container_helper.hpp +++ b/include/ck/utility/container_helper.hpp @@ -373,19 +373,6 @@ set_container_subset(Tuple& y, Sequence picks, const Tuple& static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); } -template -__host__ __device__ constexpr auto to_tuple_of_number(const Container&) -{ - static_assert(is_known_at_compile_time::value, "wrong!"); - - return generate_tuple( - [&](auto i) { - constexpr index_t tmp = Container::At(i); - return Number{}; - }, - Container::Size()); -} - template __host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence) { diff --git a/composable_kernel/include/utility/data_type.hpp b/include/ck/utility/data_type.hpp similarity index 87% rename from composable_kernel/include/utility/data_type.hpp rename to include/ck/utility/data_type.hpp index 07eceb84cff..bf8dc74f34c 100644 --- a/composable_kernel/include/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -1,11 +1,10 @@ -#ifndef CK_FLOAT_TYPE_AMD_HPP -#define CK_FLOAT_TYPE_AMD_HPP - +#pragma once #include "statically_indexed_array.hpp" namespace ck { -using half_t = _Float16; +using bhalf_t = ushort; +using half_t = _Float16; // vector_type template @@ -58,6 +57,18 @@ __host__ __device__ constexpr auto make_vector_type(Number) template struct scalar_type; +// is_scalar_type +template +struct is_scalar_type +{ + static constexpr bool value = (scalar_type>::vector_size == 1); +}; + +// has_same_scalar_type +template +using has_same_scalar_type = is_same>::type, + typename scalar_type>::type>; + template struct scalar_type { @@ -95,9 +106,9 @@ struct scalar_type }; template <> -struct scalar_type +struct scalar_type { - using type = ushort; + using type = bhalf_t; static constexpr index_t vector_size = 1; }; @@ -892,12 +903,12 @@ using half32_t = typename vector_type::type; using half64_t = typename vector_type::type; // bfp16 -using ushort2_t = typename vector_type::type; -using ushort4_t = typename vector_type::type; -using ushort8_t = typename vector_type::type; -using ushort16_t = typename vector_type::type; -using ushort32_t = typename vector_type::type; -using ushort64_t = typename vector_type::type; +using bhalf2_t = typename vector_type::type; +using bhalf4_t = typename vector_type::type; +using bhalf8_t = typename vector_type::type; +using bhalf16_t = typename vector_type::type; +using bhalf32_t = typename vector_type::type; +using bhalf64_t = typename vector_type::type; // i32 using int32x2_t = typename vector_type::type; @@ -915,97 +926,71 @@ using int8x16_t = typename vector_type::type; using int8x32_t = typename vector_type::type; using int8x64_t = typename vector_type::type; -// data type conversion -template -struct type_convert +// Convert X to Y +template +__host__ __device__ Y type_convert(X x) { - template - __device__ T operator()(X x) const - { - return static_cast(x); - } -}; - -template <> -template <> -__device__ float type_convert::operator()(ushort x) const -{ - return bfloat16_to_float(x); + return static_cast(x); } +// convert bfp16 to fp32 template <> -template <> -__device__ ushort type_convert::operator()(float x) const +inline __host__ __device__ float type_convert(bhalf_t x) { - return float_to_bfloat16(x); -} - -// TODO: deprecate this -template -struct inner_product_with_conversion -{ - static constexpr auto convert = type_convert(); - - template - __device__ T operator()(typename vector_type::type a, - typename vector_type::type b) const + union { - const vector_type a_vector{a}; - const vector_type b_vector{b}; - - T acc = 0; + uint32_t int32; + float fp32; + } u = {uint32_t(x) << 16}; - static_for<0, N, 1>{}([&](auto i) { - acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]); - }); - - return acc; - } - - __device__ T operator()(float_t a, float_t b) const { return convert(a) * convert(b); } + return u.fp32; +} - __device__ T operator()(int8x4_t a, int8x4_t b) const +// convert fp32 to bfp16 +template <> +inline __host__ __device__ bhalf_t type_convert(float x) +{ + union { - const vector_type a_vector{a}; - const vector_type b_vector{b}; - - T acc = 0; - - static_for<0, 4, 1>{}([&](auto i) { - acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); - }); - - return acc; + float fp32; + uint32_t int32; + } u = {x}; + + if(~u.int32 & 0x7f800000) + { + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even } - - __device__ T operator()(int8x8_t a, int8x8_t b) const - { - const vector_type a_vector{a}; - const vector_type b_vector{b}; - - T acc = 0; - - static_for<0, 8, 1>{}([&](auto i) { - acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); - }); - - return acc; + else if(u.int32 & 0xffff) + { + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bloat16's mantissa bits are all 0. + u.int32 |= 0x10000; // Preserve signaling NaN } - __device__ T operator()(int8x16_t a, int8x16_t b) const - { - const vector_type a_vector{a}; - const vector_type b_vector{b}; - - T acc = 0; - - static_for<0, 16, 1>{}([&](auto i) { - acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); - }); - - return acc; - } -}; + return uint16_t(u.int32 >> 16); +} template struct NumericLimits @@ -1024,12 +1009,11 @@ struct NumericLimits static constexpr unsigned short binary_max = 0x7BFF; static constexpr unsigned short binary_lowest = 0xFBFF; - __host__ __device__ static constexpr half_t Min() { return as_type(binary_min); } + __host__ __device__ static constexpr half_t Min() { return bit_cast(binary_min); } - __host__ __device__ static constexpr half_t Max() { return as_type(binary_max); } + __host__ __device__ static constexpr half_t Max() { return bit_cast(binary_max); } - __host__ __device__ static constexpr half_t Lowest() { return as_type(binary_lowest); } + __host__ __device__ static constexpr half_t Lowest() { return bit_cast(binary_lowest); } }; } // namespace ck -#endif diff --git a/composable_kernel/include/utility/data_type_enum.hpp b/include/ck/utility/data_type_enum.hpp similarity index 91% rename from composable_kernel/include/utility/data_type_enum.hpp rename to include/ck/utility/data_type_enum.hpp index 35df0067a9a..fda6a2b05cf 100644 --- a/composable_kernel/include/utility/data_type_enum.hpp +++ b/include/ck/utility/data_type_enum.hpp @@ -3,7 +3,7 @@ namespace ck { -enum DataTypeEnum_t +enum struct DataTypeEnum { Half = 0, Float = 1, diff --git a/composable_kernel/include/utility/data_type_enum_helper.hpp b/include/ck/utility/data_type_enum_helper.hpp similarity index 55% rename from composable_kernel/include/utility/data_type_enum_helper.hpp rename to include/ck/utility/data_type_enum_helper.hpp index 451ce992b1f..9c8e01a7e38 100644 --- a/composable_kernel/include/utility/data_type_enum_helper.hpp +++ b/include/ck/utility/data_type_enum_helper.hpp @@ -6,35 +6,35 @@ namespace ck { -template +template struct get_datatype_from_enum; template <> -struct get_datatype_from_enum +struct get_datatype_from_enum { using type = int8_t; }; template <> -struct get_datatype_from_enum +struct get_datatype_from_enum { using type = int32_t; }; template <> -struct get_datatype_from_enum +struct get_datatype_from_enum { using type = half_t; }; template <> -struct get_datatype_from_enum +struct get_datatype_from_enum { using type = float; }; template <> -struct get_datatype_from_enum +struct get_datatype_from_enum { using type = double; }; @@ -45,31 +45,31 @@ struct get_datatype_enum_from_type; template <> struct get_datatype_enum_from_type { - static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int8; + static constexpr DataTypeEnum value = DataTypeEnum::Int8; }; template <> struct get_datatype_enum_from_type { - static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int32; + static constexpr DataTypeEnum value = DataTypeEnum::Int32; }; template <> struct get_datatype_enum_from_type { - static constexpr DataTypeEnum_t value = DataTypeEnum_t::Half; + static constexpr DataTypeEnum value = DataTypeEnum::Half; }; template <> struct get_datatype_enum_from_type { - static constexpr DataTypeEnum_t value = DataTypeEnum_t::Float; + static constexpr DataTypeEnum value = DataTypeEnum::Float; }; template <> struct get_datatype_enum_from_type { - static constexpr DataTypeEnum_t value = DataTypeEnum_t::Double; + static constexpr DataTypeEnum value = DataTypeEnum::Double; }; } // namespace ck diff --git a/include/ck/utility/debug.hpp b/include/ck/utility/debug.hpp new file mode 100644 index 00000000000..a5b34fce74a --- /dev/null +++ b/include/ck/utility/debug.hpp @@ -0,0 +1,77 @@ +#ifndef UTILITY_DEBUG_HPP +#define UTILITY_DEBUG_HPP + +namespace ck { +namespace debug { + +namespace detail { +template +struct PrintAsType; + +template +struct PrintAsType::value>::value> +{ + using type = float; +}; + +template <> +struct PrintAsType +{ + using type = float; +}; + +template +struct PrintAsType::value>::value> +{ + using type = int; +}; +} // namespace detail + +// Print at runtime the data in shared memory in 128 bytes per row format given shared mem pointer +// and the number of elements. Can optionally specify strides between elements and how many bytes' +// worth of data per row. +// +// Usage example: +// +// debug::print_shared(a_block_buf.p_data_, index_t(a_block_desc_k0_m_k1.GetElementSpaceSize())); +// +template +__device__ void print_shared(T const* p_shared, index_t num_elements) +{ + using PrintType = typename detail::PrintAsType::type; + constexpr index_t row_elements = row_bytes / sizeof(T); + static_assert((element_stride >= 1 && element_stride <= row_elements), + "element_stride should between [1, row_elements]"); + + index_t wgid = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + index_t tid = + (threadIdx.z * (blockDim.x * blockDim.y)) + (threadIdx.y * blockDim.x) + threadIdx.x; + + __syncthreads(); + + if(tid == 0) + { + printf("\nWorkgroup id %d, bytes per row %d, element stride %d\n\n", + wgid, + row_bytes, + element_stride); + for(index_t i = 0; i < num_elements; i += row_elements) + { + printf("elem %5d: ", i); + for(index_t j = 0; j < row_elements; j += element_stride) + { + printf("%.0f ", static_cast(p_shared[i + j])); + } + + printf("\n"); + } + printf("\n"); + } + + __syncthreads(); +} + +} // namespace debug +} // namespace ck + +#endif // UTILITY_DEBUG_HPP diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp new file mode 100644 index 00000000000..0ad78423fe5 --- /dev/null +++ b/include/ck/utility/dynamic_buffer.hpp @@ -0,0 +1,393 @@ +#pragma once +#include "config.hpp" +#include "enable_if.hpp" +#include "c_style_pointer_cast.hpp" +#include "amd_buffer_addressing.hpp" +#include "generic_memory_space_atomic.hpp" + +namespace ck { + +// T may be scalar or vector +// X may be scalar or vector +// T and X have same scalar type +// X contains multiple T +template +struct DynamicBuffer +{ + using type = T; + + T* p_data_; + ElementSpaceSize element_space_size_; + T invalid_element_value_ = T{0}; + + __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) + : p_data_{p_data}, element_space_size_{element_space_size} + { + } + + __host__ __device__ constexpr DynamicBuffer(T* p_data, + ElementSpaceSize element_space_size, + T invalid_element_value) + : p_data_{p_data}, + element_space_size_{element_space_size}, + invalid_element_value_{invalid_element_value} + { + } + + __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() + { + return BufferAddressSpace; + } + + __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; } + + __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + +#if CK_USE_AMD_BUFFER_LOAD + bool constexpr use_amd_buffer_addressing = true; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + if constexpr(InvalidElementUseNumericalZeroValue) + { + return amd_buffer_load_invalid_element_return_zero, t_per_x>( + p_data_, i, is_valid_element, element_space_size_); + } + else + { + return amd_buffer_load_invalid_element_return_customized_value, + t_per_x>( + p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); + } + } + else + { + if(is_valid_element) + { +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp; + + __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); + + return tmp; +#else + return *c_style_pointer_cast(&p_data_[i]); +#endif + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return X{0}; + } + else + { + return X{invalid_element_value_}; + } + } + } + } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ void Update(index_t i, bool is_valid_element, const X& x) + { + if constexpr(Op == InMemoryDataOperationEnum::Set) + { + this->template Set(i, is_valid_element, x); + } + else if constexpr(Op == InMemoryDataOperationEnum::AtomicAdd) + { + this->template AtomicAdd(i, is_valid_element, x); + } + else if constexpr(Op == InMemoryDataOperationEnum::AtomicMax) + { + this->template AtomicMax(i, is_valid_element, x); + } + else if constexpr(Op == InMemoryDataOperationEnum::Add) + { + auto tmp = this->template Get(i, is_valid_element); + this->template Set(i, is_valid_element, x + tmp); + // tmp += x; + // this->template Set(i, is_valid_element, tmp); + } + } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + +#if CK_USE_AMD_BUFFER_STORE + bool constexpr use_amd_buffer_addressing = true; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + +#if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE + bool constexpr workaround_int8_ds_write_issue = true; +#else + bool constexpr workaround_int8_ds_write_issue = false; +#endif + + if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_store, t_per_x>( + x, p_data_, i, is_valid_element, element_space_size_); + } + else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && + is_same>::type, int8_t>::value && + workaround_int8_ds_write_issue) + { + if(is_valid_element) + { + // HACK: compiler would lower IR "store address_space(3)" into inefficient + // ISA, so I try to let compiler emit IR "store" which would be lower to + // ds_write_b128 + // TODO: remove this after compiler fix + static_assert((is_same, int8_t>::value && + is_same, int8_t>::value) || + (is_same, int8_t>::value && + is_same, int8x2_t>::value) || + (is_same, int8_t>::value && + is_same, int8x4_t>::value) || + (is_same, int8_t>::value && + is_same, int8x8_t>::value) || + (is_same, int8_t>::value && + is_same, int8x16_t>::value) || + (is_same, int8x4_t>::value && + is_same, int8x4_t>::value) || + (is_same, int8x8_t>::value && + is_same, int8x8_t>::value) || + (is_same, int8x16_t>::value && + is_same, int8x16_t>::value), + "wrong! not implemented for this combination, please add " + "implementation"); + + if constexpr(is_same, int8_t>::value && + is_same, int8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x2_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x16_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8x4_t>::value && + is_same, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8x8_t>::value && + is_same, int8x8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8x16_t>::value && + is_same, int8x16_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + } + } + else + { + if(is_valid_element) + { +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp = x; + + __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); +#else + *c_style_pointer_cast(&p_data_[i]) = x; +#endif + } + } + } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x) + { + using scalar_t = typename scalar_type>::type; + + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem"); + +#if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT + bool constexpr use_amd_buffer_addressing = + is_same_v, int32_t> || + is_same_v, float> || + (is_same_v, half_t> && scalar_per_x_vector % 2 == 0); +#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) + bool constexpr use_amd_buffer_addressing = is_same_v, int32_t>; +#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT + bool constexpr use_amd_buffer_addressing = + is_same_v, float> || + (is_same_v, half_t> && scalar_per_x_vector % 2 == 0); +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_atomic_add, t_per_x>( + x, p_data_, i, is_valid_element, element_space_size_); + } + else + { + if(is_valid_element) + { + atomic_add(c_style_pointer_cast(&p_data_[i]), x); + } + } + } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ void AtomicMax(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem"); + +#if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 + using scalar_t = typename scalar_type>::type; + bool constexpr use_amd_buffer_addressing = is_same_v, double>; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_atomic_max, t_per_x>( + x, p_data_, i, is_valid_element, element_space_size_); + } + else if(is_valid_element) + { + atomic_max(c_style_pointer_cast(&p_data_[i]), x); + } + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return false; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } +}; + +template +__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) +{ + return DynamicBuffer{p, element_space_size}; +} + +template < + AddressSpaceEnum BufferAddressSpace, + typename T, + typename ElementSpaceSize, + typename X, + typename enable_if, remove_cvref_t>::value, bool>::type = false> +__host__ __device__ constexpr auto +make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value) +{ + return DynamicBuffer{ + p, element_space_size, invalid_element_value}; +} + +} // namespace ck diff --git a/composable_kernel/include/utility/enable_if.hpp b/include/ck/utility/enable_if.hpp similarity index 100% rename from composable_kernel/include/utility/enable_if.hpp rename to include/ck/utility/enable_if.hpp diff --git a/composable_kernel/include/utility/functional.hpp b/include/ck/utility/functional.hpp similarity index 100% rename from composable_kernel/include/utility/functional.hpp rename to include/ck/utility/functional.hpp diff --git a/composable_kernel/include/utility/functional2.hpp b/include/ck/utility/functional2.hpp similarity index 100% rename from composable_kernel/include/utility/functional2.hpp rename to include/ck/utility/functional2.hpp diff --git a/composable_kernel/include/utility/functional3.hpp b/include/ck/utility/functional3.hpp similarity index 100% rename from composable_kernel/include/utility/functional3.hpp rename to include/ck/utility/functional3.hpp diff --git a/composable_kernel/include/utility/functional4.hpp b/include/ck/utility/functional4.hpp similarity index 100% rename from composable_kernel/include/utility/functional4.hpp rename to include/ck/utility/functional4.hpp diff --git a/include/ck/utility/generic_memory_space_atomic.hpp b/include/ck/utility/generic_memory_space_atomic.hpp new file mode 100644 index 00000000000..1a2dacb5c50 --- /dev/null +++ b/include/ck/utility/generic_memory_space_atomic.hpp @@ -0,0 +1,120 @@ +#pragma once +#include "data_type.hpp" + +namespace ck { + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to make the implementation of atomic_add explicit for +// each datatype. +template +__device__ X atomic_add(X* p_dst, const X& x); + +template <> +__device__ int32_t atomic_add(int32_t* p_dst, const int32_t& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ uint32_t atomic_add(uint32_t* p_dst, const uint32_t& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ float atomic_add(float* p_dst, const float& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ double atomic_add(double* p_dst, const double& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ float2_t atomic_add(float2_t* p_dst, const float2_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicAdd(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicAdd(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + + return vy.template AsType()[I0]; +} + +template <> +__device__ double2_t atomic_add(double2_t* p_dst, const double2_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicAdd(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicAdd(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + + return vy.template AsType()[I0]; +} + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to make the implementation of atomic_max explicit for +// each datatype. + +template +__device__ X atomic_max(X* p_dst, const X& x); + +template <> +__device__ int32_t atomic_max(int32_t* p_dst, const int32_t& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ uint32_t atomic_max(uint32_t* p_dst, const uint32_t& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ float atomic_max(float* p_dst, const float& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ double atomic_max(double* p_dst, const double& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ float2_t atomic_max(float2_t* p_dst, const float2_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicMax(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicMax(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + + return vy.template AsType()[I0]; +} + +} // namespace ck diff --git a/include/ck/utility/get_id.hpp b/include/ck/utility/get_id.hpp new file mode 100644 index 00000000000..7c62b890c75 --- /dev/null +++ b/include/ck/utility/get_id.hpp @@ -0,0 +1,24 @@ +#pragma once +#include "config.hpp" + +namespace ck { + +__host__ __device__ constexpr index_t get_warp_size() +{ + // warpSize is defined by HIP + return warpSize; +} + +__device__ index_t get_thread_local_1d_id() { return threadIdx.x; } + +__device__ index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; } + +__device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size(); } + +__device__ index_t get_block_1d_id() { return blockIdx.x; } + +__device__ index_t get_grid_size() { return gridDim.x; } + +__device__ index_t get_block_size() { return blockDim.x; } + +} // namespace ck diff --git a/include/ck/utility/ignore.hpp b/include/ck/utility/ignore.hpp new file mode 100644 index 00000000000..8a199159b3e --- /dev/null +++ b/include/ck/utility/ignore.hpp @@ -0,0 +1,21 @@ +#ifndef CK_IGNORE_HPP +#define CK_IGNORE_HPP + +// https://en.cppreference.com/w/cpp/utility/tuple/ignore + +namespace ck { + +namespace detail { +struct ignore_t +{ + template + constexpr void operator=(T&&) const noexcept + { + } +}; +} // namespace detail + +inline constexpr detail::ignore_t ignore; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp similarity index 91% rename from composable_kernel/include/utility/inner_product.hpp rename to include/ck/utility/inner_product.hpp index 51753accf3d..59fe17e8675 100644 --- a/composable_kernel/include/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -1,6 +1,4 @@ -#ifndef CK_INNER_PRODUCT_HPP -#define CK_INNER_PRODUCT_HPP - +#pragma once #include "data_type.hpp" namespace ck { @@ -84,13 +82,12 @@ __device__ void inner_product(const half2_t& a, const h c = __builtin_amdgcn_sdot2(a, b, c, false); #endif #else - const auto convert = type_convert{}; - const vector_type a_vector{a}; const vector_type b_vector{b}; static_for<0, 2, 1>{}([&](auto i) { - c += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + c += type_convert(a_vector.AsType()[i]) * + type_convert(b_vector.AsType()[i]); }); #endif } @@ -139,24 +136,23 @@ template <> __device__ void inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c) { -#if defined(CK_USE_DOT4_I32_I8) +#if defined(CK_USE_AMD_V_DOT4_I32_I8) #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM asm volatile("\n \ v_dot4_i32_i8 %0, %1, %2, %0\n \ " : "=v"(c) - : "v"(as_type(a)), "v"(as_type(b)), "0"(c)); + : "v"(bit_cast(a)), "v"(bit_cast(b)), "0"(c)); #else - c = __builtin_amdgcn_sdot4(as_type(a), as_type(b), c, false); + c = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b), c, false); #endif #else - const auto convert = type_convert{}; - const vector_type a_vector{a}; const vector_type b_vector{b}; static_for<0, 4, 1>{}([&](auto i) { - c += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + c += type_convert(a_vector.AsType()[i]) * + type_convert(b_vector.AsType()[i]); }); #endif } @@ -204,4 +200,3 @@ inner_product(const int8x16_t& a, const int8x16_t } } // namespace ck -#endif diff --git a/include/ck/utility/integral_constant.hpp b/include/ck/utility/integral_constant.hpp new file mode 100644 index 00000000000..3d9c0472e7f --- /dev/null +++ b/include/ck/utility/integral_constant.hpp @@ -0,0 +1,50 @@ +#ifndef CK_INTEGRAL_CONSTANT_HPP +#define CK_INTEGRAL_CONSTANT_HPP + +namespace ck { + +template +struct integral_constant +{ + static constexpr T value = v; + typedef T value_type; + typedef integral_constant type; + __host__ __device__ constexpr operator value_type() const noexcept { return value; } + __host__ __device__ constexpr value_type operator()() const noexcept { return value; } +}; + +template +__host__ __device__ constexpr auto operator+(integral_constant, integral_constant) +{ + return integral_constant{}; +} + +template +__host__ __device__ constexpr auto operator-(integral_constant, integral_constant) +{ + static_assert(Y <= X, "wrong!"); + return integral_constant{}; +} + +template +__host__ __device__ constexpr auto operator*(integral_constant, integral_constant) +{ + return integral_constant{}; +} + +template +__host__ __device__ constexpr auto operator/(integral_constant, integral_constant) +{ + static_assert(Y > 0, "wrong!"); + return integral_constant{}; +} + +template +__host__ __device__ constexpr auto operator%(integral_constant, integral_constant) +{ + static_assert(Y > 0, "wrong!"); + return integral_constant{}; +} + +} // namespace ck +#endif diff --git a/include/ck/utility/is_known_at_compile_time.hpp b/include/ck/utility/is_known_at_compile_time.hpp new file mode 100644 index 00000000000..dc440279017 --- /dev/null +++ b/include/ck/utility/is_known_at_compile_time.hpp @@ -0,0 +1,55 @@ +#ifndef IS_KNOWN_AT_COMPILE_TIME_HPP +#define IS_KNOWN_AT_COMPILE_TIME_HPP + +#include "config.hpp" +#include "integral_constant.hpp" +#include "sequence.hpp" +#include "tuple.hpp" + +namespace ck { + +template +struct is_known_at_compile_time; + +template <> +struct is_known_at_compile_time +{ + static constexpr bool value = false; +}; + +template <> +struct is_known_at_compile_time +{ + static constexpr bool value = false; +}; + +template +struct is_known_at_compile_time> +{ + static constexpr bool value = true; +}; + +template +struct is_known_at_compile_time> +{ + static constexpr bool value = true; +}; + +template +struct is_known_at_compile_time> +{ + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return container_reduce( + Tuple{}, + [](auto x, bool r) { + return is_known_at_compile_time>::value & r; + }, + true); + } + + static constexpr bool value = IsKnownAtCompileTime(); +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/magic_division.hpp b/include/ck/utility/magic_division.hpp similarity index 71% rename from composable_kernel/include/utility/magic_division.hpp rename to include/ck/utility/magic_division.hpp index 612aceea2a1..61025767170 100644 --- a/composable_kernel/include/utility/magic_division.hpp +++ b/include/ck/utility/magic_division.hpp @@ -25,21 +25,30 @@ struct MagicDivision // uint32_t __host__ __device__ static constexpr auto CalculateMagicNumbers(uint32_t divisor) { - // assert(divisior >= 1 && divisior <= INT32_MAX); - uint32_t shift = 0; - for(shift = 0; shift < 32; ++shift) + // WARNING: magic division is only applicable for division inside this range. + // You should use the return value of CalculateMagicNumbers, if division is not inside this + // range. The "else" logic below is to quiet down run-time error. + if(divisor >= 1 && divisor <= INT32_MAX) { - if((1U << shift) >= divisor) + uint32_t shift = 0; + for(shift = 0; shift < 32; ++shift) { - break; + if((1U << shift) >= divisor) + { + break; + } } - } - uint64_t one = 1; - uint64_t multiplier = ((one << 32) * ((one << shift) - divisor)) / divisor + 1; - // assert(multiplier <= 0xffffffffUL); + uint64_t one = 1; + uint64_t multiplier = ((one << 32) * ((one << shift) - divisor)) / divisor + 1; + // assert(multiplier <= 0xffffffffUL); - return make_tuple(uint32_t(multiplier), shift); + return make_tuple(uint32_t(multiplier), shift); + } + else + { + return make_tuple(uint32_t(0), uint32_t(0)); + } } __host__ __device__ static constexpr uint32_t CalculateMagicMultiplier(uint32_t divisor) @@ -111,24 +120,39 @@ struct MagicDivision } // magic division for uint32_t - __host__ __device__ static constexpr uint32_t + __device__ static constexpr uint32_t DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) { uint32_t tmp = __umulhi(dividend, multiplier); return (tmp + dividend) >> shift; } + __host__ static constexpr uint32_t + DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) + { + uint32_t tmp = static_cast(dividend) * multiplier >> 32; + return (tmp + dividend) >> shift; + } + // magic division for int32_t // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be // non-negative for result to be correct // TODO: figure out how to do magic number divison for int32_t as dividended - __host__ __device__ static constexpr int32_t + __device__ static constexpr int32_t DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) { - uint32_t dividend_u32 = as_type(dividend_i32); + uint32_t dividend_u32 = bit_cast(dividend_i32); uint32_t tmp = __umulhi(dividend_u32, multiplier); return (tmp + dividend_u32) >> shift; } + + __host__ static constexpr int32_t + DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) + { + uint32_t dividend_u32 = bit_cast(dividend_i32); + uint32_t tmp = static_cast(dividend_u32) * multiplier >> 32; + return (tmp + dividend_u32) >> shift; + } }; } // namespace ck diff --git a/composable_kernel/include/utility/math.hpp b/include/ck/utility/math.hpp similarity index 100% rename from composable_kernel/include/utility/math.hpp rename to include/ck/utility/math.hpp diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp new file mode 100644 index 00000000000..572d576e7ac --- /dev/null +++ b/include/ck/utility/math_v2.hpp @@ -0,0 +1,66 @@ +#ifndef CK_MATH_V2_HPP +#define CK_MATH_V2_HPP + +#include +#include "data_type.hpp" +#include "half.hpp" + +namespace ck { +namespace math { + +static inline __host__ float abs(float x) { return std::abs(x); }; + +static inline __host__ double abs(double x) { return std::abs(x); }; + +static inline __host__ int8_t abs(int8_t x) +{ + int8_t sgn = x >> (8 - 1); + + return (x ^ sgn) - sgn; +}; + +static inline __host__ int32_t abs(int32_t x) +{ + int32_t sgn = x >> (32 - 1); + + return (x ^ sgn) - sgn; +}; + +static inline __host__ half_t abs(half_t x) +{ + half_float::half xx = *reinterpret_cast(&x); + + half_float::half abs_xx = half_float::abs(xx); + + half_t abs_x = *reinterpret_cast(&abs_xx); + + return abs_x; +}; + +static inline __host__ float isnan(float x) { return std::isnan(x); }; + +static inline __host__ double isnan(double x) { return std::isnan(x); }; + +static inline __host__ int8_t isnan(int8_t x) +{ + (void)x; + return false; +}; + +static inline __host__ int32_t isnan(int32_t x) +{ + (void)x; + return false; +}; + +static inline __host__ bool isnan(half_t x) +{ + half_float::half xx = *reinterpret_cast(&x); + + return half_float::isnan(xx); +}; + +} // namespace math +} // namespace ck + +#endif diff --git a/composable_kernel/include/utility/multi_index.hpp b/include/ck/utility/multi_index.hpp similarity index 77% rename from composable_kernel/include/utility/multi_index.hpp rename to include/ck/utility/multi_index.hpp index 0bb34fb1e2a..f395b5ee715 100644 --- a/composable_kernel/include/utility/multi_index.hpp +++ b/include/ck/utility/multi_index.hpp @@ -3,7 +3,7 @@ #include "common_header.hpp" -#if CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX +#if CK_EXPERIMENTAL_USE_DYNAMICALLY_INDEXED_MULTI_INDEX #include "array_multi_index.hpp" #else #include "statically_indexed_array_multi_index.hpp" diff --git a/include/ck/utility/number.hpp b/include/ck/utility/number.hpp new file mode 100644 index 00000000000..97a71f8a411 --- /dev/null +++ b/include/ck/utility/number.hpp @@ -0,0 +1,15 @@ +#ifndef CK_NUMBER_HPP +#define CK_NUMBER_HPP + +#include "integral_constant.hpp" + +namespace ck { + +template +using Number = integral_constant; + +template +using LongNumber = integral_constant; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/print.hpp b/include/ck/utility/print.hpp similarity index 100% rename from composable_kernel/include/utility/print.hpp rename to include/ck/utility/print.hpp diff --git a/composable_kernel/include/utility/reduction_common.hpp b/include/ck/utility/reduction_common.hpp similarity index 85% rename from composable_kernel/include/utility/reduction_common.hpp rename to include/ck/utility/reduction_common.hpp index ff574c315c1..a34cfce8377 100644 --- a/composable_kernel/include/utility/reduction_common.hpp +++ b/include/ck/utility/reduction_common.hpp @@ -33,7 +33,7 @@ namespace ck { struct float_equal_one { template - __device__ inline bool operator()(T x) + __host__ __device__ inline bool operator()(T x) { return x <= static_cast(1.0f) and x >= static_cast(1.0f); }; @@ -42,12 +42,24 @@ struct float_equal_one struct float_equal_zero { template - __device__ inline bool operator()(T x) + __host__ __device__ inline bool operator()(T x) { return x <= static_cast(0.0f) and x >= static_cast(0.0f); }; }; +template +static constexpr __device__ index_t get_shift() +{ + return (get_shift() + 1); +}; + +template <> +constexpr __device__ index_t get_shift<1>() +{ + return (0); +} + }; // end of namespace ck #endif diff --git a/composable_kernel/include/utility/reduction_enums.hpp b/include/ck/utility/reduction_enums.hpp similarity index 94% rename from composable_kernel/include/utility/reduction_enums.hpp rename to include/ck/utility/reduction_enums.hpp index e97108179ea..9089fd6116c 100644 --- a/composable_kernel/include/utility/reduction_enums.hpp +++ b/include/ck/utility/reduction_enums.hpp @@ -28,7 +28,7 @@ namespace ck { -enum class ReduceTensorOp_t +enum struct ReduceTensorOp { ADD = 0, MUL = 1, @@ -41,19 +41,19 @@ enum class ReduceTensorOp_t // MUL_NO_ZEROS = 8, }; -enum class NanPropagation_t +enum struct NanPropagation { NOT_PROPAGATE_NAN = 0, PROPAGATE_NAN = 1, }; -enum class ReduceTensorIndices_t +enum struct ReduceTensorIndices { NO_INDICES = 0, FLATTENED_INDICES = 1, }; -enum class IndicesType_t +enum struct IndicesType { INDICES_32BIT = 0, INDICES_64BIT = 1, diff --git a/composable_kernel/include/utility/reduction_functions_binop.hpp b/include/ck/utility/reduction_functions_accumulate.hpp similarity index 51% rename from composable_kernel/include/utility/reduction_functions_binop.hpp rename to include/ck/utility/reduction_functions_accumulate.hpp index 5285abee81e..4e8636e5b2a 100644 --- a/composable_kernel/include/utility/reduction_functions_binop.hpp +++ b/include/ck/utility/reduction_functions_accumulate.hpp @@ -34,50 +34,79 @@ namespace ck { namespace detail { -static inline __device__ bool isnan(half_t x) { return __hisnan(x); }; +template +static inline __device__ bool is_nan(T x) +{ + return (isnan(x)); +}; + +template <> +inline __device__ bool is_nan(half_t x) +{ + return (__hisnan(x)); +}; -template -struct binop_with_nan_check; +template +struct AccumulateWithNanCheck; -template -struct binop_with_nan_check +template +struct AccumulateWithNanCheck { // cppcheck-suppress constParameter - __device__ static inline void calculate(compType& accuVal, compType currVal) + __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) + { + ReduceOperation{}(accuVal, currVal); + }; +}; + +template +struct AccumulateWithNanCheck +{ + __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) { - opReduce{}(accuVal, currVal); + if(is_nan(currVal)) + { + accuVal = currVal; + } + else + { + ReduceOperation{}(accuVal, currVal); + }; }; +}; + +template +struct AccumulateWithIndexAndNanCheck; - // The method is called when the opReduce is indexable and the user asked for indices +template +struct AccumulateWithIndexAndNanCheck +{ __device__ static inline void // cppcheck-suppress constParameter - calculate(compType& accuVal, compType currVal, int& accuIndex, int currIndex) + Calculate(AccDataType& accuVal, + AccDataType currVal, + IndexDataType& accuIndex, + IndexDataType currIndex) { bool changed = false; - opReduce{}(accuVal, currVal, changed); + ReduceOperation{}(accuVal, currVal, changed); if(changed) accuIndex = currIndex; }; }; -template -struct binop_with_nan_check +template +struct AccumulateWithIndexAndNanCheck { - __device__ static inline void calculate(compType& accuVal, compType currVal) - { - if(isnan(currVal)) - accuVal = currVal; - else - opReduce{}(accuVal, currVal); - }; - - // The method is called when the opReduce is indexable and the user asked for indices - __device__ static inline void - calculate(compType& accuVal, compType currVal, int& accuIndex, int currIndex) + // The method is called when the ReduceOperation is indexable and the user asked for indices + __device__ static inline void Calculate(AccDataType& accuVal, + AccDataType currVal, + IndexDataType& accuIndex, + IndexDataType currIndex) { - if(isnan(currVal)) + if(is_nan(currVal)) { accuVal = currVal; accuIndex = currIndex; @@ -86,7 +115,7 @@ struct binop_with_nan_check { bool changed = false; - opReduce{}(accuVal, currVal, changed); + ReduceOperation{}(accuVal, currVal, changed); if(changed) accuIndex = currIndex; diff --git a/include/ck/utility/reduction_operator.hpp b/include/ck/utility/reduction_operator.hpp new file mode 100644 index 00000000000..e7a8db8c011 --- /dev/null +++ b/include/ck/utility/reduction_operator.hpp @@ -0,0 +1,201 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef CK_REDUCTION_OPERATOR_HPP +#define CK_REDUCTION_OPERATOR_HPP + +#include "config.hpp" +#include "data_type.hpp" + +namespace ck { + +namespace reduce { + +// Every binary operator used in reduction is represented by a templated functor class. Each functor +// class must provide at least +// three members: +// 1) GetReductionZeroVal() -- the interface to return the "identity element" for the binary +// operator, "identity element" is the unique +// element in the algebraic space that doesn't affect the value of other elements +// when operated against them, and the concept is similar to zero vector in +// vector space +// (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf). +// 2) IsCompatibleInMemoryDataOperation() -- return true if the reduction task corresponding to this +// operator can use the InMemoryDataOperation to finalize, or else it return false 3) operator() -- +// the first argument of the operator must be both an input & output, and the corresponding variable +// usually stores +// the accumulated result of many operator() calls; the second argument is only an +// input. For indexable binary +// operator, the second version of operator() has third argument (which is an +// output) to indicate whether the +// accumulated value (the first argument) has changed, in which case the recorded +// accumulated index also need be +// changed. + +template +struct Add +{ + using dataType = T; + + __host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast(0.0f); }; + + __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + return operation == InMemoryDataOperationEnum::AtomicAdd || + operation == InMemoryDataOperationEnum::Set; + }; + + __host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; } +}; + +template +struct Mul +{ + using dataType = T; + + __host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast(1.0f); }; + + __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + return operation == InMemoryDataOperationEnum::Set; + }; + + __host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; } +}; + +template +struct Max +{ + using dataType = T; + + __host__ __device__ static constexpr T GetReductionZeroVal() + { + return NumericLimits::Lowest(); + }; + + __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + // ToChange: atomic_max to be added + return operation == InMemoryDataOperationEnum::Set; + }; + + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + if(a < b) + a = b; + } + + __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const + { + if(a < b) + { + a = b; + changed = true; + } + } +}; + +template +struct Min +{ + using dataType = T; + + __host__ __device__ static constexpr T GetReductionZeroVal() + { + return NumericLimits::Max(); + }; + + __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + // ToChange: atomic_min to be added + return operation == InMemoryDataOperationEnum::Set; + }; + + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + if(a > b) + a = b; + } + + __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const + { + if(a > b) + { + a = b; + changed = true; + } + } +}; + +template +struct AMax +{ + using dataType = T; + + __host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast(0.0f); }; + + __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + // ToChange: atomic_max to be added + return operation == InMemoryDataOperationEnum::Set; + }; + + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + if(a < b) + a = b; + } + + __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const + { + if(a < b) + { + a = b; + changed = true; + } + } +}; + +template +T GetReductionZeroValueForInMemoryDataOperation(InMemoryDataOperationEnum operation) +{ + T result = ck::type_convert(0.0f); + + if(operation == InMemoryDataOperationEnum::AtomicMax) + result = ck::NumericLimits::Lowest(); + + return (result); +}; + +}; // end of namespace reduce + +} // end of namespace ck + +#endif diff --git a/composable_kernel/include/utility/sequence.hpp b/include/ck/utility/sequence.hpp similarity index 99% rename from composable_kernel/include/utility/sequence.hpp rename to include/ck/utility/sequence.hpp index b35999d56ff..c2adfc5063f 100644 --- a/composable_kernel/include/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -606,6 +606,12 @@ struct sequence_map_inverse SeqMap::Size()>::type; }; +template +__host__ __device__ constexpr bool operator==(Sequence, Sequence) +{ + return ((Xs == Ys) && ...); +} + template __host__ __device__ constexpr auto operator+(Sequence, Sequence) { diff --git a/composable_kernel/include/utility/sequence_helper.hpp b/include/ck/utility/sequence_helper.hpp similarity index 100% rename from composable_kernel/include/utility/sequence_helper.hpp rename to include/ck/utility/sequence_helper.hpp diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp new file mode 100644 index 00000000000..ef177e96976 --- /dev/null +++ b/include/ck/utility/static_buffer.hpp @@ -0,0 +1,173 @@ +#ifndef CK_STATIC_BUFFER_HPP +#define CK_STATIC_BUFFER_HPP + +#include "statically_indexed_array.hpp" + +namespace ck { + +// static buffer for scalar +template // TODO remove this bool, no longer needed +struct StaticBuffer : public StaticallyIndexedArray +{ + using type = T; + using base = StaticallyIndexedArray; + + __host__ __device__ constexpr StaticBuffer() : base{} {} + + __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } + + // read access + template + __host__ __device__ constexpr const T& operator[](Number i) const + { + return base::operator[](i); + } + + // write access + template + __host__ __device__ constexpr T& operator()(Number i) + { + return base::operator()(i); + } + + __host__ __device__ void Clear() + { + static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{0}; }); + } +}; + +// static buffer for vector +template ::value, bool>::type = false> +struct StaticBufferTupleOfVector + : public StaticallyIndexedArray, NumOfVector> +{ + using V = typename vector_type::type; + using base = StaticallyIndexedArray, NumOfVector>; + + static constexpr auto s_per_v = Number{}; + static constexpr auto num_of_v_ = Number{}; + + __host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {} + + __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } + + // Get S + // i is offset of S + template + __host__ __device__ constexpr const S& operator[](Number i) const + { + constexpr auto i_v = i / s_per_v; + constexpr auto i_s = i % s_per_v; + + return base::operator[](i_v).template AsType()[i_s]; + } + + // Set S + // i is offset of S + template + __host__ __device__ constexpr S& operator()(Number i) + { + constexpr auto i_v = i / s_per_v; + constexpr auto i_s = i % s_per_v; + + return base::operator()(i_v).template AsType()(i_s); + } + + // Get X + // i is offset of S, not X. i should be aligned to X + template ::value, bool>::type = false> + __host__ __device__ constexpr auto GetAsType(Number i) const + { + constexpr auto s_per_x = Number>::vector_size>{}; + + static_assert(s_per_v % s_per_x == 0, "wrong! V must one or multiple X"); + static_assert(i % s_per_x == 0, "wrong!"); + + constexpr auto i_v = i / s_per_v; + constexpr auto i_x = (i % s_per_v) / s_per_x; + + return base::operator[](i_v).template AsType()[i_x]; + } + + // Set X + // i is offset of S, not X. i should be aligned to X + template ::value, bool>::type = false> + __host__ __device__ constexpr void SetAsType(Number i, X x) + { + constexpr auto s_per_x = Number>::vector_size>{}; + + static_assert(s_per_v % s_per_x == 0, "wrong! V must contain one or multiple X"); + static_assert(i % s_per_x == 0, "wrong!"); + + constexpr auto i_v = i / s_per_v; + constexpr auto i_x = (i % s_per_v) / s_per_x; + + base::operator()(i_v).template AsType()(i_x) = x; + } + + // Get read access to vector_type V + // i is offset of S, not V. i should be aligned to V + template + __host__ __device__ constexpr const auto& GetVectorTypeReference(Number i) const + { + static_assert(i % s_per_v == 0, "wrong!"); + + constexpr auto i_v = i / s_per_v; + + return base::operator[](i_v); + } + + // Get write access to vector_type V + // i is offset of S, not V. i should be aligned to V + template + __host__ __device__ constexpr auto& GetVectorTypeReference(Number i) + { + static_assert(i % s_per_v == 0, "wrong!"); + + constexpr auto i_v = i / s_per_v; + + return base::operator()(i_v); + } + + __host__ __device__ void Clear() + { + constexpr index_t NumScalars = NumOfVector * ScalarPerVector; + + static_for<0, NumScalars, 1>{}([&](auto i) { SetAsType(i, S{0}); }); + } +}; + +template +__host__ __device__ constexpr auto make_static_buffer(Number) +{ + return StaticBuffer{}; +} + +template +__host__ __device__ constexpr auto make_static_buffer(LongNumber) +{ + return StaticBuffer{}; +} + +} // namespace ck +#endif diff --git a/include/ck/utility/statically_indexed_array.hpp b/include/ck/utility/statically_indexed_array.hpp new file mode 100644 index 00000000000..526be2a07ac --- /dev/null +++ b/include/ck/utility/statically_indexed_array.hpp @@ -0,0 +1,102 @@ +#ifndef CK_STATICALLY_INDEXED_ARRAY_HPP +#define CK_STATICALLY_INDEXED_ARRAY_HPP + +#include "functional2.hpp" +#include "sequence.hpp" +#include "tuple.hpp" + +namespace ck { + +namespace detail { +template +struct tuple_concat; + +template +struct tuple_concat, Tuple> +{ + using type = Tuple; +}; + +template +struct StaticallyIndexedArrayImpl +{ + using type = + typename tuple_concat::type, + typename StaticallyIndexedArrayImpl::type>::type; +}; + +template +struct StaticallyIndexedArrayImpl +{ + using type = Tuple<>; +}; + +template +struct StaticallyIndexedArrayImpl +{ + using type = Tuple; +}; +} // namespace detail + +template +using StaticallyIndexedArray = typename detail::StaticallyIndexedArrayImpl::type; + +template +__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs) +{ + return StaticallyIndexedArray(x, static_cast(xs)...); +} + +// make empty StaticallyIndexedArray +template +__host__ __device__ constexpr auto make_statically_indexed_array() +{ + return StaticallyIndexedArray(); +} + +template +struct StaticallyIndexedArray_v2 +{ + __host__ __device__ constexpr StaticallyIndexedArray_v2() = default; + + __host__ __device__ static constexpr index_t Size() { return N; } + + // read access + template + __host__ __device__ constexpr const auto& At(Number) const + { + static_assert(I < N, "wrong! out of range"); + + return data_[I]; + } + + // write access + template + __host__ __device__ constexpr auto& At(Number) + { + static_assert(I < N, "wrong! out of range"); + + return data_[I]; + } + + // read access + template + __host__ __device__ constexpr const auto& operator[](Number i) const + { + return At(i); + } + + // write access + template + __host__ __device__ constexpr auto& operator()(Number i) + { + return At(i); + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } + + T data_[N]; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/statically_indexed_array_multi_index.hpp b/include/ck/utility/statically_indexed_array_multi_index.hpp similarity index 95% rename from composable_kernel/include/utility/statically_indexed_array_multi_index.hpp rename to include/ck/utility/statically_indexed_array_multi_index.hpp index 9e96f06d737..e0ee9d04fdb 100644 --- a/composable_kernel/include/utility/statically_indexed_array_multi_index.hpp +++ b/include/ck/utility/statically_indexed_array_multi_index.hpp @@ -93,6 +93,13 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple& x) return r; } +// MultiIndex = MultiIndex * index_t +template +__host__ __device__ constexpr auto operator*(const Tuple& x, index_t a) +{ + return a * x; +} + template __host__ __device__ void print_multi_index(const Tuple& x) { diff --git a/composable_kernel/include/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp similarity index 84% rename from composable_kernel/include/utility/synchronization.hpp rename to include/ck/utility/synchronization.hpp index da74f2074db..d46628d9133 100644 --- a/composable_kernel/include/utility/synchronization.hpp +++ b/include/ck/utility/synchronization.hpp @@ -7,7 +7,7 @@ namespace ck { __device__ void block_sync_lds() { -#if CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM +#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM asm volatile("\ s_waitcnt lgkmcnt(0) \n \ s_barrier \ diff --git a/include/ck/utility/tensor_space_filling_curve.hpp b/include/ck/utility/tensor_space_filling_curve.hpp new file mode 100644 index 00000000000..b5f1a34d837 --- /dev/null +++ b/include/ck/utility/tensor_space_filling_curve.hpp @@ -0,0 +1,159 @@ +#ifndef TENSOR_SPACE_FILLING_CURVE_HPP +#define TENSOR_SPACE_FILLING_CURVE_HPP + +#include "math.hpp" +#include "sequence.hpp" +#include "sequence_helper.hpp" +#include "tensor_adaptor.hpp" +#include "statically_indexed_array_multi_index.hpp" +#include "tuple_helper.hpp" + +namespace ck { + +template // # of scalars per access in each dimension +struct SpaceFillingCurve +{ + static constexpr index_t nDim = TensorLengths::Size(); + + using Index = MultiIndex; + + static constexpr index_t ScalarPerVector = + reduce_on_sequence(ScalarsPerAccess{}, math::multiplies{}, Number<1>{}); + + static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{}; + static constexpr auto dim_access_order = DimAccessOrder{}; + static constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static constexpr auto to_index_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(ordered_access_lengths)), + make_tuple(typename arithmetic_sequence_gen<0, nDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr index_t GetNumOfAccess() + { + static_assert(TensorLengths::Size() == ScalarsPerAccess::Size()); + static_assert(TensorLengths{} % ScalarsPerAccess{} == + typename uniform_sequence_gen::type{}); + + return reduce_on_sequence(TensorLengths{}, math::multiplies{}, Number<1>{}) / + ScalarPerVector; + } + + template + static __device__ __host__ constexpr auto GetStepBetween(Number, + Number) + { + static_assert(AccessIdx1dBegin >= 0, "1D index should be non-negative"); + static_assert(AccessIdx1dBegin < GetNumOfAccess(), "1D index should be larger than 0"); + static_assert(AccessIdx1dEnd >= 0, "1D index should be non-negative"); + static_assert(AccessIdx1dEnd < GetNumOfAccess(), "1D index should be larger than 0"); + + constexpr auto idx_begin = GetIndex(Number{}); + constexpr auto idx_end = GetIndex(Number{}); + return idx_end - idx_begin; + } + + template + static __device__ __host__ constexpr auto GetForwardStep(Number) + { + static_assert(AccessIdx1d < GetNumOfAccess(), "1D index should be larger than 0"); + return GetStepBetween(Number{}, Number{}); + } + + template + static __device__ __host__ constexpr auto GetBackwardStep(Number) + { + static_assert(AccessIdx1d > 0, "1D index should be larger than 0"); + + return GetStepBetween(Number{}, Number{}); + } + + template + static __device__ __host__ constexpr Index GetIndex(Number) + { +#if 0 + /* + * \todo: TensorAdaptor::CalculateBottomIndex does NOT return constexpr as expected. + */ + constexpr auto ordered_access_idx = to_index_adaptor.CalculateBottomIndex(make_multi_index(Number{})); +#else + + constexpr auto access_strides = container_reverse_exclusive_scan( + ordered_access_lengths, math::multiplies{}, Number<1>{}); + + constexpr auto idx_1d = Number{}; + // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the + // idim-th element of multidimensional index. + // All constexpr variables have to be captured by VALUE. + constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr + { + constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr + { + auto res = idx_1d.value; + auto id = 0; + + static_for<0, jdim.value + 1, 1>{}([&](auto kdim) { + id = res / access_strides[kdim].value; + res -= id * access_strides[kdim].value; + }); + + return id; + }; + + constexpr auto id = compute_index_impl(idim); + return Number{}; + }; + + constexpr auto ordered_access_idx = generate_tuple(compute_index, Number{}); +#endif + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto idim) { + index_t tmp = ordered_access_idx[I0]; + + static_for<1, idim, 1>{}( + [&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; }); + + forward_sweep_(idim) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate multi-dim tensor index + auto idx_md = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto idim) { + ordered_idx(idim) = forward_sweep[idim] ? ordered_access_idx[idim] + : ordered_access_lengths[idim] - 1 - + ordered_access_idx[idim]; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + ScalarsPerAccess{}; + }(); + return idx_md; + } + + // FIXME: rename this function + template + static __device__ __host__ constexpr auto GetIndexTupleOfNumber(Number) + { + constexpr auto idx = GetIndex(Number{}); + + return generate_tuple([&](auto i) { return Number{}; }, Number{}); + } +}; + +} // namespace ck +#endif diff --git a/include/ck/utility/thread_group.hpp b/include/ck/utility/thread_group.hpp new file mode 100644 index 00000000000..bd3563c5f10 --- /dev/null +++ b/include/ck/utility/thread_group.hpp @@ -0,0 +1,18 @@ +#pragma once +#include "get_id.hpp" + +namespace ck { + +template +struct ThisThreadBlock +{ + static constexpr index_t kNumThread_ = ThreadPerBlock; + + __device__ static constexpr index_t GetNumOfThread() { return kNumThread_; } + + __device__ static constexpr bool IsBelong() { return true; } + + __device__ static index_t GetThreadId() { return get_thread_local_1d_id(); } +}; + +} // namespace ck diff --git a/include/ck/utility/transpose_vectors.hpp b/include/ck/utility/transpose_vectors.hpp new file mode 100644 index 00000000000..31f9c02c74f --- /dev/null +++ b/include/ck/utility/transpose_vectors.hpp @@ -0,0 +1,168 @@ +#ifndef CK_TRANSPOSE_VECTORS_AMD_HPP +#define CK_TRANSPOSE_VECTORS_AMD_HPP + +#include "config.hpp" +#include "statically_indexed_array.hpp" +#include "data_type.hpp" + +namespace ck { + +template ::value, bool>::type = false> +struct transpose_vectors; + +// transpose fp16 2x2 +__device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t& y0, half2_t& y1) +{ +#if 0 + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + const vector_type vx0{x0}, vx1{x1}; + vector_type vy0, vy1; + + vy0.template AsType()(I0) = vx0.template AsType()[I0]; + vy0.template AsType()(I1) = vx1.template AsType()[I0]; + + vy1.template AsType()(I0) = vx0.template AsType()[I1]; + vy1.template AsType()(I1) = vx1.template AsType()[I1]; + + y0 = vy0.template AsType()[I0]; + y1 = vy1.template AsType()[I0]; +#else + asm volatile("\n \ + v_pack_b32_f16 %0, %1, %2 \n \ + " + : "=v"(y0) + : "v"(x0), "v"(x1)); + + asm volatile("\n \ + v_pack_b32_f16 %0, %1, %2, op_sel:[1, 1] \n \ + " + : "=v"(y1) + : "v"(x0), "v"(x1)); +#endif +} + +template +struct transpose_vectors +{ + // we got [NY * NX] amount of S data to be transposed + static constexpr index_t s_per_x = NY; + static constexpr index_t s_per_y = NX; + + using S = half_t; + using VX = vector_type; + using VY = vector_type; + + __device__ void operator()(const StaticallyIndexedArray& vx_tuple, + StaticallyIndexedArray& vy_tuple) + { + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!"); + + // loop over 2x2 tile and transpose data from vx_tuple into vy_tuple + static_for<0, NY, 2>{}([&](auto iy) { + static_for<0, NX, 2>{}([&](auto ix) { + // reference to 2 half2_t data from vx_tuple + const auto& x_s2_0 = vx_tuple[ix].template AsType()[iy / I2]; + const auto& x_s2_1 = vx_tuple[ix + I1].template AsType()[iy / I2]; + + // reference to 2 half2_t data from vy_tuple + auto& y_s2_0 = vy_tuple(iy).template AsType()(ix / I2); + auto& y_s2_1 = vy_tuple(iy + I1).template AsType()(ix / I2); + + // transpose + transpose_fp16_2x2(x_s2_0, x_s2_1, y_s2_0, y_s2_1); + }); + }); + } +}; + +// transpose int8 4x4 +__device__ void transpose_int8_4x4(const int8x4_t& x0, + const int8x4_t& x1, + const int8x4_t& x2, + const int8x4_t& x3, + int8x4_t& y0, + int8x4_t& y1, + int8x4_t& y2, + int8x4_t& y3) +{ + int32_t t0, t1; + int32_t z0, z1, z2, z3; + constexpr int32_t m0 = 0x05010400; + constexpr int32_t m1 = 0x05040100; + constexpr int32_t m2 = 0x07060302; + constexpr int32_t m3 = 0x07030602; + + // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488 + // -- -- -- -- -- -- -- -- - - - - + // index 7 6 5 4 3 2 1 0 33 77 44 88 + // index is reversed because of little endianness (least significant bits first) + // clang-format off + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t0) : "v"(bit_cast(x1)), "v"(bit_cast(x0)), "s"(m0)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t1) : "v"(bit_cast(x3)), "v"(bit_cast(x2)), "s"(m0)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z0) : "v"(bit_cast(t1)), "v"(bit_cast(t0)), "s"(m1)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z1) : "v"(bit_cast(t1)), "v"(bit_cast(t0)), "s"(m2)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t0) : "v"(bit_cast(x1)), "v"(bit_cast(x0)), "s"(m3)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t1) : "v"(bit_cast(x3)), "v"(bit_cast(x2)), "s"(m3)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z2) : "v"(bit_cast(t1)), "v"(bit_cast(t0)), "s"(m1)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z3) : "v"(bit_cast(t1)), "v"(bit_cast(t0)), "s"(m2)); + // clang-format on + + y0 = bit_cast(z0); + y1 = bit_cast(z1); + y2 = bit_cast(z2); + y3 = bit_cast(z3); +} + +template +struct transpose_vectors +{ + // we got [NY * NX] amount of S data to be transposed + static constexpr index_t s_per_x = NY; + static constexpr index_t s_per_y = NX; + + using S = int8_t; + using VX = vector_type; + using VY = vector_type; + + __device__ void operator()(const StaticallyIndexedArray& vx_tuple, + StaticallyIndexedArray& vy_tuple) + { + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!"); + + // loop over 4x4 tile and transpose data from vx_tuple into vy_tuple + static_for<0, NY, 4>{}([&](auto iy) { + static_for<0, NX, 4>{}([&](auto ix) { + // reference to 4 int8 data from vx_tuple + const auto& x_s4_0 = vx_tuple[ix].template AsType()[iy / I4]; + const auto& x_s4_1 = vx_tuple[ix + I1].template AsType()[iy / I4]; + const auto& x_s4_2 = vx_tuple[ix + I2].template AsType()[iy / I4]; + const auto& x_s4_3 = vx_tuple[ix + I3].template AsType()[iy / I4]; + + // reference to 4 int8 data from vy_tuple + auto& y_s4_0 = vy_tuple(iy).template AsType()(ix / I4); + auto& y_s4_1 = vy_tuple(iy + I1).template AsType()(ix / I4); + auto& y_s4_2 = vy_tuple(iy + I2).template AsType()(ix / I4); + auto& y_s4_3 = vy_tuple(iy + I3).template AsType()(ix / I4); + + // transpose + transpose_int8_4x4(x_s4_0, x_s4_1, x_s4_2, x_s4_3, y_s4_0, y_s4_1, y_s4_2, y_s4_3); + }); + }); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/tuple.hpp b/include/ck/utility/tuple.hpp similarity index 89% rename from composable_kernel/include/utility/tuple.hpp rename to include/ck/utility/tuple.hpp index 70f4d77d874..766a78240bd 100644 --- a/composable_kernel/include/utility/tuple.hpp +++ b/include/ck/utility/tuple.hpp @@ -21,9 +21,9 @@ struct TupleElement { __host__ __device__ constexpr TupleElement() = default; - template >, TupleElement>::value, - bool>::type = false> + template < + typename T, + typename enable_if, TupleElement>::value, bool>::type = false> __host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward(v)) { } @@ -60,7 +60,7 @@ struct TupleImpl, Xs...> : TupleElement, Xs> template >, TupleImpl>::value, + !is_same, TupleImpl>::value, bool>::type = false> __host__ __device__ constexpr TupleImpl(Y&& y) : TupleElement, Xs>(std::forward(y))... @@ -101,8 +101,7 @@ struct Tuple : detail::TupleImpl>, Tuple>::value, + typename enable_if, Tuple>::value, bool>::type = false> __host__ __device__ constexpr Tuple(Y&& y) : base(std::forward(y)) { @@ -117,6 +116,7 @@ struct Tuple : detail::TupleImpl __host__ __device__ constexpr const auto& At(Number) const { @@ -124,6 +124,7 @@ struct Tuple : detail::TupleImpl{}); } + // write access template __host__ __device__ constexpr auto& At(Number) { @@ -131,12 +132,14 @@ struct Tuple : detail::TupleImpl{}); } + // read access template __host__ __device__ constexpr const auto& operator[](Number i) const { return At(i); } + // write access template __host__ __device__ constexpr auto& operator()(Number i) { @@ -162,5 +165,12 @@ __host__ __device__ constexpr auto make_tuple(Xs&&... xs) return Tuple...>(std::forward(xs)...); } +// https://en.cppreference.com/w/cpp/utility/tuple/tie +template +constexpr Tuple tie(Args&... args) noexcept +{ + return {args...}; +} + } // namespace ck #endif diff --git a/composable_kernel/include/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp similarity index 81% rename from composable_kernel/include/utility/tuple_helper.hpp rename to include/ck/utility/tuple_helper.hpp index 55a79d2594e..4e5b9cf97c8 100644 --- a/composable_kernel/include/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -6,22 +6,6 @@ namespace ck { -template -struct is_known_at_compile_time> -{ - __host__ __device__ static constexpr bool IsKnownAtCompileTime() - { - return container_reduce( - Tuple{}, - [](auto x, bool r) { - return is_known_at_compile_time>::value & r; - }, - true); - } - - static constexpr bool value = IsKnownAtCompileTime(); -}; - template __host__ __device__ constexpr auto generate_tuple(F&& f, Number) { @@ -29,6 +13,13 @@ __host__ __device__ constexpr auto generate_tuple(F&& f, Number) typename arithmetic_sequence_gen<0, N, 1>::type{}); } +template +__host__ __device__ constexpr auto generate_tie(F&& f, Number) +{ + return unpack([&f](auto&&... xs) { return tie(f(xs)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + namespace detail { template diff --git a/composable_kernel/include/utility/type.hpp b/include/ck/utility/type.hpp similarity index 71% rename from composable_kernel/include/utility/type.hpp rename to include/ck/utility/type.hpp index 89a2bdbde63..ee3189ebe5f 100644 --- a/composable_kernel/include/utility/type.hpp +++ b/include/ck/utility/type.hpp @@ -1,6 +1,7 @@ #ifndef CK_TYPE_HPP #define CK_TYPE_HPP +#include "config.hpp" #include "integral_constant.hpp" #include "enable_if.hpp" @@ -16,6 +17,9 @@ struct is_same : public integral_constant { }; +template +inline constexpr bool is_same_v = is_same::value; + template using remove_reference_t = typename std::remove_reference::type; @@ -26,26 +30,21 @@ template using remove_cvref_t = remove_cv_t>; template -inline constexpr bool is_pointer_v = std::is_pointer::value; +using remove_pointer_t = typename std::remove_pointer::type; template -struct is_known_at_compile_time; +inline constexpr bool is_pointer_v = std::is_pointer::value; -template <> -struct is_known_at_compile_time +template ::type = false> +__host__ __device__ constexpr Y bit_cast(const X& x) { - static constexpr bool value = false; -}; +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST + Y y; -template -struct is_known_at_compile_time> -{ - static constexpr bool value = true; -}; + __builtin_memcpy(&y, &x, sizeof(X)); -template ::type = false> -__host__ __device__ constexpr Y as_type(X x) -{ + return y; +#else union AsType { X x; @@ -53,6 +52,7 @@ __host__ __device__ constexpr Y as_type(X x) }; return AsType{x}.y; +#endif } } // namespace ck diff --git a/library/CMakeLists.txt b/library/CMakeLists.txt new file mode 100644 index 00000000000..aa18026932b --- /dev/null +++ b/library/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(src/host_tensor) +add_subdirectory(src/tensor_operation_instance/gpu) +add_subdirectory(src/utility) diff --git a/library/include/ck/library/host/host_interface.hpp b/library/include/ck/library/host/host_interface.hpp new file mode 100644 index 00000000000..955da0f4bee --- /dev/null +++ b/library/include/ck/library/host/host_interface.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include +#include + +#include "stream_config.hpp" +#include "config.hpp" +#include "device_base.hpp" + +struct DeviceConvFwdPtr_t +{ + using BaseArgument = ck::tensor_operation::device::BaseArgument; + using BaseInvoker = ck::tensor_operation::device::BaseInvoker; + + struct DeviceConvFwdPtrImpl; + std::unique_ptr pImpl; + DeviceConvFwdPtr_t(); + ~DeviceConvFwdPtr_t(); + DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&); + DeviceConvFwdPtr_t(DeviceConvFwdPtrImpl&); + DeviceConvFwdPtr_t& operator=(DeviceConvFwdPtr_t&) = delete; + DeviceConvFwdPtr_t& operator=(const DeviceConvFwdPtr_t&) = delete; + std::unique_ptr + MakeArgumentPointer(void* in_ptr, + void* wei_ptr, + void* out_ptr, + size_t N, + size_t K, + size_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + const; // in,wei and out element ops are ignored for now since even if we change them, they + // cant be linked + std::unique_ptr + MakeInvokerPointer() const; // requires including BaseInvoker headers + std::string GetTypeString(); + bool IsSupportedArgument(const BaseArgument* arg_ptr); +}; + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t( + std::vector& instances); +void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t( + std::vector& instances); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t( + std::vector& instances); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t( + std::vector& instances); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t( + std::vector& instances); diff --git a/host/host_tensor/include/conv_common.hpp b/library/include/ck/library/host_tensor/conv_common.hpp similarity index 86% rename from host/host_tensor/include/conv_common.hpp rename to library/include/ck/library/host_tensor/conv_common.hpp index 4bf2c234941..b60af7d664f 100644 --- a/host/host_tensor/include/conv_common.hpp +++ b/library/include/ck/library/host_tensor/conv_common.hpp @@ -3,15 +3,6 @@ #include "tensor_descriptor.hpp" -enum ConvTensorLayout -{ - NCHW, - NHWC, - CHWN, - NCHWc, - NHWCc -}; - template +inline auto activ(T v, const ck::ActivTypeEnum activ_type) +{ + const T alpha = 0.3; + switch(activ_type) + { + case ck::ActivTypeEnum::None: return v; + case ck::ActivTypeEnum::LeakyRelu: return (v >= 0 ? v : alpha * v); + case ck::ActivTypeEnum::Sigmoid: return (1 / (1 + exp(-v))); + default: throw std::runtime_error("unsupported activ type"); break; + } +} + #endif diff --git a/library/include/ck/library/host_tensor/device.hpp b/library/include/ck/library/host_tensor/device.hpp new file mode 100644 index 00000000000..990d2f98b37 --- /dev/null +++ b/library/include/ck/library/host_tensor/device.hpp @@ -0,0 +1,123 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "stream_config.hpp" +#include "ck/options.hpp" + +template +__global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) +{ + for(uint64_t i = threadIdx.x; i < buffer_element_size; i += blockDim.x) + { + p[i] = x; + } +} + +inline void hip_check_error(hipError_t x) +{ + if(x != hipSuccess) + { + std::ostringstream ss; + ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << __FILE__ << ": " << __LINE__ + << "in function: " << __func__; + throw std::runtime_error(ss.str()); + } +} + +struct DeviceMem +{ + DeviceMem() = delete; + DeviceMem(std::size_t mem_size); + void* GetDeviceBuffer(); + std::size_t GetBufferSize(); + void ToDevice(const void* p); + void FromDevice(void* p); + void SetZero(); + template + void SetValue(T x) + { + if(mMemSize % sizeof(T) != 0) + { + throw std::runtime_error("wrong! not entire DeviceMem will be set"); + } + + set_buffer_value<<<1, 1024>>>(static_cast(mpDeviceBuf), x, mMemSize / sizeof(T)); + } + ~DeviceMem(); + + void* mpDeviceBuf; + std::size_t mMemSize; +}; + +struct KernelTimerImpl; + +struct KernelTimer +{ + KernelTimer(); + ~KernelTimer(); + void Start(); + void End(); + float GetElapsedTime() const; + + std::unique_ptr impl; +}; + +template +float launch_and_time_kernel(const StreamConfig& stream_config, + F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + Args... args) +{ +#if CK_TIME_KERNEL + if(stream_config.time_kernel_) + { + printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + const int nrepeat = 10; + + printf("Warm up 1 time\n"); + + // warm up + kernel<<>>(args...); + + printf("Start running %d times...\n", nrepeat); + + KernelTimer timer; + timer.Start(); + + for(int i = 0; i < nrepeat; ++i) + { + kernel<<>>(args...); + } + + timer.End(); + + return timer.GetElapsedTime() / nrepeat; + } + else + { + kernel<<>>(args...); + + return 0; + } +#else + kernel<<>>(args...); + + return 0; +#endif +} diff --git a/host/host_tensor/include/device_tensor.hpp b/library/include/ck/library/host_tensor/device_tensor.hpp similarity index 88% rename from host/host_tensor/include/device_tensor.hpp rename to library/include/ck/library/host_tensor/device_tensor.hpp index 1a7a34a4cf3..b8d3ccc8a0b 100644 --- a/host/host_tensor/include/device_tensor.hpp +++ b/library/include/ck/library/host_tensor/device_tensor.hpp @@ -1,6 +1,5 @@ #pragma once #include "host_tensor.hpp" -#include "common_header.hpp" template void ostream_tensor_descriptor(TensorDesc, std::ostream& os = std::cout) diff --git a/library/include/ck/library/host_tensor/host_common_util.hpp b/library/include/ck/library/host_tensor/host_common_util.hpp new file mode 100644 index 00000000000..8fc1d364304 --- /dev/null +++ b/library/include/ck/library/host_tensor/host_common_util.hpp @@ -0,0 +1,102 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GUARD_HOST_COMMON_UTIL_HPP +#define GUARD_HOST_COMMON_UTIL_HPP + +#include +#include +#include +#include + +#include "config.hpp" + +namespace ck { + +namespace host_common { + +template +static inline void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems) +{ + std::ofstream outFile(fileName, std::ios::binary); + if(outFile) + { + outFile.write(reinterpret_cast(data), dataNumItems * sizeof(T)); + outFile.close(); + std::cout << "Write output to file " << fileName << std::endl; + } + else + { + std::cout << "Could not open file " << fileName << " for writing" << std::endl; + } +}; + +template +static inline T getSingleValueFromString(const std::string& valueStr) +{ + std::istringstream iss(valueStr); + + T val; + + iss >> val; + + return (val); +}; + +template +static inline std::vector getTypeValuesFromString(const char* cstr_values) +{ + std::string valuesStr(cstr_values); + + std::vector values; + std::size_t pos = 0; + std::size_t new_pos; + + new_pos = valuesStr.find(',', pos); + while(new_pos != std::string::npos) + { + const std::string sliceStr = valuesStr.substr(pos, new_pos - pos); + + T val = getSingleValueFromString(sliceStr); + + values.push_back(val); + + pos = new_pos + 1; + new_pos = valuesStr.find(',', pos); + }; + + std::string sliceStr = valuesStr.substr(pos); + T val = getSingleValueFromString(sliceStr); + + values.push_back(val); + + return (values); +} + +}; // namespace host_common + +}; // namespace ck + +#endif diff --git a/library/include/ck/library/host_tensor/host_conv.hpp b/library/include/ck/library/host_tensor/host_conv.hpp new file mode 100644 index 00000000000..3d2588c08b4 --- /dev/null +++ b/library/include/ck/library/host_tensor/host_conv.hpp @@ -0,0 +1,149 @@ +#pragma once +#include "host_tensor.hpp" +#include "conv_common.hpp" + +template +void host_conv_nchw_kcyx_nkhw(const Tensor& in, + const Tensor& wei, + Tensor& out, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads&) +{ + constexpr auto I0 = ck::Number<0>{}; + constexpr auto I1 = ck::Number<1>{}; + + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + float v = 0; + for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) + { + for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + v += ck::type_convert(in(n, c, hi, wi)) * + ck::type_convert(wei(k, c, y, x)); + } + } + } + } + out(n, k, ho, wo) = ck::type_convert(v); + }; + + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); +} + +template +void host_conv3d_ndhwc_kzyxc_ndhwk(const Tensor& in, + const Tensor& wei, + Tensor& out, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads&) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + const auto Di = in.mDesc.GetLengths()[1]; + const auto Hi = in.mDesc.GetLengths()[2]; + const auto Wi = in.mDesc.GetLengths()[3]; + const auto Z = wei.mDesc.GetLengths()[1]; + const auto Y = wei.mDesc.GetLengths()[2]; + const auto X = wei.mDesc.GetLengths()[3]; + const auto C = wei.mDesc.GetLengths()[4]; + + auto f_ndhwc = [&](auto n, auto do_tmp, auto ho_tmp, auto wo_tmp, auto k) { + // do__ must be converted to signed integer, otherwise zmin might be wrong in cases + // negative values. + const int do_ = static_cast(do_tmp); + const int ho = static_cast(ho_tmp); + const int wo = static_cast(wo_tmp); + const int zmin = + std::max(0, + (in_left_pads[I0] - do_ * conv_strides[I0] + conv_dilations[I0] - 1) / + conv_dilations[I0]); + const int ymin = + std::max(0, + (in_left_pads[I1] - ho * conv_strides[I1] + conv_dilations[I1] - 1) / + conv_dilations[I1]); + const int xmin = + std::max(0, + (in_left_pads[I2] - wo * conv_strides[I2] + conv_dilations[I2] - 1) / + conv_dilations[I2]); + const int zmax = + std::min(Z, (in_left_pads[I0] - do_ * conv_strides[I0] + Di) / conv_dilations[I0]); + const int ymax = + std::min(Y, (in_left_pads[I1] - ho * conv_strides[I1] + Hi) / conv_dilations[I1]); + const int xmax = + std::min(X, (in_left_pads[I2] - wo * conv_strides[I2] + Wi) / conv_dilations[I2]); + const int di_min = do_ * conv_strides[I0] + zmin * conv_dilations[I0] - in_left_pads[I0]; + const int hi_min = ho * conv_strides[I1] + ymin * conv_dilations[I1] - in_left_pads[I1]; + const int wi_min = wo * conv_strides[I2] + xmin * conv_dilations[I2] - in_left_pads[I2]; + + double v = 0; + + const TIn* in_n = in.mData.data() + n * Di * Hi * Wi * C; + const TWei* wei_k = wei.mData.data() + k * Z * Y * X * C; + + int di = di_min; + for(int z = zmin; z < zmax; ++z, di += conv_dilations[I0]) + { + const TIn* in_n_di = in_n + di * Hi * Wi * C; + const TWei* wei_k_z = wei_k + z * Y * X * C; + int hi = hi_min; + + for(int y = ymin; y < ymax; ++y, hi += conv_dilations[I1]) + { + const TIn* in_n_di_hi = in_n_di + hi * Wi * C; + const TWei* wei_k_z_y = wei_k_z + y * X * C; + int wi = wi_min; + + for(int x = xmin; x < xmax; ++x, wi += conv_dilations[I2]) + { + const TIn* in_n_di_hi_wi = in_n_di_hi + wi * C; + const TWei* wei_k_z_y_x = wei_k_z_y + x * C; + + for(int c = 0; c < C; ++c) + { + v += static_cast(in_n_di_hi_wi[c]) * + static_cast(wei_k_z_y_x[c]); + } + } + } + } + + out(n, do_, ho, wo, k) = v; + }; + + make_ParallelTensorFunctor(f_ndhwc, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3], + out.mDesc.GetLengths()[4])(std::thread::hardware_concurrency() - 4); +} diff --git a/library/include/ck/library/host_tensor/host_gemm.hpp b/library/include/ck/library/host_tensor/host_gemm.hpp new file mode 100644 index 00000000000..211c01c01a7 --- /dev/null +++ b/library/include/ck/library/host_tensor/host_gemm.hpp @@ -0,0 +1,43 @@ +#pragma once +#include "host_tensor.hpp" + +template +void host_gemm_mk_kn_mn(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op) +{ + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = a_m_k.mDesc.GetLengths()[1]; + + float v_acc = 0; + + for(int k = 0; k < K; ++k) + { + float v_a; + float v_b; + + a_element_op(v_a, static_cast(a_m_k(m, k))); + b_element_op(v_b, static_cast(b_k_n(k, n))); + + v_acc += v_a * v_b; + } + + float v_c; + + c_element_op(v_c, v_acc); + + c_m_n(m, n) = v_c; + }; + + make_ParallelTensorFunctor(f_mk_kn_mn, + c_m_n.mDesc.GetLengths()[0], + c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency()); +} diff --git a/library/include/ck/library/host_tensor/host_reduce_util.hpp b/library/include/ck/library/host_tensor/host_reduce_util.hpp new file mode 100644 index 00000000000..095bb034263 --- /dev/null +++ b/library/include/ck/library/host_tensor/host_reduce_util.hpp @@ -0,0 +1,257 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GUARD_HOST_REDUCE_UTIL_HPP +#define GUARD_HOST_REDUCE_UTIL_HPP + +#include +#include +#include + +#include "reduction_enums.hpp" +#include "data_type.hpp" +#include "math_v2.hpp" + +namespace ck { + +namespace host_reduce { + +using ck::NanPropagation; +using ck::ReduceTensorOp; + +template +__host__ static inline std::function PreUnaryOpFn(int) +{ + using ck::math::abs; + + if constexpr(ReduceOpId == ReduceTensorOp::NORM1) + { + return ([&](AccDataType& a_) { a_ = abs(a_); }); + } + else if constexpr(ReduceOpId == ReduceTensorOp::NORM2) + { + return ([&](AccDataType& a_) { a_ = a_ * a_; }); + } + else if constexpr(ReduceOpId == ReduceTensorOp::AMAX) + { + return ([&](AccDataType& a_) { a_ = abs(a_); }); + } + else + { + // ReduceTensorOp::AVG: + // ReduceTensorOp::ADD: + // ReduceTensorOp::MUL: + // ReduceTensorOp::MIN: + // ReduceTensorOp::MAX: + return ([&](AccDataType&) {}); + }; +}; + +template +__host__ static inline std::function PosUnaryOpFn(int32_t divider) +{ + using std::sqrt; + + if constexpr(ReduceOpId == ReduceTensorOp::NORM2) + { + return ([&](AccDataType& a_) { a_ = sqrt(a_); }); + } + else if constexpr(ReduceOpId == ReduceTensorOp::AVG) + { + return ([&, divider](AccDataType& a_) { + a_ = a_ / static_cast(static_cast(divider)); + }); + } + else + { + // ReduceTensorOp::ADD: + // ReduceTensorOp::NORM1: + // ReduceTensorOp::MUL: + // ReduceTensorOp::MIN: + // ReduceTensorOp::MAX: + // ReduceTensorOp::AMAX: + return ([&](AccDataType&) {}); + } +}; + +template +__host__ static inline std::function ReduceOpFn() +{ + if constexpr(ReduceOpId == ReduceTensorOp::ADD || ReduceOpId == ReduceTensorOp::AVG || + ReduceOpId == ReduceTensorOp::NORM1 || ReduceOpId == ReduceTensorOp::NORM2) + { + return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ + b_; }); + } + else if constexpr(ReduceOpId == ReduceTensorOp::MUL) + { + return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ * b_; }); + } + else if constexpr(ReduceOpId == ReduceTensorOp::MIN) + { + return ([&](AccDataType& a_, AccDataType b_) { + if(a_ > b_) + a_ = b_; + }); + } + else if constexpr(ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX) + { + return ([&](AccDataType& a_, AccDataType b_) { + if(a_ < b_) + a_ = b_; + }); + } +}; + +template +__host__ static inline std::function ReduceOpFn2() +{ + if constexpr(ReduceOpId == ReduceTensorOp::MIN) + { + return ([&](AccDataType& a_, AccDataType b_, bool& changed) { + if(a_ > b_) + { + a_ = b_; + changed = true; + } + else + changed = false; + }); + } + else if constexpr(ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX) + { + return ([&](AccDataType& a_, AccDataType b_, bool& changed) { + if(a_ < b_) + { + a_ = b_; + changed = true; + } + else + changed = false; + }); + } + else + { + // ReduceTensorOp::ADD: + // ReduceTensorOp::MUL: + // ReduceTensorOp::AVG: + // ReduceTensorOp::NORM1: + // ReduceTensorOp::NORM2: + return (std::function{}); + }; +}; + +template +__host__ static inline AccDataType ReduceOpZeroVal() +{ + if constexpr(ReduceOpId == ReduceTensorOp::MUL) + { + return (static_cast(1.0f)); + } + else if constexpr(ReduceOpId == ReduceTensorOp::MIN) + { + return (ck::NumericLimits::Max()); + } + else if constexpr(ReduceOpId == ReduceTensorOp::MAX) + { + return (ck::NumericLimits::Lowest()); + } + else if constexpr(ReduceOpId == ReduceTensorOp::AMAX) + { + return (static_cast(0.0f)); + } + else + { + // ReduceTensorOp::ADD + // ReduceTensorOp::AVG + // ReduceTensorOp::NORM1 + // ReduceTensorOp::NORM2 + return (static_cast(0.0f)); + }; +}; + +template +__host__ static inline void +binop_with_nan_check(std::function opReduce, + AccDataType& accuVal, + AccDataType currVal) +{ + using ck::math::isnan; + + if constexpr(!PropagateNan) + { + opReduce(accuVal, currVal); + } + else + { + if(isnan(currVal)) + accuVal = currVal; + else + opReduce(accuVal, currVal); + }; +}; + +template +__host__ static inline void +binop_with_index_and_nan_check(std::function opReduce, + AccDataType& accuVal, + AccDataType currVal, + IndexDataType& accuIndex, + IndexDataType currIndex) +{ + using ck::math::isnan; + + if constexpr(!PropagateNan) + { + bool changed; + + opReduce(accuVal, currVal, changed); + + if(changed) + accuIndex = currIndex; + } + else + { + if(isnan(currVal)) + { + accuVal = currVal; + accuIndex = currIndex; + } + else + { + bool changed; + + opReduce(accuVal, currVal, changed); + + if(changed) + accuIndex = currIndex; + }; + }; +}; + +}; // namespace host_reduce + +}; // namespace ck + +#endif diff --git a/library/include/ck/library/host_tensor/host_reduction.hpp b/library/include/ck/library/host_tensor/host_reduction.hpp new file mode 100644 index 00000000000..1add62d1b5f --- /dev/null +++ b/library/include/ck/library/host_tensor/host_reduction.hpp @@ -0,0 +1,404 @@ + +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef HOST_REDUCTION_HPP_ +#define HOST_REDUCTION_HPP_ + +#include +#include +#include + +#include "reduction_enums.hpp" +#include "reduction_common.hpp" +#include "host_reduce_util.hpp" +#include "host_common_util.hpp" +#include "host_tensor.hpp" +#include "data_type.hpp" + +template +static void get_all_indexes(const std::array& dimLengths, + std::vector>& indexes) +{ + static_assert(NDim >= 1, "NDim >= 1 is required to use this function!"); + + if constexpr(NDim == 1) + { + for(size_t i = 0; i < dimLengths[0]; i++) + { + std::array index{i}; + + indexes.push_back(index); + }; + } + else + { + std::array partial_dim_lengths; + + for(int i = 0; i < NDim - 1; i++) + partial_dim_lengths[i] = dimLengths[i + 1]; + + std::vector> partial_indexes; + + get_all_indexes(partial_dim_lengths, partial_indexes); + + for(size_t i = 0; i < dimLengths[0]; i++) + for(const auto& index : partial_indexes) + { + std::array extIndex; + + extIndex[0] = i; + + for(int k = 0; k < NDim - 1; k++) + extIndex[k + 1] = index[k]; + + indexes.push_back(extIndex); + }; + }; +}; + +template +static size_t get_offset_from_index(const std::array& strides, + const std::array& index) +{ + size_t offset = 0; + + for(int i = 0; i < NDim; i++) + offset += strides[i] * index[i]; + + return (offset); +}; + +template +static size_t get_offset_from_index(const std::vector& strides, + const std::array& index) +{ + size_t offset = 0; + + for(int i = 0; i < NDim; i++) + offset += strides[i] * index[i]; + + return (offset); +}; + +template +struct ReductionHost +{ + using IndexDataType = int32_t; + + static constexpr int NumInvariantDim = Rank - NumReduceDim; + + std::vector outStrides; + std::vector invariantDims; + std::vector reduceDims; + + IndexDataType divider; + std::function preUnaryOp; + std::function posUnaryOp; + std::array reduceLengths; + std::array reduceStrides; + std::array invariantLengths; + std::array invariantStrides; + + std::vector> reduce_dim_indexes; + std::vector> invariant_dim_indexes; + + ReductionHost(HostTensorDescriptor& inDesc, + HostTensorDescriptor& outDesc, + const std::vector& invariantDims_, + const std::vector& reduceDims_) + { + using ck::host_reduce::PosUnaryOpFn; + using ck::host_reduce::PreUnaryOpFn; + + // this->outLengths = to_int_vector(outDesc.GetLengths()); + this->outStrides = outDesc.GetStrides(); + + this->invariantDims = invariantDims_; + this->reduceDims = reduceDims_; + + int product = 1; + + for(int i = 0; i < NumReduceDim; i++) + { + reduceLengths[i] = inDesc.GetLengths()[reduceDims[i]]; + reduceStrides[i] = inDesc.GetStrides()[reduceDims[i]]; + product *= inDesc.GetLengths()[reduceDims[i]]; + }; + + divider = product; + + for(int i = 0; i < NumInvariantDim; i++) + { + invariantLengths[i] = inDesc.GetLengths()[invariantDims[i]]; + invariantStrides[i] = inDesc.GetStrides()[invariantDims[i]]; + }; + + reduce_dim_indexes.clear(); + get_all_indexes(reduceLengths, reduce_dim_indexes); + + if constexpr(NumInvariantDim > 0) + { + invariant_dim_indexes.clear(); + get_all_indexes(invariantLengths, invariant_dim_indexes); + }; + + preUnaryOp = PreUnaryOpFn(divider); + posUnaryOp = PosUnaryOpFn(divider); + }; + + void Run(float alpha, + const InDataType* in_data, + float beta, + OutDataType* out_data, + IndexDataType* out_indices) + { + if constexpr(NeedIndices) + { + RunImpl_with_index(alpha, in_data, beta, out_data, out_indices); + } + else + { + RunImpl_no_index(alpha, in_data, beta, out_data); + }; + }; + + void RunImpl_with_index(float alpha, + const InDataType* in_data, + float beta, + OutDataType* out_data, + IndexDataType* out_indices) + { + using ck::float_equal_one; + using ck::float_equal_zero; + using ck::type_convert; + using ck::host_reduce::binop_with_index_and_nan_check; + using ck::host_reduce::ReduceOpFn2; + using ck::host_reduce::ReduceOpZeroVal; + + auto opReduce2 = ReduceOpFn2(); + + if constexpr(NumInvariantDim == 0) + { + AccDataType accuVal = ReduceOpZeroVal(); + IndexDataType accuIndex = 0; + + for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++) + { + auto offset_reduce = + get_offset_from_index(reduceStrides, reduce_dim_indexes[i]); + + auto currVal = type_convert(in_data[offset_reduce]); + + preUnaryOp(currVal); + + auto currIndex = static_cast(i); + + binop_with_index_and_nan_check( + opReduce2, accuVal, currVal, accuIndex, currIndex); + }; + + posUnaryOp(accuVal); + + if(!float_equal_one{}(alpha)) + accuVal *= type_convert(alpha); + + if(!float_equal_zero{}(beta)) + accuVal += type_convert(out_data[0]) * type_convert(beta); + + out_data[0] = type_convert(accuVal); + out_indices[0] = accuIndex; + } + else + { + auto thread_reduce_func = [&](auto invariant_index) { + AccDataType accuVal = ReduceOpZeroVal(); + IndexDataType accuIndex = 0; + + auto offset_invariant = + get_offset_from_index(invariantStrides, invariant_index); + + for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++) + { + auto offset_reduce = + get_offset_from_index(reduceStrides, reduce_dim_indexes[i]); + + auto currVal = + type_convert(in_data[offset_invariant + offset_reduce]); + + preUnaryOp(currVal); + + auto currIndex = static_cast(i); + + binop_with_index_and_nan_check( + opReduce2, accuVal, currVal, accuIndex, currIndex); + }; + + posUnaryOp(accuVal); + + if(!float_equal_one{}(alpha)) + accuVal *= type_convert(alpha); + + auto dst_offset = + get_offset_from_index(outStrides, invariant_index); + + if(!float_equal_zero{}(beta)) + accuVal += type_convert(out_data[dst_offset]) * + type_convert(beta); + + out_data[dst_offset] = type_convert(accuVal); + out_indices[dst_offset] = accuIndex; + }; + + std::size_t num_thread = 1; + std::size_t work_per_thread = + (invariant_dim_indexes.size() + num_thread - 1) / num_thread; + + std::vector threads(num_thread); + + for(std::size_t it = 0; it < num_thread; ++it) + { + std::size_t iw_begin = it * work_per_thread; + std::size_t iw_end = + std::min((it + 1) * work_per_thread, invariant_dim_indexes.size()); + + auto f = [=] { + for(std::size_t iw = iw_begin; iw < iw_end; ++iw) + { + thread_reduce_func(invariant_dim_indexes[iw]); + } + }; + + threads[it] = joinable_thread(f); + } + }; + }; + + void RunImpl_no_index(float alpha, const InDataType* in_data, float beta, OutDataType* out_data) + { + using ck::float_equal_one; + using ck::float_equal_zero; + using ck::type_convert; + using ck::host_reduce::binop_with_nan_check; + using ck::host_reduce::ReduceOpFn; + using ck::host_reduce::ReduceOpZeroVal; + + auto opReduce = ReduceOpFn(); + + if constexpr(NumInvariantDim == 0) + { + AccDataType accuVal = ReduceOpZeroVal(); + + for(const auto& reduce_index : reduce_dim_indexes) + { + auto offset_reduce = + get_offset_from_index(reduceStrides, reduce_index); + + auto currVal = type_convert(in_data[offset_reduce]); + + preUnaryOp(currVal); + + binop_with_nan_check(opReduce, accuVal, currVal); + }; + + posUnaryOp(accuVal); + + if(!float_equal_one{}(alpha)) + accuVal *= type_convert(alpha); + + if(!float_equal_zero{}(beta)) + accuVal += type_convert(out_data[0]) * type_convert(beta); + + out_data[0] = type_convert(accuVal); + } + else + { + auto thread_reduce_func = [&](auto invariant_index) { + AccDataType accuVal = ReduceOpZeroVal(); + + auto offset_invariant = + get_offset_from_index(invariantStrides, invariant_index); + + for(const auto& reduce_index : reduce_dim_indexes) + { + auto offset_reduce = + get_offset_from_index(reduceStrides, reduce_index); + + auto currVal = + type_convert(in_data[offset_invariant + offset_reduce]); + + preUnaryOp(currVal); + + binop_with_nan_check(opReduce, accuVal, currVal); + }; + + posUnaryOp(accuVal); + + if(!float_equal_one{}(alpha)) + accuVal *= type_convert(alpha); + + auto dst_offset = + get_offset_from_index(outStrides, invariant_index); + + if(!float_equal_zero{}(beta)) + accuVal += type_convert(out_data[dst_offset]) * + type_convert(beta); + + out_data[dst_offset] = type_convert(accuVal); + }; + + std::size_t num_thread = 1; + std::size_t work_per_thread = + (invariant_dim_indexes.size() + num_thread - 1) / num_thread; + + std::vector threads(num_thread); + + for(std::size_t it = 0; it < num_thread; ++it) + { + std::size_t iw_begin = it * work_per_thread; + std::size_t iw_end = + std::min((it + 1) * work_per_thread, invariant_dim_indexes.size()); + + auto f = [=] { + for(std::size_t iw = iw_begin; iw < iw_end; ++iw) + { + thread_reduce_func(invariant_dim_indexes[iw]); + } + }; + + threads[it] = joinable_thread(f); + } + }; + }; +}; + +#endif diff --git a/host/host_tensor/include/host_tensor.hpp b/library/include/ck/library/host_tensor/host_tensor.hpp similarity index 73% rename from host/host_tensor/include/host_tensor.hpp rename to library/include/ck/library/host_tensor/host_tensor.hpp index 06aed0a0c11..ad6aeecb505 100644 --- a/host/host_tensor/include/host_tensor.hpp +++ b/library/include/ck/library/host_tensor/host_tensor.hpp @@ -8,6 +8,7 @@ #include #include #include +#include "data_type.hpp" template std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) @@ -39,20 +40,6 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) return os; } -typedef enum -{ - Half = 0, - Float = 1, -} DataType_t; - -template -struct DataType; - -template <> -struct DataType : std::integral_constant -{ -}; - template auto call_f_unpack_args_impl(F f, T args, std::index_sequence) { @@ -86,10 +73,10 @@ struct HostTensorDescriptor HostTensorDescriptor() = delete; template - HostTensorDescriptor(std::vector lens); + HostTensorDescriptor(const std::vector& lens); template - HostTensorDescriptor(std::vector lens, std::vector strides); + HostTensorDescriptor(const std::vector& lens, const std::vector& strides); void CalculateStrides(); @@ -120,6 +107,8 @@ struct HostTensorDescriptor return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); } + friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); + private: std::vector mLens; std::vector mStrides; @@ -165,7 +154,7 @@ struct ParallelTensorFunctor { std::array indices; - for(int idim = 0; idim < NDIM; ++idim) + for(std::size_t idim = 0; idim < NDIM; ++idim) { indices[idim] = i / mStrides[idim]; i -= indices[idim] * mStrides[idim]; @@ -174,7 +163,7 @@ struct ParallelTensorFunctor return indices; } - void operator()(std::size_t num_thread = std::thread::hardware_concurrency()) const + void operator()(std::size_t num_thread = 1) const { std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread; @@ -255,6 +244,18 @@ struct Tensor mDesc.GetLengths()[3])(num_thread); break; } + case 5: { + auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4) { + (*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4); + }; + make_ParallelTensorFunctor(f, + mDesc.GetLengths()[0], + mDesc.GetLengths()[1], + mDesc.GetLengths()[2], + mDesc.GetLengths()[3], + mDesc.GetLengths()[4])(num_thread); + break; + } default: throw std::runtime_error("unspported dimension"); } } @@ -284,39 +285,69 @@ struct Tensor }; template -HostTensorDescriptor::HostTensorDescriptor(std::vector lens) : mLens(lens) +HostTensorDescriptor::HostTensorDescriptor(const std::vector& lens) : mLens(lens) { this->CalculateStrides(); } template -HostTensorDescriptor::HostTensorDescriptor(std::vector lens, std::vector strides) +HostTensorDescriptor::HostTensorDescriptor(const std::vector& lens, + const std::vector& strides) : mLens(lens), mStrides(strides) { } void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout); +#if 1 +// FIXME: remove +void bf16_to_f32_(const Tensor& src, Tensor& dst); +#endif + template -void check_error(const Tensor& ref, const Tensor& result) +float check_error(const Tensor& ref, const Tensor& result) { - float error = 0; - float max_diff = -1; - float ref_value = 0, result_value = 0; - for(int i = 0; i < ref.mData.size(); ++i) + float l1_error = 0; + float linf_error = -1; + float linf_rel_error = -1; + + float linf_ref_value = 0, linf_result_value = 0; + float linf_rel_ref_value = 0, linf_rel_result_value = 0; + + constexpr float eps = 1e-10; + + for(std::size_t i = 0; i < ref.mData.size(); ++i) { - error += std::abs(double(ref.mData[i]) - double(result.mData[i])); - float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); - if(max_diff < diff) + float ref_v = ck::type_convert(ref.mData[i]); + float result_v = ck::type_convert(result.mData[i]); + + float diff = std::abs(ref_v - result_v); + float rel_diff = diff / std::max(std::abs(ref_v), eps); + + l1_error += diff; + + if(linf_error < diff) + { + linf_error = diff; + linf_ref_value = ref_v; + linf_result_value = result_v; + } + + if(linf_rel_error < rel_diff) { - max_diff = diff; - ref_value = ref.mData[i]; - result_value = result.mData[i]; + linf_rel_error = rel_diff; + linf_rel_ref_value = ref_v; + linf_rel_result_value = result_v; } } - std::cout << "error: " << error << std::endl; - std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl; + std::cout << "Absolute Error L1 Norm (sum of abs diff): " << l1_error << std::endl; + std::cout << "Absolute Error L-inf Norm (max abs diff): " << linf_error << ", ref " + << linf_ref_value << ", result " << linf_result_value << std::endl; + std::cout << "Relative Error L-inf Norm (max relative abs diff): " << linf_rel_error << ", ref " + << linf_rel_ref_value << ", result " << linf_rel_result_value << std::endl; + + return linf_error; } #endif diff --git a/library/include/ck/library/host_tensor/host_tensor_generator.hpp b/library/include/ck/library/host_tensor/host_tensor_generator.hpp new file mode 100644 index 00000000000..17e20351f04 --- /dev/null +++ b/library/include/ck/library/host_tensor/host_tensor_generator.hpp @@ -0,0 +1,150 @@ +#pragma once + +#include +#include + +#include "config.hpp" + +template +struct GeneratorTensor_0 +{ + template + T operator()(Is...) + { + return T{0}; + } +}; + +template +struct GeneratorTensor_1 +{ + int value = 1; + + template + T operator()(Is...) + { + return ck::type_convert(value); + } +}; + +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::bhalf_t operator()(Is...) + { + return ck::type_convert(value); + } +}; + +template <> +struct GeneratorTensor_1 +{ + int8_t value = 1; + + template + int8_t operator()(Is...) + { + return value; + } +}; + +template +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + T operator()(Is...) + { + return static_cast((std::rand() % (max_value - min_value)) + min_value); + } +}; + +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::bhalf_t operator()(Is...) + { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + return ck::type_convert(tmp); + } +}; + +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + int8_t operator()(Is...) + { + return (std::rand() % (max_value - min_value)) + min_value; + } +}; + +template +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + T operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + return static_cast(min_value + tmp * (max_value - min_value)); + } +}; + +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::bhalf_t operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp = min_value + tmp * (max_value - min_value); + + return ck::type_convert(fp32_tmp); + } +}; + +struct GeneratorTensor_Checkboard +{ + template + float operator()(Ts... Xs) const + { + std::array dims = {static_cast(Xs)...}; + return std::accumulate(dims.begin(), + dims.end(), + true, + [](bool init, ck::index_t x) -> int { return init != (x % 2); }) + ? 1 + : -1; + } +}; + +template +struct GeneratorTensor_Sequential +{ + template + float operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + return dims[Dim]; + } +}; diff --git a/host/driver_offline/include/debug.hpp b/library/include/ck/library/obselete_driver_offline/debug.hpp similarity index 100% rename from host/driver_offline/include/debug.hpp rename to library/include/ck/library/obselete_driver_offline/debug.hpp diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp new file mode 100644 index 00000000000..debb5058e72 --- /dev/null +++ b/library/include/ck/library/obselete_driver_offline/device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp @@ -0,0 +1,220 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp" + +template +void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1( + const InLengths& in_n_c0_hi_wi_c1_lengths, + const WeiLengths& wei_k_c0_y_x_c1_lengths, + const AddLengths& add_n_k0_hox2_wox2_k1_lengths, + const OutLengths& out_n_k0_ho_wo_k1_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c0_hi_wi_c1, + const Tensor& wei_k_c0_y_x_c1, + const Tensor& bias_k0_k1, + const Tensor& add_n_k0_hox2_wox2_k1, + Tensor& add_n_k0_hox2_wox2_k1_out, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + const auto N = out_n_k0_ho_wo_k1_lengths[I0]; + const auto K0 = out_n_k0_ho_wo_k1_lengths[I1]; + const auto Ho = out_n_k0_ho_wo_k1_lengths[I2]; + const auto Wo = out_n_k0_ho_wo_k1_lengths[I3]; + const auto K1 = out_n_k0_ho_wo_k1_lengths[I4]; + + const auto C0 = in_n_c0_hi_wi_c1_lengths[I1]; + const auto Hi = in_n_c0_hi_wi_c1_lengths[I2]; + const auto Wi = in_n_c0_hi_wi_c1_lengths[I3]; + const auto C1 = in_n_c0_hi_wi_c1_lengths[I4]; + + const auto K = wei_k_c0_y_x_c1_lengths[I0]; + const auto Y = wei_k_c0_y_x_c1_lengths[I2]; + const auto X = wei_k_c0_y_x_c1_lengths[I3]; + + const auto Hox2 = add_n_k0_hox2_wox2_k1_lengths[I2]; + const auto Wox2 = add_n_k0_hox2_wox2_k1_lengths[I3]; + + DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) * + in_n_c0_hi_wi_c1.mDesc.GetElementSpace()); + DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace()); + DeviceMem bias_k0_k1_device_buf(sizeof(TOut) * bias_k0_k1.mDesc.GetElementSpace()); + DeviceMem add_n_k0_hox2_wox2_k1_device_buf(sizeof(TOut) * + add_n_k0_hox2_wox2_k1.mDesc.GetElementSpace()); + + in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); + wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); + bias_k0_k1_device_buf.ToDevice(bias_k0_k1.mData.data()); + add_n_k0_hox2_wox2_k1_device_buf.ToDevice(add_n_k0_hox2_wox2_k1.mData.data()); + + constexpr index_t InWeiVectorSize = 8; + + if(C1 % InWeiVectorSize != 0) + { + throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize"); + } + +#if 0 + constexpr index_t BlockSize = 256; + + constexpr index_t KPerBlock = 32; + constexpr index_t HoPerBlock = 8; + constexpr index_t WoPerBlock = 64; + + constexpr index_t E1 = C0 * 9; + constexpr index_t E2 = 1; + constexpr index_t E1PerBlock = C0; + + constexpr index_t KPerThread = 16; + constexpr index_t HoPerThread = 2; + constexpr index_t WoPerThread = 2; + constexpr index_t EPerThread = 1; + + using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>; + using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2; + constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2; + + constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2; + + constexpr index_t CThreadTransferDstScalarPerVector_K = K1; +#elif 1 + constexpr auto BlockSize = 64; + + constexpr auto KPerBlock = 8; + constexpr auto HoPerBlock = 8; + constexpr auto WoPerBlock = 32; + + constexpr auto E1 = 2 * 9; + constexpr auto E2 = 1; + constexpr auto K2 = 2; + constexpr auto E1PerBlock = 2; + + constexpr auto KPerThread = KPerBlock; + constexpr auto HoPerThread = 2; + constexpr auto WoPerThread = 2; + constexpr auto EPerThread = 1; + + using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, 1, E2>; + using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = + Sequence<1, E1PerBlock, 1, KPerBlock, 1>; + + constexpr auto ABlockTransferSrcScalarPerVector_E2 = E2; + constexpr auto ABlockTransferDstScalarPerVector_E2 = E2; + constexpr auto BThreadTransferSrcScalarPerVector_E2 = E2; + constexpr auto CThreadTransferDstScalarPerVector_K = InWeiVectorSize; +#endif + + const auto in_n_c0_hi_wi_c1_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)); + const auto wei_k_c0_y_x_c1_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, E2)); + const auto add_n_k0_hox2_wox2_k1_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hox2, Wox2, K1)); + const auto out_n_k0_ho_wo_k1_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)); + + constexpr auto conv_driver = + DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_add< + BlockSize, + typename vector_type::type, + TAcc, + TOut, + E1, + E2, + K2, + KPerBlock, + HoPerBlock, + WoPerBlock, + E1PerBlock, + KPerThread, + HoPerThread, + WoPerThread, + EPerThread, + ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, + ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, + ABlockTransferSrcScalarPerVector_E2, + ABlockTransferDstScalarPerVector_E2, + BThreadTransferSrcScalarPerVector_E2, + CThreadTransferDstScalarPerVector_K, + activ_type>{}; + + std::cerr << "conv_bias_activ_resize_add_input_" + << "n" << N << "c" << C0 << "h" << Hi << "w" << Wi << "c" << C1 << "_filter_k" << K + << "c" << C0 << "y" << Y << "x" << X << "c" << C1 << "_addout_n" << N << "k" << K0 + << "h" << Ho * 2 << "w" << Wo * 2 << "k" << K1 << std::endl; + + for(int i = 0; i < 5; i++) + { + + const auto ave_time = + conv_driver.Run(wei_k_c0_y_x_c1_desc, + in_n_c0_hi_wi_c1_desc, + out_n_k0_ho_wo_k1_desc, + add_n_k0_hox2_wox2_k1_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + static_cast::type*>( + wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()), + static_cast::type*>( + in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()), + static_cast(bias_k0_k1_device_buf.GetDeviceBuffer()), + static_cast(add_n_k0_hox2_wox2_k1_device_buf.GetDeviceBuffer()), + nrepeat); + + { + float perf = static_cast(std::size_t(2) * N * K * Ho * Wo * C0 * C1 * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + add_n_k0_hox2_wox2_k1_device_buf.ToDevice(add_n_k0_hox2_wox2_k1.mData.data()); + + conv_driver.Run(wei_k_c0_y_x_c1_desc, + in_n_c0_hi_wi_c1_desc, + out_n_k0_ho_wo_k1_desc, + add_n_k0_hox2_wox2_k1_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + static_cast::type*>( + wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()), + static_cast::type*>( + in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()), + static_cast(bias_k0_k1_device_buf.GetDeviceBuffer()), + static_cast(add_n_k0_hox2_wox2_k1_device_buf.GetDeviceBuffer()), + 0); + + add_n_k0_hox2_wox2_k1_device_buf.FromDevice(add_n_k0_hox2_wox2_k1_out.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp similarity index 99% rename from host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp rename to library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp index 8258aa0e663..79d31ba2467 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp @@ -231,7 +231,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( TInWei, TAcc, TOut, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmm_gemmn_grid_desc), diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp similarity index 98% rename from host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp rename to library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp index 28d6226f1b4..e3b6a6c8c29 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -303,14 +303,14 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - const auto YTilda = ConvStrideH / GcdStrideDilationH; - const auto XTilda = ConvStrideW / GcdStrideDilationW; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; float ave_time = 0; - for(index_t i_ytilda = 0; i_ytilda < YTilda; ++i_ytilda) + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) { - for(index_t i_xtilda = 0; i_xtilda < XTilda; ++i_xtilda) + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) { const auto descs = transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( @@ -321,8 +321,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk conv_dilations, in_left_pads, in_right_pads, - i_ytilda, - i_xtilda, + i_ytilde, + i_xtilde, Number{}); const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; @@ -338,7 +338,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk TInWei, TAcc, TOut, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, decltype(out_gemmk0_gemmm_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmm_gemmn_grid_desc), diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp similarity index 99% rename from host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp rename to library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp index d6955ec0005..9cc4052f778 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp @@ -104,7 +104,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 +#elif 0 // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; @@ -132,7 +132,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 +#elif 1 // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; @@ -307,7 +307,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk TInWei, TAcc, TOut, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, decltype(out_gemmk0_gemmm_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmm_gemmn_grid_desc), diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp similarity index 99% rename from host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp rename to library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp index 8207e2cb281..993630f3f8a 100644 --- a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp @@ -171,7 +171,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_ TIn, TAcc, TWei, - InMemoryDataOperationEnum_t::AtomicAdd, + InMemoryDataOperationEnum::AtomicAdd, decltype(out_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(wei_gemmm_gemmn_grid_desc), diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp similarity index 99% rename from host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp rename to library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index ac75c56bf5a..dfb612f690e 100644 --- a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -168,7 +168,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk TIn, TAcc, TWei, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, decltype(out_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(wei_gemmm_gemmn_grid_desc), diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp similarity index 99% rename from host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp rename to library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp index 6381ce8bb44..06d0ea684f9 100644 --- a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp @@ -200,7 +200,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_ TIn, TAcc, TWei, - InMemoryDataOperationEnum_t::AtomicAdd, + InMemoryDataOperationEnum::AtomicAdd, decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc), decltype(wei_gemmm_gemmn_grid_desc), diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp similarity index 99% rename from host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp rename to library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index bc5d5996041..5221ec582d2 100644 --- a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -199,7 +199,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh TIn, TAcc, TWei, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, decltype(in_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmk0_gemmn_gemmk1_grid_desc), decltype(wei_gemmm_gemmn_grid_desc), diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp similarity index 99% rename from host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp rename to library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp index 603f8726622..1bdad6e97b3 100644 --- a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp @@ -367,7 +367,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_ TIn, TAcc, TWei, - InMemoryDataOperationEnum_t::AtomicAdd, + InMemoryDataOperationEnum::AtomicAdd, decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc), decltype(wei_gemmm_gemmn_grid_desc), diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp similarity index 99% rename from host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp rename to library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp index e6554cf0fe4..a9df58bedda 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp @@ -138,7 +138,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( TInWei, TAcc, TOut, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, decltype(wei_gemmk_gemmm_grid_desc), decltype(in_gemmk_gemmn_grid_desc), decltype(out_gemmm_gemmn_grid_desc), diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp similarity index 98% rename from host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp rename to library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp index 40685e81cfa..843df27a88a 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp @@ -141,14 +141,14 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( #endif const auto descs = - transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); + transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; @@ -202,7 +202,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( TInWei, TAcc, TOut, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, decltype(in_gemmk0_gemmm_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmm_gemmn_grid_desc), diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp similarity index 99% rename from host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp rename to library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index d65ecadb4df..e4cf4dd25cd 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -167,7 +167,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( TInWei, TAcc, TOut, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmm_gemmn_grid_desc), diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp similarity index 65% rename from host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp rename to library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 1b23aa1a8c9..18e712fb47c 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -4,6 +4,131 @@ #include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" #include "driver_gemm_xdlops_v2r3.hpp" +#if 0 +__host__ __device__ static constexpr auto +MakePaddedGridDescriptors(const AGridDesc_K0Raw_MRaw_K1& a_grid_desc_k0raw_mraw_k1, + const BGridDesc_K0Raw_NRaw_K1& b_grid_desc_k0raw_nraw_k1, + const CGridDesc_MRaw_NRaw& c_grid_desc_mraw_nraw) +{ + const auto K0Raw = a_grid_desc_k0raw_mraw_k1.GetLength(I0); + const auto K1 = a_grid_desc_k0raw_mraw_k1.GetLength(I2); + const auto MRaw = c_grid_desc_mraw_nraw.GetLength(I0); + const auto NRaw = c_grid_desc_mraw_nraw.GetLength(I1); + + const auto K0Pad = math::integer_least_multiple(K0Raw, K0PerBlock) - K0Raw; + const auto MPad = math::integer_least_multiple(MRaw, MPerBlock) - MRaw; + const auto NPad = math::integer_least_multiple(NRaw, NPerBlock) - NRaw; + + // A + const auto a_grid_desc_k0_m_k1 = [&]() { + if constexpr(DoPad_K0 && DoPad_M) + { + return transform_tensor_descriptor( + a_grid_desc_k0_m_k1, + make_tuple(make_right_pad_transform(K0Raw, K0Pad), + make_right_pad_transform(MRaw, MPad), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else if constexpr(DoPad_K0 && !DoPad_M) + { + return transform_tensor_descriptor( + a_grid_desc_k0_m_k1, + make_tuple(make_right_pad_transform(K0Raw, K0Pad), + make_pass_through_transform(MRaw), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else if constexpr(!DoPad_K0 && DoPad_M) + { + return transform_tensor_descriptor( + a_grid_desc_k0_m_k1, + make_tuple(make_pass_through_transform(K0Raw), + make_right_pad_transform(MRaw, MPad), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else + { + return a_grid_desc_k0raw_mraw_k1; + } + }(); + + // B + const auto b_grid_desc_k0_n_k1 = [&]() { + if constexpr(DoPad_K0 && DoPad_N) + { + return transform_tensor_descriptor( + b_grid_desc_k0_n_k1, + make_tuple(make_right_pad_transform(K0Raw, K0Pad), + make_right_pad_transform(NRaw, NPad), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else if constexpr(DoPad_K0 && !DoPad_N) + { + return transform_tensor_descriptor( + b_grid_desc_k0_n_k1, + make_tuple(make_right_pad_transform(K0Raw, K0Pad), + make_pass_through_transform(NRaw), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else if constexpr(!DoPad_K0 && DoPad_N) + { + return transform_tensor_descriptor( + b_grid_desc_k0_n_k1, + make_tuple(make_pass_through_transform(K0Raw), + make_right_pad_transform(NRaw, NPad), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else + { + return b_grid_desc_k0raw_nraw_k1; + } + }(); + + // C + const auto c_grid_desc_m_n = [&]() { + if constexpr(DoPad_M && DoPad_N) + { + return transform_tensor_descriptor(c_grid_desc_m_n, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(DoPad_M && !DoPad_N) + { + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(!DoPad_M && DoPad_N) + { + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + reutnr c_grid_desc_m_n; + } + }(); +} +#endif + template {}); - + transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + +#if 0 // debug const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - // HACK: hacks that control index calculation when iterating over A, B, C matrix + // HACK: hacks that control index calculation when iterating over A matrix constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM @@ -297,7 +421,39 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1 - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = + constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; +#else + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = descs[I0]; + + const auto GemmK0 = in_gemmk0_gemmmraw_gemmk1_grid_desc.GetLength(I0); + const auto GemmMRaw = in_gemmk0_gemmmraw_gemmk1_grid_desc.GetLength(I1); + const auto GemmMPad = math::integer_least_multiple(GemmMRaw, GemmMPerBlock) - GemmMRaw; + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // HACK: hacks that control index calculation when iterating over A matrix + constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{}; +#endif + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + + const auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 @@ -305,6 +461,12 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0>{}; + +#if 0 + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 @@ -322,12 +484,36 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 +#else + const auto out_gemmmraw_gemmn_grid_desc = descs[I2]; - constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + const auto GemmN = out_gemmmraw_gemmn_grid_desc.GetLength(I1); - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 +#endif for(index_t i = 0; i < 5; ++i) { @@ -336,7 +522,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( TInWei, TAcc, TOut, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, decltype(in_gemmk0_gemmm_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmm_gemmn_grid_desc), diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp new file mode 100644 index 00000000000..af4676f2a24 --- /dev/null +++ b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp @@ -0,0 +1,196 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp" + +template +void device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1( + const InLengths& in_n_c0_hi_wi_c1_lengths, + const WeiLengths& wei_k_c0_y_x_c1_lengths, + const OutLengths& out_n_k0_ho_wo_k1_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c0_hi_wi_c1, + const Tensor& wei_k_c0_y_x_c1, + const Tensor& bias_k0_k1, + Tensor& out_n_k0_ho_wo_k1, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + const auto N = out_n_k0_ho_wo_k1_lengths[I0]; + const auto K0 = out_n_k0_ho_wo_k1_lengths[I1]; + const auto Ho = out_n_k0_ho_wo_k1_lengths[I2]; + const auto Wo = out_n_k0_ho_wo_k1_lengths[I3]; + const auto K1 = out_n_k0_ho_wo_k1_lengths[I4]; + + const auto C0 = in_n_c0_hi_wi_c1_lengths[I1]; + const auto Hi = in_n_c0_hi_wi_c1_lengths[I2]; + const auto Wi = in_n_c0_hi_wi_c1_lengths[I3]; + const auto C1 = in_n_c0_hi_wi_c1_lengths[I4]; + + const auto K = wei_k_c0_y_x_c1_lengths[I0]; + const auto Y = wei_k_c0_y_x_c1_lengths[I2]; + const auto X = wei_k_c0_y_x_c1_lengths[I3]; + + DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) * + in_n_c0_hi_wi_c1.mDesc.GetElementSpace()); + DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace()); + DeviceMem bias_k0_k1_device_buf(sizeof(TOut) * bias_k0_k1.mDesc.GetElementSpace()); + DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) * + out_n_k0_ho_wo_k1.mDesc.GetElementSpace()); + in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); + wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); + bias_k0_k1_device_buf.ToDevice(bias_k0_k1.mData.data()); + + constexpr index_t InWeiVectorSize = 8; + + if(C1 % InWeiVectorSize != 0) + { + throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize"); + } + +#if 0 + constexpr index_t BlockSize = 256; + + constexpr index_t KPerBlock = 32; + constexpr index_t HoPerBlock = 8; + constexpr index_t WoPerBlock = 64; + + constexpr index_t E1 = C0 * 9; + constexpr index_t E2 = 1; + constexpr index_t E1PerBlock = C0; + + constexpr index_t KPerThread = 16; + constexpr index_t HoPerThread = 2; + constexpr index_t WoPerThread = 2; + constexpr index_t EPerThread = 1; + + using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>; + using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2; + constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2; + + constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2; + + constexpr index_t CThreadTransferDstScalarPerVector_K = K1; +#elif 1 + constexpr index_t BlockSize = 64; + + constexpr index_t KPerBlock = 8; + constexpr index_t HoPerBlock = 8; + constexpr index_t WoPerBlock = 32; + + constexpr index_t E1 = 2 * 9; + constexpr index_t E2 = 1; + constexpr index_t K2 = 2; + constexpr index_t E1PerBlock = 2; + + constexpr index_t KPerThread = KPerBlock; + constexpr index_t HoPerThread = 2; + constexpr index_t WoPerThread = 2; + constexpr index_t EPerThread = 1; + + using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, 1, E2>; + using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = + Sequence<1, E1PerBlock, 1, KPerBlock, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2; + constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2; + constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2; + constexpr index_t CThreadTransferDstScalarPerVector_K = InWeiVectorSize; +#endif + + if(KPerThread % InWeiVectorSize != 0) + { + throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize"); + } + + const auto in_n_c0_hi_wi_c1_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)); + const auto wei_k_c0_y_x_c1_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, E2)); + const auto out_n_k0_ho_wo_k1_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)); + + constexpr auto conv_driver = + DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad< + BlockSize, + typename vector_type::type, + TAcc, + TOut, + E1, + E2, + K2, + KPerBlock, + HoPerBlock, + WoPerBlock, + E1PerBlock, + KPerThread, + HoPerThread, + WoPerThread, + EPerThread, + ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, + ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, + ABlockTransferSrcScalarPerVector_E2, + ABlockTransferDstScalarPerVector_E2, + BThreadTransferSrcScalarPerVector_E2, + CThreadTransferDstScalarPerVector_K, + activ_type>{}; + + std::cerr << "conv_bias_activ_input_" + << "n" << N << "c" << C0 << "h" << Hi << "w" << Wi << "c" << C1 << "_filter_k" << K + << "c" << C0 << "y" << Y << "x" << X << "c" << C1 << "_convout_n" << N << "k" << K0 + << "h" << Ho << "w" << Wo << "k" << K1 << std::endl; + + for(int i = 0; i < 5; i++) + { + + const auto ave_time = + conv_driver.Run(wei_k_c0_y_x_c1_desc, + in_n_c0_hi_wi_c1_desc, + out_n_k0_ho_wo_k1_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + static_cast::type*>( + wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()), + static_cast::type*>( + in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()), + static_cast(bias_k0_k1_device_buf.GetDeviceBuffer()), + static_cast(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()), + nrepeat); + + { + float perf = static_cast(std::size_t(2) * N * K * Ho * Wo * C0 * C1 * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp similarity index 99% rename from host/driver_offline/include/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp rename to library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp index e1b7c5486cd..31925f0511c 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp @@ -182,7 +182,7 @@ void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( TInWei, TAcc, TOut, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, decltype(wei_grid_desc_gk0_gm0_gm1_gk1), decltype(in_grid_desc_gk0_gn0_gn1_gk1), decltype(out_grid_desc_gm0_gm1_gn0_gn1), diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp new file mode 100644 index 00000000000..2cb2e109152 --- /dev/null +++ b/library/include/ck/library/obselete_driver_offline/device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp @@ -0,0 +1,212 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp" + +template +void device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1( + const InLengths& in_n_c0_hi_wi_c1_lengths, + const WeiLengths& wei_k_c0_y_x_c1_lengths, + const MaxLengths& max_n_k0_hx_wx_k1_lengths, + const OutLengths& out_n_k0_ho_wo_k1_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c0_hi_wi_c1, + const Tensor& wei_k_c0_y_x_c1, + const Tensor& bias_k0_k1, + Tensor& out_n_k0_ho_wo_k1, + Tensor& max_n_k0_hx_wx_k1, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + const auto N = out_n_k0_ho_wo_k1_lengths[I0]; + const auto K0 = out_n_k0_ho_wo_k1_lengths[I1]; + const auto Ho = out_n_k0_ho_wo_k1_lengths[I2]; + const auto Wo = out_n_k0_ho_wo_k1_lengths[I3]; + const auto K1 = out_n_k0_ho_wo_k1_lengths[I4]; + + const auto C0 = in_n_c0_hi_wi_c1_lengths[I1]; + const auto Hi = in_n_c0_hi_wi_c1_lengths[I2]; + const auto Wi = in_n_c0_hi_wi_c1_lengths[I3]; + const auto C1 = in_n_c0_hi_wi_c1_lengths[I4]; + + const auto K = wei_k_c0_y_x_c1_lengths[I0]; + const auto Y = wei_k_c0_y_x_c1_lengths[I2]; + const auto X = wei_k_c0_y_x_c1_lengths[I3]; + + const auto Hx = max_n_k0_hx_wx_k1_lengths[I2]; + const auto Wx = max_n_k0_hx_wx_k1_lengths[I3]; + + DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) * + in_n_c0_hi_wi_c1.mDesc.GetElementSpace()); + DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace()); + DeviceMem bias_k0_k1_device_buf(sizeof(TOut) * bias_k0_k1.mDesc.GetElementSpace()); + DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) * + out_n_k0_ho_wo_k1.mDesc.GetElementSpace()); + DeviceMem max_n_k0_hx_wx_k1_device_buf(sizeof(TOut) * + max_n_k0_hx_wx_k1.mDesc.GetElementSpace()); + + in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); + wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); + bias_k0_k1_device_buf.ToDevice(bias_k0_k1.mData.data()); + max_n_k0_hx_wx_k1_device_buf.ToDevice(max_n_k0_hx_wx_k1.mData.data()); + + constexpr index_t InWeiVectorSize = 8; + + if(C1 % InWeiVectorSize != 0) + { + throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize"); + } + +#if 0 + constexpr index_t BlockSize = 256; + + constexpr index_t KPerBlock = 32; + constexpr index_t HoPerBlock = 8; + constexpr index_t WoPerBlock = 64; + + constexpr index_t E1 = C0 * 9; + constexpr index_t E2 = 1; + constexpr index_t E1PerBlock = C0; + + constexpr index_t KPerThread = 16; + constexpr index_t HoPerThread = 2; + constexpr index_t WoPerThread = 2; + constexpr index_t EPerThread = 1; + + using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>; + using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2; + constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2; + + constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2; + + constexpr index_t CThreadTransferDstScalarPerVector_K = K1; +#elif 1 + constexpr index_t BlockSize = 64; + + constexpr index_t KPerBlock = 8; + constexpr index_t HoPerBlock = 8; + constexpr index_t WoPerBlock = 32; + + constexpr index_t E1 = 2 * 9; + constexpr index_t E2 = 1; + constexpr index_t K2 = 2; + constexpr index_t E1PerBlock = 2; + + constexpr index_t KPerThread = KPerBlock; + constexpr index_t HoPerThread = 2; + constexpr index_t WoPerThread = 2; + constexpr index_t EPerThread = 1; + + using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, 1, E2>; + using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = + Sequence<1, E1PerBlock, 1, KPerBlock, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2; + constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2; + constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2; + constexpr index_t CThreadTransferDstScalarPerVector_K = InWeiVectorSize; +#endif + + if(KPerThread % InWeiVectorSize != 0) + { + throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize"); + } + + const auto in_n_c0_hi_wi_c1_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)); + const auto wei_k_c0_y_x_c1_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, E2)); + const auto max_n_k0_hx_wx_k1_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hx, Wx, K1)); + const auto out_n_k0_ho_wo_k1_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)); + + constexpr auto conv_driver = + DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_maxpool< + BlockSize, + typename vector_type::type, + TAcc, + TOut, + E1, + E2, + K2, + KPerBlock, + HoPerBlock, + WoPerBlock, + E1PerBlock, + KPerThread, + HoPerThread, + WoPerThread, + EPerThread, + ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, + ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, + ABlockTransferSrcScalarPerVector_E2, + ABlockTransferDstScalarPerVector_E2, + BThreadTransferSrcScalarPerVector_E2, + CThreadTransferDstScalarPerVector_K, + activ_type>{}; + + std::cerr << "conv_bias_activ_maxpool_input_" + << "n" << N << "c" << C0 << "h" << Hi << "w" << Wi << "c" << C1 << "_filter_k" << K + << "c" << C0 << "y" << Y << "x" << X << "c" << C1 << "_convout_n" << N << "k" << K0 + << "h" << Ho << "w" << Wo << "k" << K1 << "_maxpoolout_n" << N << "k" << K0 << "h" + << Ho / 2 << "w" << Wo / 2 << "k" << K1 << std::endl; + + for(int i = 0; i < 5; i++) + { + + const auto ave_time = + conv_driver.Run(wei_k_c0_y_x_c1_desc, + in_n_c0_hi_wi_c1_desc, + out_n_k0_ho_wo_k1_desc, + max_n_k0_hx_wx_k1_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + static_cast::type*>( + wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()), + static_cast::type*>( + in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()), + static_cast(bias_k0_k1_device_buf.GetDeviceBuffer()), + static_cast(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()), + static_cast(max_n_k0_hx_wx_k1_device_buf.GetDeviceBuffer()), + nrepeat); + + { + float perf = static_cast(std::size_t(2) * N * K * Ho * Wo * C0 * C1 * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data()); + max_n_k0_hx_wx_k1_device_buf.FromDevice(max_n_k0_hx_wx_k1.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_kn_mn.hpp similarity index 99% rename from host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp rename to library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_kn_mn.hpp index c44aa7d9a27..f54ff181dd9 100644 --- a/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_kn_mn.hpp @@ -398,7 +398,7 @@ void device_gemm_xdlops_km_kn_mn(const Tensor& a_k_m, ABType, AccType, CType, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, decltype(a_k0_m_k1_grid_desc), decltype(b_k0_n_k1_grid_desc), decltype(c_m_n_grid_desc), diff --git a/host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_kn_nm.hpp similarity index 99% rename from host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp rename to library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_kn_nm.hpp index abaaf321136..eb78ba96d8b 100644 --- a/host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_kn_nm.hpp @@ -202,7 +202,7 @@ void device_gemm_xdlops_km_kn_nm(const Tensor& a_k_m, ABType, AccType, CType, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, decltype(a_k0_m_k1_grid_desc), decltype(b_k0_n_k1_grid_desc), decltype(c_m_n_grid_desc), diff --git a/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_nk_mn.hpp similarity index 99% rename from host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp rename to library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_nk_mn.hpp index 0a97d361d4e..dbd318ce4dc 100644 --- a/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_nk_mn.hpp @@ -398,7 +398,7 @@ void device_gemm_xdlops_km_nk_mn(const Tensor& a_k_m, ABType, AccType, CType, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, decltype(a_k0_m_k1_grid_desc), decltype(b_k0_n_k1_grid_desc), decltype(c_m_n_grid_desc), diff --git a/host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_nk_nm.hpp similarity index 99% rename from host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp rename to library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_nk_nm.hpp index d51caa38477..5b819fd1af4 100644 --- a/host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_km_nk_nm.hpp @@ -202,7 +202,7 @@ void device_gemm_xdlops_km_nk_nm(const Tensor& a_k_m, ABType, AccType, CType, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, decltype(a_k0_m_k1_grid_desc), decltype(b_k0_n_k1_grid_desc), decltype(c_m_n_grid_desc), diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_kn_mn.hpp similarity index 99% rename from host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp rename to library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_kn_mn.hpp index 30ede2517b2..4b041777c3e 100644 --- a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_kn_mn.hpp @@ -398,7 +398,7 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor& a_m_k, ABType, AccType, CType, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, decltype(a_k0_m_k1_grid_desc), decltype(b_k0_n_k1_grid_desc), decltype(c_m_n_grid_desc), diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_kn_nm.hpp similarity index 99% rename from host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp rename to library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_kn_nm.hpp index 58ac3880d6f..c848cd79361 100644 --- a/host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_kn_nm.hpp @@ -230,7 +230,7 @@ void device_gemm_xdlops_mk_kn_nm(const Tensor& a_m_k, ABType, AccType, CType, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, decltype(a_k0_m_k1_grid_desc), decltype(b_k0_n_k1_grid_desc), decltype(c_m_n_grid_desc), diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_nk_mn.hpp similarity index 99% rename from host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp rename to library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_nk_mn.hpp index e99d5704136..557624026d5 100644 --- a/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_nk_mn.hpp @@ -499,7 +499,7 @@ void device_gemm_xdlops_mk_nk_mn(const Tensor& a_m_k, ABType, AccType, CType, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, decltype(a_k0_m_k1_grid_desc), decltype(b_k0_n_k1_grid_desc), decltype(c_m_n_grid_desc), diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_nk_nm.hpp similarity index 99% rename from host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp rename to library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_nk_nm.hpp index a12cf0733a8..06d8ed29404 100644 --- a/host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_nk_nm.hpp @@ -286,7 +286,7 @@ void device_gemm_xdlops_mk_nk_nm(const Tensor& a_m_k, ABType, AccType, CType, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, decltype(a_k0_m_k1_grid_desc), decltype(b_k0_n_k1_grid_desc), decltype(c_m_n_grid_desc), diff --git a/host/driver_offline/include/driver_contraction_dlops_v1r2.hpp b/library/include/ck/library/obselete_driver_offline/driver_contraction_dlops_v1r2.hpp similarity index 99% rename from host/driver_offline/include/driver_contraction_dlops_v1r2.hpp rename to library/include/ck/library/obselete_driver_offline/driver_contraction_dlops_v1r2.hpp index d207728a2e6..000098f4fca 100644 --- a/host/driver_offline/include/driver_contraction_dlops_v1r2.hpp +++ b/library/include/ck/library/obselete_driver_offline/driver_contraction_dlops_v1r2.hpp @@ -10,7 +10,7 @@ template +struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_add +{ + template + __host__ float Run(const ck::TensorDescriptor& wei_k_c0_y_x_c1_global_desc, + const ck::TensorDescriptor& in_n_c0_hi_wi_c1_global_desc, + const ck::TensorDescriptor& out_n_k0_ho_wo_k1_global_desc, + const ck::TensorDescriptor& add_n_k0_hox2_wox2_k1_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatC* __restrict__ p_bias_grid, + FloatC* __restrict__ p_d_grid, + const int nrepeat) const + { + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + const auto N = in_n_c0_hi_wi_c1_global_desc.GetLength(I0); + const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1); + const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2); + const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3); + // const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4); + + const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); + const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); + const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3); + const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4); + + const auto Hox2 = add_n_k0_hox2_wox2_k1_global_desc.GetLength(I2); + const auto Wox2 = add_n_k0_hox2_wox2_k1_global_desc.GetLength(I3); + + const auto K = wei_k_c0_y_x_c1_global_desc.GetLength(I0); + const auto Y = wei_k_c0_y_x_c1_global_desc.GetLength(I2); + const auto X = wei_k_c0_y_x_c1_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock; + const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock; + + const auto OutRightPadH = Hop - Ho; + const auto OutRightPadW = Wop - Wo; + + const auto OutRightPadHx = OutRightPadH * 2; + const auto OutRightPadWx = OutRightPadW * 2; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH; + const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW; + + const auto E = C0 * Y * X; + + constexpr auto E1 = Number{}; + constexpr auto E2 = Number{}; + constexpr auto K2 = Number{}; + + const auto E0 = E / E1; + + // weight tensor + const auto a_e_k_e2_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X, E2)), + make_tuple(make_pass_through_transform(K), + make_pass_through_transform(C0 * Y * X), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{})); + + const auto a_e0_e1_k_e2_grid_desc = + transform_tensor_descriptor(a_e_k_e2_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(E0, E1)), + make_pass_through_transform(K), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); + + // input tensor + const auto in_n_c0_hip_wip_e2_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)), + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C0), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_c0_y_ho_x_wo_e2_global_desc = transform_tensor_descriptor( + in_n_c0_hip_wip_e2_global_desc, + make_tuple( + make_pass_through_transform(N), + make_pass_through_transform(C0), + make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{})); + + const auto in_e_n_ho_wo_e2_grid_desc = transform_tensor_descriptor( + in_n_c0_y_ho_x_wo_e2_global_desc, + make_tuple(make_merge_transform(make_tuple(C0, Y, X)), + make_pass_through_transform(N), + make_pass_through_transform(Hop), + make_pass_through_transform(Wop), + make_pass_through_transform(E2)), + make_tuple( + Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto b_e0_e1_n_ho_wo_e2_grid_desc = transform_tensor_descriptor( + in_e_n_ho_wo_e2_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(E0, E1)), + make_pass_through_transform(N), + make_pass_through_transform(Hop), + make_pass_through_transform(Wop), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{})); + + // output tensor + const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)), + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_pad_transform(Ho, I0, OutRightPadH), + make_pad_transform(Wo, I0, OutRightPadW)), + make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + // add tensor + const auto d_k_n_hopx2_wopx2_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hox2, Wox2, K1)), + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_pad_transform(Hox2, I0, OutRightPadHx), + make_pad_transform(Wox2, I0, OutRightPadWx)), + make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl; + + if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 && + (E1 % E1PerBlock) == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + // clang-format off + + // hack to control index calculation when iterating over a_e0_e1_k_e2_global tensor + constexpr auto a_e0_e1_k_e2_global_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); + + constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; + + // hack to control index calculation when iterating over b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global tensor + constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = + make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}) + ); + + constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; + + // hack to control index calculation when iterating over c_k0_k1_n_h0_h1_h2_w0_w1_w2_global tensor + constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks = + make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); + + constexpr auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks = + make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); + + // clang-format on + + // GEMM + using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperationEnum::Set, + decltype(a_e0_e1_k_e2_grid_desc), + decltype(b_e0_e1_n_ho_wo_e2_grid_desc), + decltype(c_k_n_hop_wop_grid_desc), + decltype(d_k_n_hopx2_wopx2_grid_desc), + E1, + E2, + K2, + KPerBlock, + HoPerBlock, + WoPerBlock, + E1PerBlock, + KPerThread, + HoPerThread, + WoPerThread, + EPerThread, + ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, + ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, + Sequence<2, 3, 0, 1, 4>, + Sequence<0, 1, 2, 3, 4>, + 4, + ABlockTransferSrcScalarPerVector_E2, + ABlockTransferDstScalarPerVector_E2, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, // E0, E1, N, H0, H1, H2, W0, W1, W2, E2 + 9, + BThreadTransferSrcScalarPerVector_E2, + false, // don't move back src coordinate after threadwise copy, which will be fused with + // MoveSrcSliceWindow() to save addr computation + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, I2, H2, W0, W1, I2, W2 + 1, + CThreadTransferDstScalarPerVector_K, + decltype(a_e0_e1_k_e2_global_step_hacks), + decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks), + decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks), + decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks), + decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack), + decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>; + + const auto a_e0_e1_k0_k1_e2_grid_desc = + GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc); + const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc = + GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc); + const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = + GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc); + const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc = + GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd( + d_k_n_hopx2_wopx2_grid_desc); + + using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc); + using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = + decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc); + using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + using DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 = + decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc); + + const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; + + const bool has_main_e0_block_loop = E0 > 1; + + std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop << std::endl; + + const auto cblockid_to_k_n_h_w_block_cluster_adaptor = + GridwiseGemm::MakeCBlockIdToKNHoWoBlockClusterAdaptor(c_k_n_hop_wop_grid_desc); + + using CBlockIdToBlockClusterAdaptor_K_N_H_W = + decltype(cblockid_to_k_n_h_w_block_cluster_adaptor); + + float ave_time = 0; + + if(has_main_e0_block_loop) + { + const auto kernel = kernel_gemm_dlops_v3_resize_add< + GridwiseGemm, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + activ_type>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_bias_grid, + p_d_grid, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, + cblockid_to_k_n_h_w_block_cluster_adaptor); + } + else + { + const auto kernel = kernel_gemm_dlops_v3_resize_add< + GridwiseGemm, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + activ_type>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_bias_grid, + p_d_grid, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, + cblockid_to_k_n_h_w_block_cluster_adaptor); + } + + return ave_time; + } +}; +#endif diff --git a/library/include/ck/library/obselete_driver_offline/driver_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp b/library/include/ck/library/obselete_driver_offline/driver_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp new file mode 100644 index 00000000000..34296405d49 --- /dev/null +++ b/library/include/ck/library/obselete_driver_offline/driver_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp @@ -0,0 +1,386 @@ +#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP +#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_dlops_v3.hpp" + +template +struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad +{ + template + __host__ float Run(const ck::TensorDescriptor& wei_k_c0_y_x_c1_global_desc, + const ck::TensorDescriptor& in_n_c0_hi_wi_c1_global_desc, + const ck::TensorDescriptor& out_n_k0_ho_wo_k1_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatC* __restrict__ p_bias_grid, + FloatC* __restrict__ p_c_grid, + const int nrepeat) const + { + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + const auto N = in_n_c0_hi_wi_c1_global_desc.GetLength(I0); + const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1); + const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2); + const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3); + // const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4); + + const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); + const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); + const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3); + const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4); + + const auto K = wei_k_c0_y_x_c1_global_desc.GetLength(I0); + const auto Y = wei_k_c0_y_x_c1_global_desc.GetLength(I2); + const auto X = wei_k_c0_y_x_c1_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + +#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR + const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{}; + const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{}; +#else + const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock; + const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock; +#endif + + const auto OutRightPadH = Hop - Ho; + const auto OutRightPadW = Wop - Wo; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH; + const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW; + + const auto E = C0 * Y * X; + + constexpr auto E1 = Number{}; + constexpr auto E2 = Number{}; + constexpr auto K2 = Number{}; + + const auto E0 = E / E1; + + // weight tensor + const auto a_e_k_e2_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X, E2)), + make_tuple(make_pass_through_transform(K), + make_pass_through_transform(C0 * Y * X), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{})); + + const auto a_e0_e1_k_e2_grid_desc = + transform_tensor_descriptor(a_e_k_e2_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(E0, E1)), + make_pass_through_transform(K), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); + + // input tensor + const auto in_n_c0_hip_wip_e2_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)), + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C0), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_c0_y_ho_x_wo_e2_global_desc = transform_tensor_descriptor( + in_n_c0_hip_wip_e2_global_desc, + make_tuple( + make_pass_through_transform(N), + make_pass_through_transform(C0), + make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{})); + + const auto in_e_n_ho_wo_e2_grid_desc = transform_tensor_descriptor( + in_n_c0_y_ho_x_wo_e2_global_desc, + make_tuple(make_merge_transform(make_tuple(C0, Y, X)), + make_pass_through_transform(N), + make_pass_through_transform(Hop), + make_pass_through_transform(Wop), + make_pass_through_transform(E2)), + make_tuple( + Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto b_e0_e1_n_ho_wo_e2_grid_desc = transform_tensor_descriptor( + in_e_n_ho_wo_e2_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(E0, E1)), + make_pass_through_transform(N), + make_pass_through_transform(Hop), + make_pass_through_transform(Wop), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{})); + + // output tensor + const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)), + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_pad_transform(Ho, I0, OutRightPadH), + make_pad_transform(Wo, I0, OutRightPadW)), + make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl; + + if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 && + (E1 % E1PerBlock) == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + // clang-format off + + // hack to control index calculation when iterating over a_e0_e1_k_e2_global tensor + constexpr auto a_e0_e1_k_e2_global_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); + + constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; + + // hack to control index calculation when iterating over b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global tensor + constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = + make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}) + ); + + constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; + + // hack to control index calculation when iterating over c_k0_k1_n_h0_h1_h2_w0_w1_w2_global tensor + constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks = + make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); + // clang-format on + + // GEMM + using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperationEnum::Set, + decltype(a_e0_e1_k_e2_grid_desc), + decltype(b_e0_e1_n_ho_wo_e2_grid_desc), + decltype(c_k_n_hop_wop_grid_desc), + decltype(c_k_n_hop_wop_grid_desc), + E1, + E2, + K2, + KPerBlock, + HoPerBlock, + WoPerBlock, + E1PerBlock, + KPerThread, + HoPerThread, + WoPerThread, + EPerThread, + ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, + ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, + Sequence<2, 3, 0, 1, 4>, + Sequence<0, 1, 2, 3, 4>, + 4, + ABlockTransferSrcScalarPerVector_E2, + ABlockTransferDstScalarPerVector_E2, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, // E0, E1, N, H0, H1, H2, W0, W1, W2, E2 + 9, + BThreadTransferSrcScalarPerVector_E2, + false, // don't move back src coordinate after threadwise copy, which will be fused with + // MoveSrcSliceWindow() to save addr computation + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, H2, W0, W1, W2 + 1, + CThreadTransferDstScalarPerVector_K, + decltype(a_e0_e1_k_e2_global_step_hacks), + decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks), + decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks), + decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks), + decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack), + decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>; + + const auto a_e0_e1_k0_k1_e2_grid_desc = + GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc); + const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc = + GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc); + const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = + GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc); + + using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc); + using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = + decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc); + using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + + const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; + + const bool has_main_e0_block_loop = E0 > 1; + + std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop << std::endl; + + const auto cblockid_to_k_n_h_w_block_cluster_adaptor = + GridwiseGemm::MakeCBlockIdToKNHoWoBlockClusterAdaptor(c_k_n_hop_wop_grid_desc); + + using CBlockIdToBlockClusterAdaptor_K_N_H_W = + decltype(cblockid_to_k_n_h_w_block_cluster_adaptor); + + float ave_time = 0; + + if(has_main_e0_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + activ_type>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_bias_grid, + p_c_grid, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + cblockid_to_k_n_h_w_block_cluster_adaptor); + } + else + { + const auto kernel = + kernel_gemm_dlops_v3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + activ_type>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_bias_grid, + p_c_grid, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + cblockid_to_k_n_h_w_block_cluster_adaptor); + } + + return ave_time; + } +}; +#endif diff --git a/library/include/ck/library/obselete_driver_offline/driver_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp b/library/include/ck/library/obselete_driver_offline/driver_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp new file mode 100644 index 00000000000..1b8e48e6c1e --- /dev/null +++ b/library/include/ck/library/obselete_driver_offline/driver_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp @@ -0,0 +1,440 @@ +#ifndef DRIVER_CONVOLUTION_MAXPOOL_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP +#define DRIVER_CONVOLUTION_MAXPOOL_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_dlops_v3.hpp" + +template +struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_maxpool +{ + template + __host__ float Run(const ck::TensorDescriptor& wei_k_c0_y_x_c1_global_desc, + const ck::TensorDescriptor& in_n_c0_hi_wi_c1_global_desc, + const ck::TensorDescriptor& out_n_k0_ho_wo_k1_global_desc, + const ck::TensorDescriptor& max_n_k0_hx_wx_k1_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatC* __restrict__ p_bias_grid, + FloatC* __restrict__ p_c_grid, + FloatC* __restrict__ p_d_grid, + const int nrepeat) const + { + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + const auto N = in_n_c0_hi_wi_c1_global_desc.GetLength(I0); + const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1); + const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2); + const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3); + // const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4); + + const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); + const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); + const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3); + const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4); + + const auto Hx = max_n_k0_hx_wx_k1_global_desc.GetLength(I2); + const auto Wx = max_n_k0_hx_wx_k1_global_desc.GetLength(I3); + + const auto K = wei_k_c0_y_x_c1_global_desc.GetLength(I0); + const auto Y = wei_k_c0_y_x_c1_global_desc.GetLength(I2); + const auto X = wei_k_c0_y_x_c1_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + +#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR + const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{}; + const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{}; + + const auto OutRightPadH = Hop - Ho; + const auto OutRightPadW = Wop - Wo; + + const auto OutRightPadHx = Number{}; + const auto OutRightPadWx = Number{}; +#else + const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock; + const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock; + + const auto OutRightPadH = Hop - Ho; + const auto OutRightPadW = Wop - Wo; + + const auto OutRightPadHx = OutRightPadH / 2; + const auto OutRightPadWx = OutRightPadW / 2; +#endif + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH; + const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW; + + const auto E = C0 * Y * X; + + constexpr auto E1 = Number{}; + constexpr auto E2 = Number{}; + constexpr auto K2 = Number{}; + + const auto E0 = E / E1; + + // weight tensor + const auto a_e_k_e2_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X, E2)), + make_tuple(make_pass_through_transform(K), + make_pass_through_transform(C0 * Y * X), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{})); + + const auto a_e0_e1_k_e2_grid_desc = + transform_tensor_descriptor(a_e_k_e2_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(E0, E1)), + make_pass_through_transform(K), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); + + // input tensor + const auto in_n_c0_hip_wip_e2_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)), + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C0), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_c0_y_ho_x_wo_e2_global_desc = transform_tensor_descriptor( + in_n_c0_hip_wip_e2_global_desc, + make_tuple( + make_pass_through_transform(N), + make_pass_through_transform(C0), + make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{})); + + const auto in_e_n_ho_wo_e2_grid_desc = transform_tensor_descriptor( + in_n_c0_y_ho_x_wo_e2_global_desc, + make_tuple(make_merge_transform(make_tuple(C0, Y, X)), + make_pass_through_transform(N), + make_pass_through_transform(Hop), + make_pass_through_transform(Wop), + make_pass_through_transform(E2)), + make_tuple( + Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto b_e0_e1_n_ho_wo_e2_grid_desc = transform_tensor_descriptor( + in_e_n_ho_wo_e2_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(E0, E1)), + make_pass_through_transform(N), + make_pass_through_transform(Hop), + make_pass_through_transform(Wop), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{})); + + // output tensor + const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)), + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_pad_transform(Ho, I0, OutRightPadH), + make_pad_transform(Wo, I0, OutRightPadW)), + make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + // max tensor + const auto d_k_n_hx_wx_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hx, Wx, K1)), + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_pad_transform(Hx, I0, OutRightPadHx), + make_pad_transform(Wx, I0, OutRightPadWx)), + make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl; + + if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 && + (E1 % E1PerBlock) == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + // clang-format off + + // hack to control index calculation when iterating over a_e0_e1_k_e2_global tensor + constexpr auto a_e0_e1_k_e2_global_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); + + constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; + + // hack to control index calculation when iterating over b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global tensor + constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = + make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}) + ); + + constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; + + constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks = + make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); + + constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = + make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); + + // clang-format on + + // GEMM + using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperationEnum::Set, + decltype(a_e0_e1_k_e2_grid_desc), + decltype(b_e0_e1_n_ho_wo_e2_grid_desc), + decltype(c_k_n_hop_wop_grid_desc), + decltype(d_k_n_hx_wx_grid_desc), + E1, + E2, + K2, + KPerBlock, + HoPerBlock, + WoPerBlock, + E1PerBlock, + KPerThread, + HoPerThread, + WoPerThread, + EPerThread, + ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, + ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, + Sequence<2, 3, 0, 1, 4>, + Sequence<0, 1, 2, 3, 4>, + 4, + ABlockTransferSrcScalarPerVector_E2, + ABlockTransferDstScalarPerVector_E2, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, // E0, E1, N, H0, H1, H2, W0, W1, W2, E2 + 9, + BThreadTransferSrcScalarPerVector_E2, + false, // don't move back src coordinate after threadwise copy, which will be fused + // with MoveSrcSliceWindow() to save addr computation + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, I2, H2, W0, W1, I2, W2 + 1, + CThreadTransferDstScalarPerVector_K, + decltype(a_e0_e1_k_e2_global_step_hacks), + decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks), + decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks), + decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks), + decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack), + decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>; + + const auto a_e0_e1_k0_k1_e2_grid_desc = + GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc); + const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc = + GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc); + const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = + GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc); + const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = + GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(d_k_n_hx_wx_grid_desc); + + using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc); + using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = + decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc); + using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx = decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc); + + const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; + + const bool has_main_e0_block_loop = E0 > 1; + + std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop << std::endl; + + const auto cblockid_to_k_n_h_w_block_cluster_adaptor = + GridwiseGemm::MakeCBlockIdToKNHoWoBlockClusterAdaptor(c_k_n_hop_wop_grid_desc); + + using CBlockIdToBlockClusterAdaptor_K_N_H_W = + decltype(cblockid_to_k_n_h_w_block_cluster_adaptor); + + float ave_time = 0; + + if(has_main_e0_block_loop) + { + const auto kernel = kernel_gemm_dlops_v3_maxpool< + GridwiseGemm, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + activ_type>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_bias_grid, + p_c_grid, + p_d_grid, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + cblockid_to_k_n_h_w_block_cluster_adaptor); + } + else + { + const auto kernel = kernel_gemm_dlops_v3_maxpool< + GridwiseGemm, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + activ_type>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_bias_grid, + p_c_grid, + p_d_grid, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + cblockid_to_k_n_h_w_block_cluster_adaptor); + } + + return ave_time; + } +}; +#endif diff --git a/host/driver_offline/include/driver_gemm_dlops_v1r2.hpp b/library/include/ck/library/obselete_driver_offline/driver_gemm_dlops_v1r2.hpp similarity index 67% rename from host/driver_offline/include/driver_gemm_dlops_v1r2.hpp rename to library/include/ck/library/obselete_driver_offline/driver_gemm_dlops_v1r2.hpp index bf5f7f1c0f5..ce0530b3fd2 100644 --- a/host/driver_offline/include/driver_gemm_dlops_v1r2.hpp +++ b/library/include/ck/library/obselete_driver_offline/driver_gemm_dlops_v1r2.hpp @@ -10,7 +10,7 @@ template , - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - else - { - const auto kernel = - kernel_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - - return ave_time; -#endif } #endif diff --git a/host/driver_offline/include/driver_gemm_dlops_v1r3.hpp b/library/include/ck/library/obselete_driver_offline/driver_gemm_dlops_v1r3.hpp similarity index 65% rename from host/driver_offline/include/driver_gemm_dlops_v1r3.hpp rename to library/include/ck/library/obselete_driver_offline/driver_gemm_dlops_v1r3.hpp index 44709188208..3fd1a1dbbac 100644 --- a/host/driver_offline/include/driver_gemm_dlops_v1r3.hpp +++ b/library/include/ck/library/obselete_driver_offline/driver_gemm_dlops_v1r3.hpp @@ -10,7 +10,7 @@ template , - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space( - a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space( - a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space( - a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - else - { - const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space( - a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - - return ave_time; -#endif } #endif diff --git a/library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r3.hpp b/library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r3.hpp new file mode 100644 index 00000000000..5652040250e --- /dev/null +++ b/library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r3.hpp @@ -0,0 +1,220 @@ +#ifndef DRIVER_GEMM_XDLOPS_V2R3_HPP +#define DRIVER_GEMM_XDLOPS_V2R3_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" +#include "element_wise_operation.hpp" + +template +__host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K& b_grid_desc_k0_n_k1, + const CMNGridDesc& c_grid_desc_m_n, + ck::index_t M01, + ck::index_t N01, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks, + ck::index_t nrepeat) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + using ElementwiseOperation = ck::tensor_operation::element_wise::PassThrough; + + using GridwiseGemm = + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + { + std::cout << "a_grid_desc_k0_m_k1{" << a_grid_desc_k0_m_k1.GetLength(I0) << ", " + << a_grid_desc_k0_m_k1.GetLength(I1) << ", " << a_grid_desc_k0_m_k1.GetLength(I2) + << "}" << std::endl; + + std::cout << "b_grid_desc_k0_n_k1{" << b_grid_desc_k0_n_k1.GetLength(I0) << ", " + << b_grid_desc_k0_n_k1.GetLength(I1) << ", " << b_grid_desc_k0_n_k1.GetLength(I2) + << "}" << std::endl; + + std::cout << "c_grid_desc_m_n{ " << c_grid_desc_m_n.GetLength(I0) << ", " + << c_grid_desc_m_n.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_grid_desc_m_n, M01, N01)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); + + const auto block_2_ctile_map = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, M01, N01); + + using Block2CTileMap = decltype(block_2_ctile_map); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(c_grid_desc_m_n); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + auto element_op_ = ElementwiseOperation{}; + + if(has_main_k0_block_loop) + { + const auto kernel = + kernel_gemm_xdlops_v2r3, + remove_reference_t, + remove_reference_t, + ElementwiseOperation, + ElementwiseOperation, + ElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + element_op_, + element_op_, + element_op_, + block_2_ctile_map); + } + else + { + const auto kernel = + kernel_gemm_xdlops_v2r3, + remove_reference_t, + remove_reference_t, + ElementwiseOperation, + ElementwiseOperation, + ElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + element_op_, + element_op_, + element_op_, + block_2_ctile_map); + } + return ave_time; +} +#endif diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp b/library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r4.hpp similarity index 75% rename from host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp rename to library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r4.hpp index 30ecb02de13..6e9983b0b50 100644 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp +++ b/library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r4.hpp @@ -10,7 +10,7 @@ template , - remove_reference_t, - remove_reference_t, - remove_reference_t, - true>; - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - else - { - const auto kernel = kernel_gemm_xdlops_v2r4, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false>; - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } -#endif return ave_time; } #endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp new file mode 100644 index 00000000000..f4944a28d2e --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp @@ -0,0 +1,135 @@ +#ifndef REFERENCE_BATCHED_GEMM_HPP +#define REFERENCE_BATCHED_GEMM_HPP + +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceBatchedGemm : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_g_m_k, + const Tensor& b_g_k_n, + Tensor& c_g_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_g_m_k_{a_g_m_k}, + b_g_k_n_{b_g_k_n}, + c_g_m_n_{c_g_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_g_m_k_; + const Tensor& b_g_k_n_; + Tensor& c_g_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceBatchedGemm::Argument; + + float Run(const Argument& arg) + { + auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) { + const int K = arg.a_g_m_k_.mDesc.GetLengths()[2]; + + float v_acc = 0; + + for(int k = 0; k < K; ++k) + { + float v_a; + float v_b; + + arg.a_element_op_(v_a, static_cast(arg.a_g_m_k_(g, m, k))); + arg.b_element_op_(v_b, static_cast(arg.b_g_k_n_(g, k, n))); + + v_acc += v_a * v_b; + } + + float v_c; + + arg.c_element_op_(v_c, v_acc); + + arg.c_g_m_n_(g, m, n) = v_c; + }; + + make_ParallelTensorFunctor(f_gmk_gkn_gmn, + arg.c_g_m_n_.mDesc.GetLengths()[0], + arg.c_g_m_n_.mDesc.GetLengths()[1], + arg.c_g_m_n_.mDesc.GetLengths()[2])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_g_m_k, + const Tensor& b_g_k_n, + Tensor& c_g_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceBatchedGemm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp new file mode 100644 index 00000000000..4203085dbc6 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp @@ -0,0 +1,306 @@ +#pragma once + +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X] +template = 1 && NumDimSpatial <= 3, bool>::type = false> +struct ReferenceConvBwdWeight : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& in_n_c_hi_wi, + Tensor& wei_k_c_y_x, + const Tensor& out_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : input_{in_n_c_hi_wi}, + weight_{wei_k_c_y_x}, + output_{out_n_k_ho_wo}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + const Tensor& input_; + Tensor& weight_; + const Tensor& output_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvBwdWeight::Argument; + + float Run(const Argument& arg) + { + if constexpr(NumDimSpatial == 1) + { + constexpr auto I0 = Number<0>{}; + auto f_kcx = [&](auto k, auto c, auto x) { + float v_acc = 0; + for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n) + { + for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[2]; ++wo) + { + auto wi = + ck::type_convert(wo * arg.conv_strides_[I0]) + + ck::type_convert(x * arg.conv_dilations_[I0]) - + ck::type_convert(arg.in_left_pads_[I0]); + if(wi >= 0 && + ck::type_convert(wi) < arg.input_.mDesc.GetLengths()[2]) + { + float v_out; + float v_in; + + arg.out_element_op_(v_out, + ck::type_convert(arg.output_(n, k, wo))); + arg.in_element_op_(v_in, + ck::type_convert(arg.input_(n, c, wi))); + + v_acc += v_out * v_in; + } + } + } + float v_wei; + + arg.wei_element_op_(v_wei, v_acc); + + arg.weight_(k, c, x) = ck::type_convert(v_wei); + }; + + make_ParallelTensorFunctor(f_kcx, + arg.weight_.mDesc.GetLengths()[0], + arg.weight_.mDesc.GetLengths()[1], + arg.weight_.mDesc.GetLengths()[2])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NumDimSpatial == 2) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + auto f_kcyx = [&](auto k, auto c, auto y, auto x) { + float v_acc = 0; + for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n) + { + for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[2]; ++ho) + { + auto hi = + ck::type_convert(ho * arg.conv_strides_[I0]) + + ck::type_convert(y * arg.conv_dilations_[I0]) - + ck::type_convert(arg.in_left_pads_[I0]); + for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[3]; ++wo) + { + auto wi = + ck::type_convert(wo * arg.conv_strides_[I1]) + + ck::type_convert(x * + arg.conv_dilations_[I1]) - + ck::type_convert(arg.in_left_pads_[I1]); + if(hi >= 0 && + ck::type_convert(hi) < + arg.input_.mDesc.GetLengths()[2] && + wi >= 0 && + ck::type_convert(wi) < + arg.input_.mDesc.GetLengths()[3]) + { + float v_out; + float v_in; + + arg.out_element_op_( + v_out, ck::type_convert(arg.output_(n, k, ho, wo))); + arg.in_element_op_( + v_in, ck::type_convert(arg.input_(n, c, hi, wi))); + + v_acc += v_out * v_in; + } + } + } + } + float v_wei; + + arg.wei_element_op_(v_wei, v_acc); + + arg.weight_(k, c, y, x) = ck::type_convert(v_wei); + }; + + make_ParallelTensorFunctor(f_kcyx, + arg.weight_.mDesc.GetLengths()[0], + arg.weight_.mDesc.GetLengths()[1], + arg.weight_.mDesc.GetLengths()[2], + arg.weight_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NumDimSpatial == 3) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + auto f_kczyx = [&](auto k, auto c, auto z, auto y, auto x) { + float v_acc = 0; + for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n) + { + for(std::size_t do_ = 0; do_ < arg.output_.mDesc.GetLengths()[2]; ++do_) + { + auto di = + ck::type_convert(do_ * arg.conv_strides_[I0]) + + ck::type_convert(z * arg.conv_dilations_[I0]) - + ck::type_convert(arg.in_left_pads_[I0]); + for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[3]; ++ho) + { + auto hi = + ck::type_convert(ho * arg.conv_strides_[I1]) + + ck::type_convert(y * + arg.conv_dilations_[I1]) - + ck::type_convert(arg.in_left_pads_[I1]); + for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[4]; + ++wo) + { + auto wi = + ck::type_convert(wo * + arg.conv_strides_[I2]) + + ck::type_convert( + x * arg.conv_dilations_[I2]) - + ck::type_convert(arg.in_left_pads_[I2]); + if(di >= 0 && + ck::type_convert(di) < + arg.input_.mDesc.GetLengths()[2] && + hi >= 0 && + ck::type_convert(hi) < + arg.input_.mDesc.GetLengths()[3] && + wi >= 0 && + ck::type_convert(wi) < + arg.input_.mDesc.GetLengths()[4]) + { + float v_out; + float v_in; + + arg.out_element_op_(v_out, + ck::type_convert( + arg.output_(n, k, do_, ho, wo))); + arg.in_element_op_( + v_in, + ck::type_convert(arg.input_(n, c, di, hi, wi))); + + v_acc += v_out * v_in; + } + } + } + } + } + float v_wei; + + arg.wei_element_op_(v_wei, v_acc); + + arg.weight_(k, c, z, y, x) = ck::type_convert(v_wei); + }; + + make_ParallelTensorFunctor(f_kczyx, + arg.weight_.mDesc.GetLengths()[0], + arg.weight_.mDesc.GetLengths()[1], + arg.weight_.mDesc.GetLengths()[2], + arg.weight_.mDesc.GetLengths()[3], + arg.weight_.mDesc.GetLengths()[4])( + std::thread::hardware_concurrency()); + + return 0; + } + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /*stream_config*/ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& in_n_c_hi_wi, + Tensor& wei_k_c_y_x, + const Tensor& out_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvBwdWeight" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp new file mode 100644 index 00000000000..45fc8b85034 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp @@ -0,0 +1,355 @@ +#ifndef REFERENCE_CONV_BWD_DATA_HPP +#define REFERENCE_CONV_BWD_DATA_HPP + +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X] +template = 1 && NumDimSpatial <= 3, bool>::type = false> +struct ReferenceConvBwdData : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(Tensor& input, + const Tensor& weight, + const Tensor& output, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : input_{input}, + weight_{weight}, + output_{output}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + Tensor& input_; + const Tensor& weight_; + const Tensor& output_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvBwdData::Argument; + + float Run(const Argument& arg) + { + if constexpr(NumDimSpatial == 1) + { + auto f_ncw = [&](auto n, auto c, auto wi) { + std::size_t K = arg.weight_.mDesc.GetLengths()[0]; + std::size_t X = arg.weight_.mDesc.GetLengths()[2]; + std::size_t Wo = arg.output_.mDesc.GetLengths()[2]; + + AccDataType v_acc = 0; + + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = ck::type_convert(wi) + + ck::type_convert(arg.in_left_pads_[0]) - + ck::type_convert(x * arg.conv_dilations_[0]); + if(w_tmp % arg.conv_strides_[0] == 0) + { + auto wo = ck::type_convert(w_tmp) / + ck::type_convert(arg.conv_strides_[0]); + if(wo >= 0 && ck::type_convert(wo) < Wo) + { + for(std::size_t k = 0; k < K; ++k) + { + AccDataType v_out = 0; + AccDataType v_wei = 0; + + arg.out_element_op_( + v_out, + ck::type_convert(arg.output_(n, k, wo))); + arg.wei_element_op_( + v_wei, ck::type_convert(arg.weight_(k, c, x))); + + v_acc += v_out * v_wei; + } + } + } + } + + float v_in; + arg.in_element_op_(v_in, v_acc); + arg.input_(n, c, wi) = ck::type_convert(v_in); + }; + + make_ParallelTensorFunctor(f_ncw, + arg.input_.mDesc.GetLengths()[0], + arg.input_.mDesc.GetLengths()[1], + arg.input_.mDesc.GetLengths()[2])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NumDimSpatial == 2) + { + auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { + std::size_t K = arg.weight_.mDesc.GetLengths()[0]; + std::size_t Y = arg.weight_.mDesc.GetLengths()[2]; + std::size_t X = arg.weight_.mDesc.GetLengths()[3]; + + std::size_t Ho = arg.output_.mDesc.GetLengths()[2]; + std::size_t Wo = arg.output_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(std::size_t y = 0; y < Y; ++y) + { + auto h_tmp = ck::type_convert(hi) + + ck::type_convert(arg.in_left_pads_[0]) - + ck::type_convert(y * arg.conv_dilations_[0]); + if(h_tmp % arg.conv_strides_[0] == 0) + { + auto ho = ck::type_convert(h_tmp) / + ck::type_convert(arg.conv_strides_[0]); + if(ho >= 0 && ck::type_convert(ho) < Ho) + { + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = + ck::type_convert(wi) + + ck::type_convert(arg.in_left_pads_[1]) - + ck::type_convert(x * + arg.conv_dilations_[1]); + if(w_tmp % arg.conv_strides_[1] == 0) + { + auto wo = ck::type_convert(w_tmp) / + ck::type_convert( + arg.conv_strides_[1]); + if(wo >= 0 && ck::type_convert(wo) < Wo) + { + for(std::size_t k = 0; k < K; ++k) + { + AccDataType v_out = 0; + AccDataType v_wei = 0; + + arg.out_element_op_(v_out, + ck::type_convert( + arg.output_(n, k, ho, wo))); + arg.wei_element_op_(v_wei, + ck::type_convert( + arg.weight_(k, c, y, x))); + + v_acc += v_out * v_wei; + } + } + } + } + } + } + } + + AccDataType v_in; + arg.in_element_op_(v_in, v_acc); + arg.input_(n, c, hi, wi) = ck::type_convert(v_in); + }; + + make_ParallelTensorFunctor(f_nchw, + arg.input_.mDesc.GetLengths()[0], + arg.input_.mDesc.GetLengths()[1], + arg.input_.mDesc.GetLengths()[2], + arg.input_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NumDimSpatial == 3) + { + auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) { + std::size_t K = arg.weight_.mDesc.GetLengths()[0]; + std::size_t Z = arg.weight_.mDesc.GetLengths()[2]; + std::size_t Y = arg.weight_.mDesc.GetLengths()[3]; + std::size_t X = arg.weight_.mDesc.GetLengths()[4]; + + std::size_t Do = arg.output_.mDesc.GetLengths()[2]; + std::size_t Ho = arg.output_.mDesc.GetLengths()[3]; + std::size_t Wo = arg.output_.mDesc.GetLengths()[4]; + + AccDataType v_acc = 0; + + for(std::size_t z = 0; z < Z; ++z) + { + auto d_tmp = ck::type_convert(di) + + ck::type_convert(arg.in_left_pads_[0]) - + ck::type_convert(z * arg.conv_dilations_[0]); + if(d_tmp % arg.conv_strides_[0] == 0) + { + auto do_ = ck::type_convert(d_tmp) / + ck::type_convert(arg.conv_strides_[0]); + if(do_ >= 0 && ck::type_convert(do_) < Do) + { + for(std::size_t y = 0; y < Y; ++y) + { + auto h_tmp = + ck::type_convert(hi) + + ck::type_convert(arg.in_left_pads_[1]) - + ck::type_convert(y * + arg.conv_dilations_[1]); + if(h_tmp % arg.conv_strides_[1] == 0) + { + auto ho = ck::type_convert(h_tmp) / + ck::type_convert( + arg.conv_strides_[1]); + if(ho >= 0 && ck::type_convert(ho) < Ho) + { + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = + ck::type_convert(wi) + + ck::type_convert( + arg.in_left_pads_[2]) - + ck::type_convert( + x * arg.conv_dilations_[2]); + if(w_tmp % arg.conv_strides_[2] == 0) + { + auto wo = + ck::type_convert(w_tmp) / + ck::type_convert( + arg.conv_strides_[2]); + if(wo >= 0 && + ck::type_convert(wo) < Wo) + { + for(std::size_t k = 0; k < K; ++k) + { + AccDataType v_out = 0; + AccDataType v_wei = 0; + + arg.out_element_op_( + v_out, + ck::type_convert( + arg.output_( + n, k, do_, ho, wo))); + arg.wei_element_op_( + v_wei, + ck::type_convert( + arg.weight_(k, c, z, y, x))); + + v_acc += v_out * v_wei; + } + } + } + } + } + } + } + } + } + } + + AccDataType v_in; + arg.in_element_op_(v_in, v_acc); + arg.input_(n, c, di, hi, wi) = ck::type_convert(v_in); + }; + + make_ParallelTensorFunctor(f_ncdhw, + arg.input_.mDesc.GetLengths()[0], + arg.input_.mDesc.GetLengths()[1], + arg.input_.mDesc.GetLengths()[2], + arg.input_.mDesc.GetLengths()[3], + arg.input_.mDesc.GetLengths()[4])( + std::thread::hardware_concurrency()); + + return 0; + } + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(Tensor& input, + const Tensor& weight, + const Tensor& output, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{input, + weight, + output, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvBwdData" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp new file mode 100644 index 00000000000..d1afa898e40 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -0,0 +1,315 @@ +#pragma once + +#include +#include +#include + +#include "stream_config.hpp" +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// +// @brief Reference implementation for forward convolution. +// +// @paragraph Supports both NCHW as well as NHWC formats (and their respective +// counterparts for weight and output) as long as tensor descriptor +// lengths is in NCHW. +// +// @tparam InDataType Input tensor data type. +// @tparam WeiDataType Weights tensor data type. +// @tparam OutDataType Output tensor data type. +// @tparam InElementwiseOperation Functor for input tensor elementwise +// operation. +// @tparam WeiElementwiseOperation Functor for weights tensor elementwise +// operation. +// @tparam NumDimSpatial Number of spatial dimensions. +// +template = 1 && NumDimSpatial <= 3, bool>::type = false> +struct ReferenceConvFwd : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& input, + const Tensor& weight, + Tensor& output, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : input_{input}, + weight_{weight}, + output_{output}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + const Tensor& input_; + const Tensor& weight_; + Tensor& output_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvFwd::Argument; + + float Run(const Argument& arg) + { + if constexpr(NumDimSpatial == 1) + { + auto f_ncw = [&](auto n, auto k, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) + { + for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x) + { + auto wi = + ck::type_convert(wo * arg.conv_strides_[0]) + + ck::type_convert(x * arg.conv_dilations_[0]) - + ck::type_convert(arg.in_left_pads_[0]); + if(wi >= 0 && + ck::type_convert(wi) < arg.input_.mDesc.GetLengths()[2]) + { + float v_in; + float v_wei; + + arg.in_element_op_(v_in, + ck::type_convert(arg.input_(n, c, wi))); + arg.wei_element_op_(v_wei, + ck::type_convert(arg.weight_(k, c, x))); + + v_acc += v_in * v_wei; + } + } + } + + float v_out; + + arg.out_element_op_(v_out, v_acc); + arg.output_(n, k, wo) = ck::type_convert(v_out); + }; + + make_ParallelTensorFunctor(f_ncw, + arg.output_.mDesc.GetLengths()[0], + arg.output_.mDesc.GetLengths()[1], + arg.output_.mDesc.GetLengths()[2])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NumDimSpatial == 2) + { + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) + { + for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y) + { + auto hi = + ck::type_convert(ho * arg.conv_strides_[0]) + + ck::type_convert(y * arg.conv_dilations_[0]) - + ck::type_convert(arg.in_left_pads_[0]); + for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x) + { + auto wi = + ck::type_convert(wo * arg.conv_strides_[1]) + + ck::type_convert(x * arg.conv_dilations_[1]) - + ck::type_convert(arg.in_left_pads_[1]); + if(hi >= 0 && + ck::type_convert(hi) < + arg.input_.mDesc.GetLengths()[2] && + wi >= 0 && + ck::type_convert(wi) < + arg.input_.mDesc.GetLengths()[3]) + { + float v_in; + float v_wei; + + arg.in_element_op_( + v_in, ck::type_convert(arg.input_(n, c, hi, wi))); + arg.wei_element_op_( + v_wei, ck::type_convert(arg.weight_(k, c, y, x))); + v_acc += v_in * v_wei; + } + } + } + } + + float v_out; + + arg.out_element_op_(v_out, v_acc); + arg.output_(n, k, ho, wo) = ck::type_convert(v_out); + }; + + make_ParallelTensorFunctor(f_nchw, + arg.output_.mDesc.GetLengths()[0], + arg.output_.mDesc.GetLengths()[1], + arg.output_.mDesc.GetLengths()[2], + arg.output_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NumDimSpatial == 3) + { + auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) + { + for(std::size_t z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z) + { + auto di = + ck::type_convert(d_o * arg.conv_strides_[0]) + + ck::type_convert(z * arg.conv_dilations_[0]) - + ck::type_convert(arg.in_left_pads_[0]); + for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y) + { + auto hi = + ck::type_convert(ho * arg.conv_strides_[1]) + + ck::type_convert(y * arg.conv_dilations_[1]) - + ck::type_convert(arg.in_left_pads_[1]); + for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x) + { + auto wi = + ck::type_convert(wo * + arg.conv_strides_[2]) + + ck::type_convert(x * + arg.conv_dilations_[2]) - + ck::type_convert(arg.in_left_pads_[2]); + if(di >= 0 && + ck::type_convert(di) < + arg.input_.mDesc.GetLengths()[2] && + hi >= 0 && + ck::type_convert(hi) < + arg.input_.mDesc.GetLengths()[3] && + wi >= 0 && + ck::type_convert(wi) < + arg.input_.mDesc.GetLengths()[4]) + { + float v_in; + float v_wei; + + arg.in_element_op_( + v_in, + ck::type_convert(arg.input_(n, c, di, hi, wi))); + arg.wei_element_op_( + v_wei, + ck::type_convert(arg.weight_(k, c, z, y, x))); + v_acc += v_in * v_wei; + } + } + } + } + } + + float v_out; + + arg.out_element_op_(v_out, v_acc); + arg.output_(n, k, d_o, ho, wo) = ck::type_convert(v_out); + }; + + make_ParallelTensorFunctor(f_nchw, + arg.output_.mDesc.GetLengths()[0], + arg.output_.mDesc.GetLengths()[1], + arg.output_.mDesc.GetLengths()[2], + arg.output_.mDesc.GetLengths()[3], + arg.output_.mDesc.GetLengths()[4])( + std::thread::hardware_concurrency()); + + return 0; + } + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /*stream_config*/ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& input, + const Tensor& weight, + Tensor& output, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{input, + weight, + output, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvFwd" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp new file mode 100644 index 00000000000..4be6169c150 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp @@ -0,0 +1,190 @@ +#ifndef REFERENCE_CONV_FWD_BIAS_ACTIVATION_HPP +#define REFERENCE_CONV_FWD_BIAS_ACTIVATION_HPP + +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// out[N, Ho, Wo, K] = +// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) +template +struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : in_n_c_hi_wi_{in_n_c_hi_wi}, + wei_k_c_y_x_{wei_k_c_y_x}, + out_n_k_ho_wo_{out_n_k_ho_wo}, + bias_k_{bias_k}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + const Tensor& in_n_c_hi_wi_; + const Tensor& wei_k_c_y_x_; + Tensor& out_n_k_ho_wo_; + const Tensor& bias_k_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvFwd_Bias_Activation::Argument; + + float Run(const Argument& arg) + { + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) + { + for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) + { + auto hi = ck::type_convert(ho * arg.conv_strides_[0]) + + ck::type_convert(y * arg.conv_dilations_[0]) - + ck::type_convert(arg.in_left_pads_[0]); + for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) + { + auto wi = + ck::type_convert(wo * arg.conv_strides_[1]) + + ck::type_convert(x * arg.conv_dilations_[1]) - + ck::type_convert(arg.in_left_pads_[1]); + if(hi >= 0 && + ck::type_convert(hi) < + arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && + wi >= 0 && + ck::type_convert(wi) < + arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) + { + float v_in; + float v_wei; + + arg.in_element_op_( + v_in, + static_cast(arg.in_n_c_hi_wi_(n, c, hi, wi))); + arg.wei_element_op_( + v_wei, static_cast(arg.wei_k_c_y_x_(k, c, y, x))); + + v_acc += v_in * v_wei; + } + } + } + } + + float v_out; + + arg.out_element_op_(v_out, v_acc, static_cast(arg.bias_k_(k))); + + arg.out_n_k_ho_wo_(n, k, ho, wo) = v_out; + }; + + make_ParallelTensorFunctor(f_nchw, + arg.out_n_k_ho_wo_.mDesc.GetLengths()[0], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[1], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[2], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo, + bias_k, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvFwd_Bias_Activation" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp new file mode 100644 index 00000000000..466537c686a --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp @@ -0,0 +1,198 @@ +#ifndef REFERENCE_CONV2D_FWD_BIAS_ACTIVATION_ADD_HPP +#define REFERENCE_CONV2D_FWD_BIAS_ACTIVATION_ADD_HPP + +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// out[N, Ho, Wo, K] = +// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) + residual[N, Ho, Wo, K] +template +struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + const Tensor& resi_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : in_n_c_hi_wi_{in_n_c_hi_wi}, + wei_k_c_y_x_{wei_k_c_y_x}, + out_n_k_ho_wo_{out_n_k_ho_wo}, + bias_k_{bias_k}, + resi_n_k_ho_wo_{resi_n_k_ho_wo}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + const Tensor& in_n_c_hi_wi_; + const Tensor& wei_k_c_y_x_; + Tensor& out_n_k_ho_wo_; + const Tensor& bias_k_; + const Tensor& resi_n_k_ho_wo_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvFwd_Bias_Activation_Add::Argument; + + float Run(const Argument& arg) + { + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) + { + for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) + { + auto hi = ck::type_convert(ho * arg.conv_strides_[0]) + + ck::type_convert(y * arg.conv_dilations_[0]) - + ck::type_convert(arg.in_left_pads_[0]); + for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) + { + auto wi = + ck::type_convert(wo * arg.conv_strides_[1]) + + ck::type_convert(x * arg.conv_dilations_[1]) - + ck::type_convert(arg.in_left_pads_[1]); + if(hi >= 0 && + ck::type_convert(hi) < + arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && + wi >= 0 && + ck::type_convert(wi) < + arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) + { + float v_in; + float v_wei; + + arg.in_element_op_( + v_in, + static_cast(arg.in_n_c_hi_wi_(n, c, hi, wi))); + arg.wei_element_op_( + v_wei, static_cast(arg.wei_k_c_y_x_(k, c, y, x))); + + v_acc += v_in * v_wei; + } + } + } + } + + float v_out; + + arg.out_element_op_(v_out, + v_acc, + static_cast(arg.bias_k_(k)), + static_cast(arg.resi_n_k_ho_wo_(n, k, ho, wo))); + + arg.out_n_k_ho_wo_(n, k, ho, wo) = v_out; + }; + + make_ParallelTensorFunctor(f_nchw, + arg.out_n_k_ho_wo_.mDesc.GetLengths()[0], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[1], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[2], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /*stream_config*/ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + const Tensor& resi_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo, + bias_k, + resi_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvFwd_Bias_Activation_Add" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp new file mode 100644 index 00000000000..d89c8f5e050 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -0,0 +1,130 @@ +#pragma once +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGemm : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c_m_n_{c_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + Tensor& c_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemm::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + float v_acc = 0; + + for(int k = 0; k < K; ++k) + { + float v_a; + float v_b; + + arg.a_element_op_(v_a, static_cast(arg.a_m_k_(m, k))); + arg.b_element_op_(v_b, static_cast(arg.b_k_n_(k, n))); + + v_acc += v_a * v_b; + } + + float v_c; + + arg.c_element_op_(v_c, v_acc); + + arg.c_m_n_(m, n) = v_c; + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp new file mode 100644 index 00000000000..3e7f220e03d --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp @@ -0,0 +1,134 @@ +#ifndef REFERENCE_GEMM_BIAS_BIAS_2D_HPP +#define REFERENCE_GEMM_BIAS_BIAS_2D_HPP + +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGemmBias2D : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& c0_m_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c0_m_n_{c0_m_n}, + c_m_n_{c_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + const Tensor& c0_m_n_; + Tensor& c_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemmBias2D::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + AccDataType a = 0; + AccDataType b = 0; + AccDataType acc = 0; + + for(int k = 0; k < K; ++k) + { + arg.a_element_op_(a, arg.a_m_k_(m, k)); + arg.b_element_op_(b, arg.b_k_n_(k, n)); + acc += a * b; + } + + CDataType cast_acc = static_cast(acc); + arg.c_element_op_(arg.c_m_n_(m, n), cast_acc, arg.c0_m_n_(m, n)); + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& c0_m_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k, b_k_n, c0_m_n, c_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmBias2D" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp new file mode 100644 index 00000000000..60f72e9e510 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp @@ -0,0 +1,137 @@ +#ifndef REFERENCE_GEMM_BIAS_ACTIVATION_HPP +#define REFERENCE_GEMM_BIAS_ACTIVATION_HPP + +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGemmBiasActivation : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c_m_n_{c_m_n}, + c0_n_{c0_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + Tensor& c_m_n_; + const Tensor& c0_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemmBiasActivation::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + float v_acc = 0; + + for(int k = 0; k < K; ++k) + { + float v_a; + float v_b; + + arg.a_element_op_(v_a, static_cast(arg.a_m_k_(m, k))); + arg.b_element_op_(v_b, static_cast(arg.b_k_n_(k, n))); + + v_acc += v_a * v_b; + } + + float v_c; + + arg.c_element_op_(v_c, v_acc, static_cast(arg.c0_n_(n))); + + arg.c_m_n_(m, n) = v_c; + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k, b_k_n, c_m_n, c0_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmBiasActivation" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp new file mode 100644 index 00000000000..5e0ec75e5e8 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp @@ -0,0 +1,145 @@ +#ifndef REFERENCE_GEMM_BIAS_ACTIVATION_ADD_HPP +#define REFERENCE_GEMM_BIAS_ACTIVATION_ADD_HPP + +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGemmBiasActivationAdd : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n, + const Tensor& c1_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c_m_n_{c_m_n}, + c0_n_{c0_n}, + c1_m_n_{c1_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + Tensor& c_m_n_; + const Tensor& c0_n_; + const Tensor& c1_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemmBiasActivationAdd::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + float v_acc = 0; + + for(int k = 0; k < K; ++k) + { + float v_a; + float v_b; + + arg.a_element_op_(v_a, static_cast(arg.a_m_k_(m, k))); + arg.b_element_op_(v_b, static_cast(arg.b_k_n_(k, n))); + + v_acc += v_a * v_b; + } + + float v_c; + + arg.c_element_op_(v_c, + v_acc, + static_cast(arg.c0_n_(n)), + static_cast(arg.c1_m_n_(m, n))); + + arg.c_m_n_(m, n) = v_c; + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n, + const Tensor& c1_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{ + a_m_k, b_k_n, c_m_n, c0_n, c1_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmBiasActivationAdd" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp new file mode 100644 index 00000000000..120938f0722 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp @@ -0,0 +1,122 @@ +#ifndef NAIVE_CONV_FWD_HPP +#define NAIVE_CONV_FWD_HPP + +namespace ck { +namespace ref { + +/* + * \brief naive implementation of 3D convolution. Layout is (NDHWC, KZYXC, NDHWK). + * + * \param N number of batches + * \param K number of filters + * \param C number of channels of weight + * \param (Di, Hi, Wi) depth, height and width dimension of data + * \param (Z, Y, X) depth, height and width dimensions of weights + * \param (Do, Ho, Wo) depth, height and width dimension of output + * \param (stride_z, stride_y, stride_x) strides + * \param (dilation_z, dilation_y, dilation_x) dilations + * \param (pad_z, pad_y, pad_x) pads + */ +template +__global__ void naive_conv_fwd_ndhwc_kzyxc_ndhwk(const TIn* __restrict__ p_in, + const TWei* __restrict__ p_wei, + TOut* __restrict__ p_out, + index_t N, + index_t K, + index_t C, + index_t Di, + index_t Hi, + index_t Wi, + index_t Z, + index_t Y, + index_t X, + index_t Do, + index_t Ho, + index_t Wo, + index_t stride_z, + index_t stride_y, + index_t stride_x, + index_t dilation_z, + index_t dilation_y, + index_t dilation_x, + index_t pad_z, + index_t pad_y, + index_t pad_x) +{ + const index_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const index_t num_threads = blockDim.x * gridDim.x; + const long_index_t output_length = N * Do * Ho * Wo * K; + + const index_t out_strides[] = {Do * Ho * Wo * K, Ho * Wo * K, Wo * K, K}; + const index_t in_strides[] = {Di * Hi * Wi * C, Hi * Wi * C, Wi * C, C}; + const index_t wei_strides[] = {Z * Y * X * C, Y * X * C, X * C, C}; + + constexpr auto in_op = InElementwiseOperation{}; + constexpr auto wei_op = WeiElementwiseOperation{}; + constexpr auto out_op = OutElementwiseOperation{}; + + TIn in_val; + TWei wei_val; + TOut out_val; + + for(long_index_t ii = tid; ii < output_length; ii += num_threads) + { + const index_t n = ii / out_strides[0]; + index_t k = ii - n * out_strides[0]; + const index_t dO = k / out_strides[1]; + k -= dO * out_strides[1]; + const index_t ho = k / out_strides[2]; + k -= ho * out_strides[2]; + const index_t wo = k / out_strides[3]; + k -= wo * out_strides[3]; + + TAcc acc = static_cast(0); + + const TIn* in_n = p_in + static_cast(n) * in_strides[0]; + const TWei* wei_k = p_wei + static_cast(k) * wei_strides[0]; + + for(index_t z = 0; z < Z; ++z) + { + index_t di = stride_z * dO - pad_z + dilation_z * z; + const TIn* in_n_di = in_n + di * in_strides[1]; + const TWei* wei_k_z = wei_k + z * wei_strides[1]; + + for(index_t y = 0; y < Y; ++y) + { + index_t hi = stride_y * ho - pad_y + dilation_y * y; + const TIn* in_n_di_hi = in_n_di + hi * in_strides[2]; + const TWei* wei_k_z_y = wei_k_z + y * wei_strides[2]; + + for(index_t x = 0; x < X; ++x) + { + index_t wi = stride_x * wo - pad_x + dilation_x * x; + const TIn* in_n_di_hi_wi = in_n_di_hi + wi * in_strides[3]; + const TWei* wei_k_z_y_x = wei_k_z_y + x * wei_strides[3]; + + if(di >= 0 && di < Di && hi >= 0 && hi < Hi && wi >= 0 && wi < Wi) + { + for(index_t c = 0; c < C; ++c) + { + in_op(in_val, in_n_di_hi_wi[c]); + wei_op(wei_val, wei_k_z_y_x[c]); + acc += in_val * wei_val; + } + } + } + } + } + + out_op(out_val, static_cast(acc)); + p_out[ii] = out_val; + } +} +} // namespace ref +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance.hpp new file mode 100644 index 00000000000..40fd7274ef9 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance.hpp @@ -0,0 +1,26 @@ +#ifndef CK_DEVICE_OPERATION_INSTANCE_HPP +#define CK_DEVICE_OPERATION_INSTANCE_HPP + +#include + +namespace ck { +namespace tensor_operation { +namespace device { + +template +void add_device_operation_instances(std::vector>& op_instances, + const NewOpInstances& new_op_instances) +{ + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + const auto new_op_instance = std::get(new_op_instances); + + using NewOpInstance = remove_cvref_t; + + op_instances.push_back(std::make_unique(new_op_instance)); + }); +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp new file mode 100644 index 00000000000..6f0dbe75fff --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp @@ -0,0 +1,26 @@ +#ifndef DEVICE_REDUCE_INSTANTCE_HPP +#define DEVICE_REDUCE_INSTANTCE_HPP + +#include "device_reduce_instance_blockwise_f16_f16_f16.hpp" +#include "device_reduce_instance_blockwise_f16_f32_f16.hpp" +#include "device_reduce_instance_blockwise_f32_f32_f32.hpp" +#include "device_reduce_instance_blockwise_f32_f64_f32.hpp" +#include "device_reduce_instance_blockwise_f64_f64_f64.hpp" +#include "device_reduce_instance_blockwise_i8_i8_i8.hpp" +#include "device_reduce_instance_blockwise_i8_i32_i8.hpp" +#include "device_reduce_instance_blockwise_b16_f32_b16.hpp" +#include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp" +#include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp" +#include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp" +#include "device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp" +#include "device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp" +#include "device_reduce_instance_threadwise_f16_f16_f16.hpp" +#include "device_reduce_instance_threadwise_f16_f32_f16.hpp" +#include "device_reduce_instance_threadwise_f32_f32_f32.hpp" +#include "device_reduce_instance_threadwise_f32_f64_f32.hpp" +#include "device_reduce_instance_threadwise_f64_f64_f64.hpp" +#include "device_reduce_instance_threadwise_i8_i8_i8.hpp" +#include "device_reduce_instance_threadwise_i8_i32_i8.hpp" +#include "device_reduce_instance_threadwise_b16_f32_b16.hpp" + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp new file mode 100644 index 00000000000..e31d4e769ed --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp @@ -0,0 +1,186 @@ +#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_HPP +#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_HPP + +#include "reduction_operator_mapping.hpp" +#include "device_reduce_instance_impl_common.hpp" +#include "device_reduce_multiblock.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +using reduce_configuration_1_instances_blockwise = std::tuple< + // clang-format off + // BlockSize | MThreadClusterSize | KThreadClusterSize + ReductionConfiguration_1<256, 128, 2>, + ReductionConfiguration_1<256, 64, 4>, + ReductionConfiguration_1<256, 32, 8>, + ReductionConfiguration_1<256, 16, 16>, + ReductionConfiguration_1<256, 8, 32>, + ReductionConfiguration_1<256, 4, 64>, + ReductionConfiguration_1<256, 2, 128>, + ReductionConfiguration_1<256, 1, 256> + // clang-format on + >; + +#ifdef QUICK_REDUCE_TEST +using reduce_configuration_2_instances_blockwise = std::tuple< + // clang-format off + // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize + ReductionConfiguration_2<0, 2, 2, 2, 1>, + ReductionConfiguration_2<0, 1, 1, 2, 1>, + ReductionConfiguration_2<1, 2, 1, 1, 2>, + ReductionConfiguration_2<0, 1, 1, 3, 1>, + ReductionConfiguration_2<1, 1, 1, 1, 3> + // clang-format on + >; +#else +using reduce_configuration_2_instances_blockwise = std::tuple< + // clang-format off + // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize + ReductionConfiguration_2<0, 4, 4, 8, 1>, + ReductionConfiguration_2<0, 4, 4, 4, 1>, + ReductionConfiguration_2<0, 2, 2, 2, 1>, + + ReductionConfiguration_2<1, 4, 1, 1, 8>, + ReductionConfiguration_2<1, 4, 1, 1, 4>, + ReductionConfiguration_2<1, 2, 1, 1, 2>, + + // special instances + ReductionConfiguration_2<0, 1, 1, 3, 1>, + ReductionConfiguration_2<0, 1, 1, 5, 1>, + ReductionConfiguration_2<0, 1, 1, 7, 1>, + ReductionConfiguration_2<0, 1, 1, 11, 1>, + + ReductionConfiguration_2<1, 1, 1, 1, 3>, + ReductionConfiguration_2<1, 1, 1, 1, 5>, + ReductionConfiguration_2<1, 1, 1, 1, 7>, + ReductionConfiguration_2<1, 1, 1, 1, 11> + // clang-format on + >; +#endif + +template +using deviceReduceBlockWisePtrType = DeviceReducePtr< + typename reduce_unary_operator::InElementwiseOperation, + typename reduce_unary_operator::AccElementwiseOperation>; + +template +void add_device_reduce_instance_blockwise( + std::vector>& device_op_instances) +{ + using ReduceOperation = typename reduce_binary_operator::opType; + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + using AccElementwiseOperation = + typename reduce_unary_operator:: + AccElementwiseOperation; + + constexpr bool Indexable = + (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || + ReduceOpId == ReduceTensorOp::AMAX); + constexpr bool OutputIndex = Indexable && UseIndex; + + static_for<0, std::tuple_size::value, 1>{}( + [&](auto i) { + using cfg1 = remove_cvref_t(reduce_configuration_1_instances_blockwise{}))>; + + static_for<0, std::tuple_size::value, 1>{}( + [&](auto j) { + using cfg2 = remove_cvref_t(reduce_configuration_2_instances_blockwise{}))>; + + using ReduceOpInstance = + DeviceReduceMultiBlock; + + device_op_instances.push_back( + std::make_unique(ReduceOpInstance{})); + }); + }); +}; + +#define ADD_BLOCKWISE_INST_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + template void add_device_reduce_instance_blockwise( \ + std::vector> & device_op_instances) + +#define ADD_BLOCKWISE_INST_BY_ID( \ + inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + ADD_BLOCKWISE_INST_BY_TYPE(inT, \ + compT, \ + outT, \ + static_cast(ReduceOpId), \ + static_cast(NanOpt), \ + static_cast(IndicesOpt), \ + Rank, \ + NumReduceDim) + +#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + extern template void add_device_reduce_instance_blockwise( \ + std::vector::InElementwiseOperation, \ + typename reduce_unary_operator:: \ + AccElementwiseOperation>> & \ + device_op_instances) + +#define ADD_BLOCKWISE_INST_REF_BY_ID( \ + inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \ + compT, \ + outT, \ + static_cast(ReduceOpId), \ + static_cast(NanOpt), \ + static_cast(IndicesOpt), \ + Rank, \ + NumReduceDim) + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp new file mode 100644 index 00000000000..3cad45f2e5d --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp @@ -0,0 +1,59 @@ +#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP +#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP + +#include "data_type.hpp" +#include "device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1); + +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp new file mode 100644 index 00000000000..441c1aec3ff --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp @@ -0,0 +1,46 @@ +#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP +#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP + +#include "data_type.hpp" +#include "device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp new file mode 100644 index 00000000000..ca8532a458c --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp @@ -0,0 +1,34 @@ +#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP +#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP + +#include "data_type.hpp" +#include "device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp new file mode 100644 index 00000000000..64f504c9da5 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp @@ -0,0 +1,57 @@ +#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP +#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP + +#include "device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp new file mode 100644 index 00000000000..9e84ee34fb3 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp @@ -0,0 +1,33 @@ +#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP +#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP + +#include "device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp new file mode 100644 index 00000000000..a37e3bdeb91 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp @@ -0,0 +1,57 @@ +#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP +#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP + +#include "device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp new file mode 100644 index 00000000000..1d8695bbb0f --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp @@ -0,0 +1,29 @@ +#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP +#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP + +#include "device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp new file mode 100644 index 00000000000..b5c19b72072 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp @@ -0,0 +1,45 @@ +#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I8_I8_HPP +#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I8_I8_HPP + +#include "device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp new file mode 100644 index 00000000000..721d98a7189 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp @@ -0,0 +1,41 @@ +#ifndef DEVICE_REDUCE_INSTANCE_IMPL_COMMON_HPP +#define DEVICE_REDUCE_INSTANCE_IMPL_COMMON_HPP + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +template +struct ReductionConfiguration_1 +{ + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, "Invalid Configuration!"); + + static constexpr int BlockSize_ = BlockSize; + static constexpr int MThreadClusterSize_ = MThreadClusterSize; + static constexpr int KThreadClusterSize_ = KThreadClusterSize; +}; + +template +struct ReductionConfiguration_2 +{ + static constexpr int InSrcVectorDim_ = InSrcVectorDim; + static constexpr int InSrcVectorSize_ = InSrcVectorSize; + static constexpr int OutDstVectorSize_ = OutDstVectorSize; + static constexpr int MThreadSliceSize_ = MThreadSliceSize; + static constexpr int KThreadSliceSize_ = KThreadSliceSize; +}; + +#define QUICK_REDUCE_TEST 1 + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp new file mode 100644 index 00000000000..605109d0779 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp @@ -0,0 +1,208 @@ +#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_HPP +#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_HPP + +#include "reduction_operator_mapping.hpp" +#include "device_reduce_instance_impl_common.hpp" +#include "device_reduce_multiblock.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +using reduce_configuration_1_instances_multiblock_atomic_add = std::tuple< + // clang-format off + // BlockSize | MThreadClusterSize | KThreadClusterSize + ReductionConfiguration_1<256, 128, 2>, + ReductionConfiguration_1<256, 64, 4>, + ReductionConfiguration_1<256, 32, 8>, + ReductionConfiguration_1<256, 16, 16>, + ReductionConfiguration_1<256, 8, 32>, + ReductionConfiguration_1<256, 4, 64>, + ReductionConfiguration_1<256, 2, 128>, + ReductionConfiguration_1<256, 1, 256> + // clang-format on + >; + +#ifdef QUICK_REDUCE_TEST +using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple< + // clang-format off + // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize + ReductionConfiguration_2<0, 2, 2, 2, 1>, + ReductionConfiguration_2<0, 1, 1, 2, 1>, + ReductionConfiguration_2<1, 2, 1, 1, 2>, + ReductionConfiguration_2<0, 1, 1, 3, 1>, + ReductionConfiguration_2<1, 1, 1, 1, 3> + // clang-format on + >; +#else +using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple< + // clang-format off + // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize + ReductionConfiguration_2<0, 4, 4, 8, 1>, + ReductionConfiguration_2<0, 4, 4, 4, 1>, + ReductionConfiguration_2<0, 2, 2, 2, 1>, + + ReductionConfiguration_2<1, 4, 1, 1, 8>, + ReductionConfiguration_2<1, 4, 1, 1, 4>, + ReductionConfiguration_2<1, 2, 1, 1, 2>, + + // special instances + ReductionConfiguration_2<0, 1, 1, 3, 1>, + ReductionConfiguration_2<0, 1, 1, 5, 1>, + ReductionConfiguration_2<0, 1, 1, 7, 1>, + ReductionConfiguration_2<0, 1, 1, 11, 1>, + + ReductionConfiguration_2<1, 1, 1, 1, 3>, + ReductionConfiguration_2<1, 1, 1, 1, 5>, + ReductionConfiguration_2<1, 1, 1, 1, 7>, + ReductionConfiguration_2<1, 1, 1, 1, 11> + // clang-format on + >; +#endif + +template +using deviceReduceMultiBlockAtomicAddPtrType = + DeviceReducePtr:: + InElementwiseOperation, + typename reduce_unary_operator:: + AccElementwiseOperation>; + +template +void add_device_reduce_instance_multiblock_atomic_add( + std::vector>& + device_op_instances) +{ + using ReduceOperation = typename reduce_binary_operator::opType; + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + using AccElementwiseOperation = + typename reduce_unary_operator:: + AccElementwiseOperation; + + constexpr bool Indexable = + (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || + ReduceOpId == ReduceTensorOp::AMAX); + constexpr bool OutputIndex = Indexable && UseIndex; + + static_assert(UseIndex == false, + "AtomicAdd can only be used with reduction operations using no index!"); + + constexpr bool op_acceptable = + (ReduceOpId == ReduceTensorOp::ADD || ReduceOpId == ReduceTensorOp::MUL || + ReduceOpId == ReduceTensorOp::AVG || ReduceOpId == ReduceTensorOp::NORM1); + + constexpr bool out_type_acceptable = + (std::is_same::value || std::is_same::value); + + if constexpr(!op_acceptable || !out_type_acceptable) + return; + else + { + static_for<0, + std::tuple_size::value, + 1>{}([&](auto i) { + using cfg1 = remove_cvref_t(reduce_configuration_1_instances_multiblock_atomic_add{}))>; + + static_for< + 0, + std::tuple_size::value, + 1>{}([&](auto j) { + using cfg2 = remove_cvref_t(reduce_configuration_2_instances_multiblock_atomic_add{}))>; + + using ReduceOpInstance = + DeviceReduceMultiBlock; + + device_op_instances.push_back( + std::make_unique(ReduceOpInstance{})); + }); + }); + } +}; + +#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + template void add_device_reduce_instance_multiblock_atomic_add( \ + std::vector> & \ + device_op_instances) + +#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \ + inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE(inT, \ + compT, \ + outT, \ + static_cast(ReduceOpId), \ + static_cast(NanOpt), \ + static_cast(IndicesOpt), \ + Rank, \ + NumReduceDim) + +#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + extern template void add_device_reduce_instance_multiblock_atomic_add( \ + std::vector::InElementwiseOperation, \ + typename reduce_unary_operator:: \ + AccElementwiseOperation>> & \ + device_op_instances) + +#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \ + inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE(inT, \ + compT, \ + outT, \ + static_cast(ReduceOpId), \ + static_cast(NanOpt), \ + static_cast(IndicesOpt), \ + Rank, \ + NumReduceDim) + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp new file mode 100644 index 00000000000..4e39cf49f6f --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp @@ -0,0 +1,30 @@ +#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_B16_F32_F32_HPP +#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_B16_F32_F32_HPP + +#include "data_type.hpp" +#include "device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp new file mode 100644 index 00000000000..73424322ae2 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp @@ -0,0 +1,30 @@ +#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F16_F32_F32_HPP +#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F16_F32_F32_HPP + +#include "data_type.hpp" +#include "device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp new file mode 100644 index 00000000000..ecc9c4ea871 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp @@ -0,0 +1,29 @@ +#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F32_F32_HPP +#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F32_F32_HPP + +#include "device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp new file mode 100644 index 00000000000..41a60d5b70e --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp @@ -0,0 +1,29 @@ +#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F64_F32_HPP +#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F64_F32_HPP + +#include "device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp new file mode 100644 index 00000000000..bdcca274d7f --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp @@ -0,0 +1,29 @@ +#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F64_F64_F64_HPP +#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F64_F64_F64_HPP + +#include "device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp new file mode 100644 index 00000000000..a2b4ae22bee --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp @@ -0,0 +1,163 @@ +#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_HPP +#define DEVICE_REDUCE_INSTANCE_THREADWISE_HPP + +#include "reduction_operator_mapping.hpp" +#include "device_reduce_instance_impl_common.hpp" +#include "device_reduce_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +#ifdef QUICK_REDUCE_TEST +using reduce_configuration_2_instances_threadwise = std::tuple< + // clang-format off + // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize + ReductionConfiguration_2<0, 2, 2, 2, 1>, + ReductionConfiguration_2<0, 1, 1, 2, 1>, + ReductionConfiguration_2<1, 2, 1, 1, 2>, + ReductionConfiguration_2<0, 1, 1, 3, 1>, + ReductionConfiguration_2<1, 1, 1, 1, 3> + // clang-format on + >; +#else +using reduce_configuration_2_instances_threadwise = std::tuple< + // clang-format off + // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize + ReductionConfiguration_2<0, 4, 4, 8, 1>, + ReductionConfiguration_2<0, 4, 4, 4, 1>, + ReductionConfiguration_2<0, 2, 2, 2, 1>, + + ReductionConfiguration_2<1, 4, 1, 1, 8>, + ReductionConfiguration_2<1, 4, 1, 1, 4>, + ReductionConfiguration_2<1, 2, 1, 1, 2>, + + // special instances + ReductionConfiguration_2<0, 1, 1, 3, 1>, + ReductionConfiguration_2<0, 1, 1, 5, 1>, + ReductionConfiguration_2<0, 1, 1, 7, 1>, + ReductionConfiguration_2<0, 1, 1, 11, 1>, + + ReductionConfiguration_2<1, 1, 1, 1, 3>, + ReductionConfiguration_2<1, 1, 1, 1, 5>, + ReductionConfiguration_2<1, 1, 1, 1, 7>, + ReductionConfiguration_2<1, 1, 1, 1, 11> + // clang-format on + >; +#endif + +template +using deviceReduceThreadWisePtrType = DeviceReducePtr< + typename reduce_unary_operator::InElementwiseOperation, + typename reduce_unary_operator::AccElementwiseOperation>; + +template +void add_device_reduce_instance_threadwise( + std::vector>& device_op_instances) +{ + using ReduceOperation = typename reduce_binary_operator::opType; + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + using AccElementwiseOperation = + typename reduce_unary_operator:: + AccElementwiseOperation; + + constexpr bool Indexable = + (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || + ReduceOpId == ReduceTensorOp::AMAX); + constexpr bool OutputIndex = Indexable && UseIndex; + + using cfg1 = ReductionConfiguration_1<256, 256, 1>; + + static_for<0, std::tuple_size::value, 1>{}( + [&](auto j) { + using cfg2 = remove_cvref_t(reduce_configuration_2_instances_threadwise{}))>; + + using ReduceOpInstance = DeviceReduceThreadWise; + + device_op_instances.push_back(std::make_unique(ReduceOpInstance{})); + }); +}; + +#define ADD_THREADWISE_INST_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + template void add_device_reduce_instance_threadwise( \ + std::vector> & device_op_instances) + +#define ADD_THREADWISE_INST_BY_ID( \ + inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + ADD_THREADWISE_INST_BY_TYPE(inT, \ + compT, \ + outT, \ + static_cast(ReduceOpId), \ + static_cast(NanOpt), \ + static_cast(IndicesOpt), \ + Rank, \ + NumReduceDim) + +#define ADD_THREADWISE_INST_REF_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + extern template void add_device_reduce_instance_threadwise( \ + std::vector::InElementwiseOperation, \ + typename reduce_unary_operator:: \ + AccElementwiseOperation>> & \ + device_op_instances) + +#define ADD_THREADWISE_INST_REF_BY_ID( \ + inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + ADD_THREADWISE_INST_REF_BY_TYPE(inT, \ + compT, \ + outT, \ + static_cast(ReduceOpId), \ + static_cast(NanOpt), \ + static_cast(IndicesOpt), \ + Rank, \ + NumReduceDim) + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp new file mode 100644 index 00000000000..0291f332146 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp @@ -0,0 +1,59 @@ +#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_B16_F32_B16_HPP +#define DEVICE_REDUCE_INSTANCE_THREADWISE_B16_F32_B16_HPP + +#include "data_type.hpp" +#include "device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1); + +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp new file mode 100644 index 00000000000..7ab1bebc5f7 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp @@ -0,0 +1,46 @@ +#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F16_F16_HPP +#define DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F16_F16_HPP + +#include "data_type.hpp" +#include "device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp new file mode 100644 index 00000000000..39c3d106609 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp @@ -0,0 +1,34 @@ +#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F32_F16_HPP +#define DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F32_F16_HPP + +#include "data_type.hpp" +#include "device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp new file mode 100644 index 00000000000..3c47bfd1898 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp @@ -0,0 +1,57 @@ +#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F32_F32_HPP +#define DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F32_F32_HPP + +#include "device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp new file mode 100644 index 00000000000..9df9f6f1faf --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp @@ -0,0 +1,33 @@ +#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F64_F32_HPP +#define DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F64_F32_HPP + +#include "device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp new file mode 100644 index 00000000000..00ab218f206 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp @@ -0,0 +1,57 @@ +#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F64_F64_F64_HPP +#define DEVICE_REDUCE_INSTANCE_THREADWISE_F64_F64_F64_HPP + +#include "device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp new file mode 100644 index 00000000000..de7445b0437 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp @@ -0,0 +1,29 @@ +#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I32_I8_HPP +#define DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I32_I8_HPP + +#include "device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp new file mode 100644 index 00000000000..1ea1ee745e7 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp @@ -0,0 +1,45 @@ +#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I8_I8_HPP +#define DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I8_I8_HPP + +#include "device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp new file mode 100644 index 00000000000..7cd6cc34c9d --- /dev/null +++ b/library/include/ck/library/utility/check_err.hpp @@ -0,0 +1,195 @@ +#ifndef CHECK_ERR_HPP +#define CHECK_ERR_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "data_type.hpp" + +namespace ck { +namespace utils { + +template +typename std::enable_if::value && !std::is_same::value, + bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-5, + double atol = 3e-6) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + err = std::abs(out[i] - ref[i]); + if(err > atol + rtol * std::abs(ref[i]) || !std::isfinite(out[i]) || !std::isfinite(ref[i])) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" + << i << "]: " << out[i] << " != " << ref[i] << std::endl + << msg << std::endl; + } + res = false; + } + } + if(!res) + { + std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +typename std::enable_if::value, bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + // TODO: This is a hack. We should have proper specialization for bhalf_t data type. + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + double o = type_convert(out[i]); + double r = type_convert(ref[i]); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" + << i << "]: " << o << " != " << r << std::endl + << msg << std::endl; + } + res = false; + } + } + if(!res) + { + std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +typename std::enable_if::value || std::is_same::value, + bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + double o = type_convert(out[i]); + double r = type_convert(ref[i]); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" + << i << "]: " << o << " != " << r << std::endl + << msg << std::endl; + } + res = false; + } + } + if(!res) + { + std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +typename std::enable_if::value && !std::is_same::value, bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg = "Error: Incorrect results!", + double = 0, + double = 0) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + for(std::size_t i = 0; i < ref.size(); ++i) + { + if(out[i] != ref[i]) + { + std::cout << "out[" << i << "] != ref[" << i << "]: " << static_cast(out[i]) + << " != " << static_cast(ref[i]) << std::endl + << msg << std::endl; + return false; + } + } + return true; +} + +} // namespace utils +} // namespace ck + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); + return os; +} + +#endif diff --git a/library/include/ck/library/utility/conv_util.hpp b/library/include/ck/library/utility/conv_util.hpp new file mode 100644 index 00000000000..c881b897056 --- /dev/null +++ b/library/include/ck/library/utility/conv_util.hpp @@ -0,0 +1,571 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "device_conv_fwd.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "fill.hpp" +#include "host_tensor.hpp" +#include "op_instance_engine.hpp" +#include "reference_conv_fwd.hpp" +#include "tensor_layout.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr; +namespace device_conv1d_fwd_instance { + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(std::vector&); +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(std::vector&); +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(std::vector&); +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(std::vector&); + +} // namespace device_conv1d_fwd_instance +namespace device_conv2d_fwd_instance { + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector&); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector&); +void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( + std::vector&); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector&); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector&); + +} // namespace device_conv2d_fwd_instance +namespace device_conv3d_fwd_instance { + +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(std::vector&); +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(std::vector&); +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(std::vector&); +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(std::vector&); + +} // namespace device_conv3d_fwd_instance + +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace utils { +namespace conv { + +using DeviceConvFwdNoOpPtr = + ck::tensor_operation::device::DeviceConvFwdPtr; + +/** + * @brief Calculate number of FLOPs for Convolution + * + * @param[in] N Batch size. + * @param[in] C Number of input channels. + * @param[in] K Number of output channels. + * @param[in] filter_spatial_lengths Filter spatial dimensions lengths. + * @param[in] output_spatial_lengths Convolution output spatial dimensions + * lengths. + * + * @return The number of flops. + */ +std::size_t get_flops(ck::index_t N, + ck::index_t C, + ck::index_t K, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths); + +/** + * @brief Calculate number of bytes read/write by convolution algorithm. + * + * @param[in] N Batch size. + * @param[in] C Number of input channels. + * @param[in] K Number of output channels. + * @param[in] input_spatial_lengths Input spatial dimensions lengths. + * @param[in] filter_spatial_lengths Filter spatial dimensions lengths. + * @param[in] output_spatial_lengths Output spatial dimensions lengths + * + * @tparam InDataType Input tensor data type. + * @tparam WeiDataType Weights tensor data type. + * @tparam OutDataType Output tensor data type. + * + * @return The number of used bytes. + */ +template +std::size_t get_btype(ck::index_t N, + ck::index_t C, + ck::index_t K, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths) +{ + // sizeof(InDataType) * (N * C * ) + + // sizeof(WeiDataType) * (K * C * ) + + // sizeof(OutDataType) * (N * K * ); + return sizeof(InDataType) * (N * C * + std::accumulate(std::begin(input_spatial_lengths), + std::end(input_spatial_lengths), + static_cast(1), + std::multiplies())) + + sizeof(WeiDataType) * (K * C * + std::accumulate(std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths), + static_cast(1), + std::multiplies())) + + sizeof(OutDataType) * (N * K * + std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + static_cast(1), + std::multiplies())); +} + +struct ConvParams +{ + ConvParams(); + ConvParams(ck::index_t n_dim, + ck::index_t n_batch, + ck::index_t n_out_channels, + ck::index_t n_in_channels, + const std::vector& filters_len, + const std::vector& input_len, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads); + + ck::index_t num_dim_spatial_; + ck::index_t N_; + ck::index_t K_; + ck::index_t C_; + + std::vector filter_spatial_lengths_; + std::vector input_spatial_lengths_; + + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + + std::vector input_left_pads_; + std::vector input_right_pads_; + + std::vector GetOutputSpatialLengths() const; +}; + +ConvParams parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[]); + +/** + * @brief Gets the host tensor descriptor. + * + * @param[in] dims The tensor dimensions lengths. Always in NCHW format. + * @param[in] layout The tensor data layout. + * + * @tparam TensorLayout Layout type. + * + * @return The host tensor descriptor object. + */ +template +HostTensorDescriptor get_host_tensor_descriptor(const std::vector& dims, + const TensorLayout& layout) +{ + std::size_t C = dims[1]; + // 1D + if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + + return HostTensorDescriptor(dims, std::vector{C * dims[2], dims[2], 1}); + } + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + return HostTensorDescriptor(dims, std::vector{C * dims[2], 1, C}); + } + // 2D + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + + return HostTensorDescriptor( + dims, std::vector{C * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1}); + } + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + return HostTensorDescriptor( + dims, std::vector{C * dims[2] * dims[3], 1, dims[3] * C, C}); + } + // 3D + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + + return HostTensorDescriptor(dims, + std::vector{C * dims[2] * dims[3] * dims[4], + dims[2] * dims[3] * dims[4], + dims[3] * dims[4], + dims[4], + 1}); + } + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + return HostTensorDescriptor( + dims, + std::vector{ + C * dims[2] * dims[3] * dims[4], 1, C * dims[3] * dims[4], C * dims[4], C}); + } + + std::stringstream err_msg; + err_msg << "Unsupported data layout provided: " << layout << "!"; + throw std::runtime_error(err_msg.str()); +} + +HostTensorDescriptor get_output_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2); + +HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2); + +HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2); + +template +void run_reference_convolution_forward(const ConvParams& params, + const Tensor& input, + const Tensor& weights, + Tensor& output) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weights, + output, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); +} + +template +struct ConvolutionFwdInstances; + +template <> +struct ConvolutionFwdInstances +{ + template = 1 && NumDimSpatial <= 3, bool>::type = false> + static std::vector Get() + { + std::vector conv_ptrs; + if constexpr(NumDimSpatial == 1) + { + ck::tensor_operation::device::device_conv1d_fwd_instance:: + add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs); + } + else if constexpr(NumDimSpatial == 2) + { + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + } + else if constexpr(NumDimSpatial == 3) + { + ck::tensor_operation::device::device_conv3d_fwd_instance:: + add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + +template <> +struct ConvolutionFwdInstances +{ + template = 1 && NumDimSpatial <= 3, bool>::type = false> + static std::vector Get() + { + std::vector conv_ptrs; + if constexpr(NumDimSpatial == 1) + { + ck::tensor_operation::device::device_conv1d_fwd_instance:: + add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs); + return conv_ptrs; + } + else if constexpr(NumDimSpatial == 2) + { + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + } + else if constexpr(NumDimSpatial == 3) + { + ck::tensor_operation::device::device_conv3d_fwd_instance:: + add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + +template <> +struct ConvolutionFwdInstances +{ + template = 1 && NumDimSpatial <= 3, bool>::type = false> + static std::vector Get() + { + std::vector conv_ptrs; + if constexpr(NumDimSpatial == 1) + { + ck::tensor_operation::device::device_conv1d_fwd_instance:: + add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); + } + else if constexpr(NumDimSpatial == 2) + { + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); + } + else if constexpr(NumDimSpatial == 3) + { + ck::tensor_operation::device::device_conv3d_fwd_instance:: + add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + +template <> +struct ConvolutionFwdInstances +{ + template = 1 && NumDimSpatial <= 3, bool>::type = false> + static std::vector Get() + { + std::vector conv_ptrs; + if constexpr(NumDimSpatial == 1) + { + ck::tensor_operation::device::device_conv1d_fwd_instance:: + add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs); + } + else if constexpr(NumDimSpatial == 2) + { + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); + } + else if constexpr(NumDimSpatial == 3) + { + ck::tensor_operation::device::device_conv3d_fwd_instance:: + add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + +template , + typename WeightsInitFun = FillUniform> +class ConvFwdOpInstance : public ck::utils::OpInstance +{ + using DeviceConvFwdOp = tensor_operation::device:: + DeviceConvFwd; + using DeviceMemPtr = std::unique_ptr; + using DeviceBuffers = std::vector; + using BaseType = ck::utils::OpInstance; + template + using TensorPtr = std::unique_ptr>; + using InTensorsTuple = std::tuple, TensorPtr>; + + public: + ConvFwdOpInstance() = delete; + ConvFwdOpInstance(const ConvFwdOpInstance&) = default; + ConvFwdOpInstance& operator=(const ConvFwdOpInstance&) = default; + + ConvFwdOpInstance(const ConvParams& params, + bool do_init = true, + const InputInitFun& input_init_f = InputInitFun{}, + const WeightsInitFun& weights_init_f = WeightsInitFun{}) + : BaseType(), + params_{params}, + output_spatial_lengths_{params.GetOutputSpatialLengths()}, + do_init_{do_init}, + input_init_f_{input_init_f}, + weights_init_f_{weights_init_f} + { + } + + virtual ~ConvFwdOpInstance() override{}; + + virtual InTensorsTuple GetInputTensors() const override + { + std::vector input_dims{static_cast(params_.N_), + static_cast(params_.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params_.input_spatial_lengths_), + std::end(params_.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params_.K_), + static_cast(params_.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params_.filter_spatial_lengths_), + std::end(params_.filter_spatial_lengths_)); + + auto input = std::make_unique>( + get_host_tensor_descriptor(input_dims, InLayout{})); + auto weights = std::make_unique>( + get_host_tensor_descriptor(filter_dims, WeiLayout{})); + + if(do_init_) + { + input_init_f_(input->begin(), input->end()); + weights_init_f_(weights->begin(), weights->end()); + } + + return std::make_tuple(std::move(input), std::move(weights)); + } + + virtual TensorPtr GetOutputTensor() const override + { + std::vector output_dims{static_cast(params_.N_), + static_cast(params_.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths_), + std::end(output_spatial_lengths_)); + auto output = std::make_unique>( + get_host_tensor_descriptor(output_dims, OutLayout{})); + + if(do_init_) + { + std::fill(output->begin(), output->end(), OutDataType(0.f)); + } + return output; + } + + virtual std::unique_ptr + MakeInvokerPointer(tensor_operation::device::BaseOperator* op_ptr) const override + { + static_assert( + std::is_same_v); + static_assert( + std::is_same_v); + static_assert( + std::is_same_v); + + auto conv_ptr = dynamic_cast(op_ptr); + if(!conv_ptr) + { + throw std::runtime_error( + "[ConvFwdOpInstance]: couldn't cast op_ptr to DeviceConvFwdNoOpPtr type!"); + } + return conv_ptr->MakeInvokerPointer(); + } + + virtual std::unique_ptr + MakeArgumentPointer(tensor_operation::device::BaseOperator* op_ptr, + const DeviceBuffers& in_device_buffers, + const DeviceMemPtr& out_device_buffer) const override + { + static_assert( + std::is_same_v); + static_assert( + std::is_same_v); + static_assert( + std::is_same_v); + + auto conv_ptr = dynamic_cast(op_ptr); + if(!conv_ptr) + { + throw std::runtime_error( + "[ConvFwdOpInstance]: couldn't cast op_ptr to DeviceConvFwdNoOpPtr type!"); + } + + return conv_ptr->MakeArgumentPointer( + static_cast(in_device_buffers[0]->GetDeviceBuffer()), + static_cast(in_device_buffers[1]->GetDeviceBuffer()), + static_cast(out_device_buffer->GetDeviceBuffer()), + params_.N_, + params_.K_, + params_.C_, + params_.input_spatial_lengths_, + params_.filter_spatial_lengths_, + output_spatial_lengths_, + params_.conv_filter_strides_, + params_.conv_filter_dilations_, + params_.input_left_pads_, + params_.input_right_pads_, + InElementwiseOp{}, + WeiElementwiseOp{}, + OutElementwiseOp{}); + } + + virtual std::size_t GetFlops() const override + { + return get_flops(params_.N_, + params_.C_, + params_.K_, + params_.filter_spatial_lengths_, + output_spatial_lengths_); + } + + virtual std::size_t GetBtype() const override + { + return get_btype(params_.N_, + params_.C_, + params_.K_, + params_.input_spatial_lengths_, + params_.filter_spatial_lengths_, + output_spatial_lengths_); + } + + private: + const ConvParams& params_; + const std::vector output_spatial_lengths_; + const bool do_init_; + const InputInitFun& input_init_f_; + const WeightsInitFun& weights_init_f_; +}; + +} // namespace conv +} // namespace utils +} // namespace ck + +std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParams& p); diff --git a/library/include/ck/library/utility/fill.hpp b/library/include/ck/library/utility/fill.hpp new file mode 100644 index 00000000000..f44aec969d3 --- /dev/null +++ b/library/include/ck/library/utility/fill.hpp @@ -0,0 +1,81 @@ +#pragma once + +#include +#include + +#include "data_type.hpp" + +namespace ck { +namespace utils { + +// template +// struct FillUniform; + +// TODO: what's wrong with this specialization??? +// err: segmentation fault in mt19937 - infinite loop like. +// template +// struct FillUniform::value && +// !std::is_same::value>::type> +// { +// int a_{0}; +// int b_{5}; +// // T a_ = T{0}; +// // T b_ = T{5}; + +// template +// void operator()(ForwardIter first, ForwardIter last) const +// { +// std::mt19937 gen{11939}; +// std::uniform_int_distribution dis(a_, b_); +// std::generate(first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); +// } +// }; + +// struct FillUniform::value || +// std::is_same::value>::type> +template +struct FillUniform +{ + float a_{0}; + float b_{5}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen{11939}; + std::uniform_real_distribution<> dis(a_, b_); + std::generate(first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); + } +}; + +template +struct FillMonotonicSeq +{ + T init_value_{0}; + T step_{1}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::generate(first, last, [=, n = init_value_]() mutable { + auto tmp = n; + n += step_; + return tmp; + }); + } +}; + +template +struct FillConstant +{ + T value_{0}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::fill(first, last, value_); + } +}; + +} // namespace utils +} // namespace ck diff --git a/library/include/ck/library/utility/op_instance_engine.hpp b/library/include/ck/library/utility/op_instance_engine.hpp new file mode 100644 index 00000000000..5429f66d3ed --- /dev/null +++ b/library/include/ck/library/utility/op_instance_engine.hpp @@ -0,0 +1,231 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "device_base.hpp" +#include "functional2.hpp" + +namespace ck { +namespace utils { + +struct ProfileBestConfig +{ + std::string best_op_name; + float best_avg_time = std::numeric_limits::max(); + float best_tflops = std::numeric_limits::max(); + float best_gb_per_sec = std::numeric_limits::max(); +}; + +/** + * @brief This class describes an operation instance(s). + * + * Op instance defines a particular specializations of operator + * template. Thanks to this specific input/output data types, data + * layouts and modifying elementwise operations it is able to create + * it's input/output tensors, provide pointers to instances which + * can execute it and all operation specific parameters. + */ +template +class OpInstance +{ + public: + template + using TensorPtr = std::unique_ptr>; + using InTensorsTuple = std::tuple...>; + using DeviceMemPtr = std::unique_ptr; + using DeviceBuffers = std::vector; + + OpInstance() = default; + OpInstance(const OpInstance&) = default; + OpInstance& operator=(const OpInstance&) = default; + virtual ~OpInstance(){}; + + virtual InTensorsTuple GetInputTensors() const = 0; + virtual TensorPtr GetOutputTensor() const = 0; + virtual std::unique_ptr + MakeInvokerPointer(tensor_operation::device::BaseOperator*) const = 0; + virtual std::unique_ptr + MakeArgumentPointer(tensor_operation::device::BaseOperator*, + const DeviceBuffers&, + const DeviceMemPtr&) const = 0; + virtual std::size_t GetFlops() const = 0; + virtual std::size_t GetBtype() const = 0; +}; + +/** + * @brief A generic operation instance run engine. + */ +template +class OpInstanceRunEngine +{ + public: + using OpInstanceT = OpInstance; + template + using TensorPtr = std::unique_ptr>; + using DeviceMemPtr = std::unique_ptr; + using InTensorsTuple = std::tuple...>; + using DeviceBuffers = std::vector; + using InArgsTypesTuple = std::tuple; + + OpInstanceRunEngine() = delete; + + template > + OpInstanceRunEngine(const OpInstanceT& op_instance, + const ReferenceOp& reference_op = ReferenceOp{}) + : op_instance_{op_instance} + { + in_tensors_ = op_instance_.GetInputTensors(); + out_tensor_ = op_instance_.GetOutputTensor(); + + if constexpr(std::is_invocable_v&..., + Tensor&>) + { + ref_output_ = op_instance_.GetOutputTensor(); + CallRefOpUnpackArgs(reference_op, std::make_index_sequence{}); + } + AllocateDeviceInputTensors(std::make_index_sequence{}); + out_device_buffer_ = + std::make_unique(sizeof(OutDataType) * out_tensor_->mDesc.GetElementSpace()); + out_device_buffer_->SetZero(); + } + + virtual ~OpInstanceRunEngine(){}; + + template + bool Test(const std::vector& op_ptrs) + { + bool res{true}; + for(auto& op_ptr : op_ptrs) + { + auto invoker = op_instance_.MakeInvokerPointer(op_ptr.get()); + auto argument = op_instance_.MakeArgumentPointer( + op_ptr.get(), in_device_buffers_, out_device_buffer_); + if(op_ptr->IsSupportedArgument(argument.get())) + { + invoker->Run(argument.get()); + out_device_buffer_->FromDevice(out_tensor_->mData.data()); + if(!ref_output_) + { + throw std::runtime_error( + "OpInstanceRunEngine::Test: Reference value not availabe." + " You have to provide reference function."); + } + // TODO: enable flexible use of custom check_error functions + res = res && check_err(out_tensor_->mData, ref_output_->mData); + out_device_buffer_->SetZero(); + } + } + return res; + } + + template + ProfileBestConfig Profile(const std::vector& op_ptrs, + bool time_kernel = false, + bool do_verification = false, + bool do_log = false) + { + bool res{true}; + ProfileBestConfig best_config; + + for(auto& op_ptr : op_ptrs) + { + auto invoker = op_instance_.MakeInvokerPointer(op_ptr.get()); + auto argument = op_instance_.MakeArgumentPointer( + op_ptr.get(), in_device_buffers_, out_device_buffer_); + if(op_ptr->IsSupportedArgument(argument.get())) + { + std::string op_name = op_ptr->GetTypeString(); + float avg_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flops = op_instance_.GetFlops(); + std::size_t num_btype = op_instance_.GetBtype(); + float tflops = static_cast(flops) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << op_name << std::endl; + + if(tflops < best_config.best_tflops) + { + best_config.best_op_name = op_name; + best_config.best_tflops = tflops; + best_config.best_gb_per_sec = gb_per_sec; + best_config.best_avg_time = avg_time; + } + + if(do_verification) + { + out_device_buffer_->FromDevice(out_tensor_->mData.data()); + if(!ref_output_) + { + throw std::runtime_error( + "OpInstanceRunEngine::Profile: Reference value not availabe." + " You have to provide reference function."); + } + // TODO: enable flexible use of custom check_error functions + res = res && CheckErr(out_tensor_->mData, ref_output_->mData); + + if(do_log) {} + } + out_device_buffer_->SetZero(); + } + } + return best_config; + } + + void SetAtol(double a) { atol_ = a; } + void SetRtol(double r) { rtol_ = r; } + + private: + template + void CallRefOpUnpackArgs(const F& f, std::index_sequence) const + { + f(*std::get(in_tensors_)..., *ref_output_); + } + + template + void AllocateDeviceInputTensors(std::index_sequence) + { + (AllocateDeviceInputTensorsImpl(), ...); + } + + template + void AllocateDeviceInputTensorsImpl() + { + const auto& ts = std::get(in_tensors_); + in_device_buffers_ + .emplace_back( + std::make_unique(sizeof(std::tuple_element_t) * + ts->mDesc.GetElementSpace())) + ->ToDevice(ts->mData.data()); + } + + static constexpr std::size_t kNInArgs_ = std::tuple_size_v; + const OpInstanceT& op_instance_; + double rtol_{1e-5}; + double atol_{1e-8}; + + InTensorsTuple in_tensors_; + TensorPtr out_tensor_; + TensorPtr ref_output_; + + DeviceBuffers in_device_buffers_; + DeviceMemPtr out_device_buffer_; + + template + bool CheckErr(const std::vector& dev_out, const std::vector& ref_out) const + { + return ck::utils::check_err(dev_out, ref_out, "Error: incorrect results!", atol_, rtol_); + } +}; + +} // namespace utils +} // namespace ck diff --git a/library/src/host_tensor/CMakeLists.txt b/library/src/host_tensor/CMakeLists.txt new file mode 100644 index 00000000000..2a020b763dc --- /dev/null +++ b/library/src/host_tensor/CMakeLists.txt @@ -0,0 +1,40 @@ +## host_tensor +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/include/ck + ${PROJECT_SOURCE_DIR}/include/ck/utility + ${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor +) + +set(HOST_TENSOR_SOURCE + device.cpp + host_tensor.cpp +) + +add_library(host_tensor STATIC ${HOST_TENSOR_SOURCE}) +add_library(composable_kernel::host_tensor ALIAS host_tensor) + +target_compile_features(host_tensor PUBLIC) +set_target_properties(host_tensor PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_include_directories(host_tensor SYSTEM PUBLIC $) + +target_include_directories(host_tensor PUBLIC + "$" + "$" + "$" +) + +install(TARGETS host_tensor + EXPORT host_tensorTargets + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} +) + +install(EXPORT host_tensorTargets + FILE composable_kernelhost_tensorTargets.cmake + NAMESPACE composable_kernel:: + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel +) + +clang_tidy_check(host_tensor) diff --git a/library/src/host_tensor/device.cpp b/library/src/host_tensor/device.cpp new file mode 100644 index 00000000000..9f0d982dbc1 --- /dev/null +++ b/library/src/host_tensor/device.cpp @@ -0,0 +1,70 @@ +#include "device.hpp" + +DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) +{ + hip_check_error(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); +} + +void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; } + +std::size_t DeviceMem::GetBufferSize() { return mMemSize; } + +void DeviceMem::ToDevice(const void* p) +{ + hip_check_error(hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); +} + +void DeviceMem::FromDevice(void* p) +{ + hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); +} + +void DeviceMem::SetZero() { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); } + +DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); } + +struct KernelTimerImpl +{ + KernelTimerImpl() + { + hip_check_error(hipEventCreate(&mStart)); + hip_check_error(hipEventCreate(&mEnd)); + } + + ~KernelTimerImpl() + { + hip_check_error(hipEventDestroy(mStart)); + hip_check_error(hipEventDestroy(mEnd)); + } + + void Start() + { + hip_check_error(hipDeviceSynchronize()); + hip_check_error(hipEventRecord(mStart, nullptr)); + } + + void End() + { + hip_check_error(hipEventRecord(mEnd, nullptr)); + hip_check_error(hipEventSynchronize(mEnd)); + } + + float GetElapsedTime() const + { + float time; + hip_check_error(hipEventElapsedTime(&time, mStart, mEnd)); + return time; + } + + hipEvent_t mStart, mEnd; +}; + +KernelTimer::KernelTimer() : impl(new KernelTimerImpl()) {} + +KernelTimer::~KernelTimer() {} + +void KernelTimer::Start() { impl->Start(); } + +void KernelTimer::End() { impl->End(); } + +float KernelTimer::GetElapsedTime() const { return impl->GetElapsedTime(); } diff --git a/host/host_tensor/src/host_tensor.cpp b/library/src/host_tensor/host_tensor.cpp similarity index 63% rename from host/host_tensor/src/host_tensor.cpp rename to library/src/host_tensor/host_tensor.cpp index e840baf7f5f..138e3fc2549 100644 --- a/host/host_tensor/src/host_tensor.cpp +++ b/library/src/host_tensor/host_tensor.cpp @@ -1,6 +1,4 @@ -#include #include - #include "host_tensor.hpp" void HostTensorDescriptor::CalculateStrides() @@ -26,14 +24,33 @@ std::size_t HostTensorDescriptor::GetElementSize() const std::size_t HostTensorDescriptor::GetElementSpace() const { - auto ls = mLens | boost::adaptors::transformed([](std::size_t v) { return v - 1; }); - return std::inner_product(ls.begin(), ls.end(), mStrides.begin(), std::size_t{0}) + 1; + std::size_t space = 1; + for(std::size_t i = 0; i < mLens.size(); ++i) + { + space += (mLens[i] - 1) * mStrides[i]; + } + return space; } const std::vector& HostTensorDescriptor::GetLengths() const { return mLens; } const std::vector& HostTensorDescriptor::GetStrides() const { return mStrides; } +std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc) +{ + os << "dim " << desc.GetNumOfDimension() << ", "; + + os << "lengths {"; + LogRange(os, desc.GetLengths(), ", "); + os << "}, "; + + os << "strides {"; + LogRange(os, desc.GetStrides(), ", "); + os << "}"; + + return os; +} + void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os) { os << "dim " << desc.GetNumOfDimension() << ", "; @@ -46,3 +63,12 @@ void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream LogRange(os, desc.GetStrides(), ", "); os << "}" << std::endl; } + +#if 1 +// FIXME: remove +void bf16_to_f32_(const Tensor& src, Tensor& dst) +{ + for(std::size_t i = 0; i < src.mData.size(); ++i) + dst.mData[i] = ck::type_convert(src.mData[i]); +} +#endif diff --git a/host/driver_offline/CMakeLists.txt b/library/src/obselete_driver_offline/CMakeLists.txt similarity index 62% rename from host/driver_offline/CMakeLists.txt rename to library/src/obselete_driver_offline/CMakeLists.txt index a3b3613293e..54b13953279 100644 --- a/host/driver_offline/CMakeLists.txt +++ b/library/src/obselete_driver_offline/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(BEFORE include ${PROJECT_SOURCE_DIR}/host/host_tensor/include + ${PROJECT_SOURCE_DIR}/host/device/include ${PROJECT_SOURCE_DIR}/host/solver/include ${PROJECT_SOURCE_DIR}/composable_kernel/include ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility @@ -12,16 +13,25 @@ include_directories(BEFORE ) set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp) +set(CONV_FWD_DRIVER_OFFLINE_NCHWC_SOURCE src/conv_fwd_driver_offline_nchwc.cpp) +set(CONV_ADD_FWD_DRIVER_OFFLINE_NCHWC_SOURCE src/conv_add_fwd_driver_offline_nchwc.cpp) +set(CONV_MAXPOOL_FWD_DRIVER_OFFLINE_NCHWC_SOURCE src/conv_maxpool_fwd_driver_offline_nchwc.cpp) set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp) set(CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp) set(GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp) add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE}) +add_executable(conv_fwd_driver_offline_nchwc ${CONV_FWD_DRIVER_OFFLINE_NCHWC_SOURCE}) +add_executable(conv_add_fwd_driver_offline_nchwc ${CONV_ADD_FWD_DRIVER_OFFLINE_NCHWC_SOURCE}) +add_executable(conv_maxpool_fwd_driver_offline_nchwc ${CONV_MAXPOOL_FWD_DRIVER_OFFLINE_NCHWC_SOURCE}) add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE}) add_executable(conv_wrw_driver_offline ${CONV_WRW_DRIVER_OFFLINE_SOURCE}) add_executable(gemm_driver_offline ${GEMM_DRIVER_OFFLINE_SOURCE}) target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor) +target_link_libraries(conv_fwd_driver_offline_nchwc PRIVATE host_tensor) +target_link_libraries(conv_add_fwd_driver_offline_nchwc PRIVATE host_tensor) +target_link_libraries(conv_maxpool_fwd_driver_offline_nchwc PRIVATE host_tensor) target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor) target_link_libraries(conv_wrw_driver_offline PRIVATE host_tensor) target_link_libraries(gemm_driver_offline PRIVATE host_tensor) diff --git a/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp b/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp new file mode 100644 index 00000000000..a7541f03de8 --- /dev/null +++ b/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp @@ -0,0 +1,416 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "debug.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "conv_common.hpp" +#include "device_tensor.hpp" +#include "device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp" + +#define USE_DYNAMIC_MODE 0 +#define USE_CONV_FWD_V5R1_NCHWC 1 + +enum ConvForwardAlgo +{ + V5R1NCHWC // 0 +}; + +template +void host_direct_convolution_add_nchwc(const Tensor& in, + const Tensor& wei, + const Tensor& add, + const Tensor& bias, + Tensor& add_host, + Tensor& out_host, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads&, + const ck::ActivTypeEnum activ_type) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + auto f_nchw = [&](auto n, auto k0, auto ho, auto wo, auto k1) { + double v = 0; + auto k = k0 * out_host.mDesc.GetLengths()[4] + k1; + + for(int c0 = 0; c0 < wei.mDesc.GetLengths()[1]; ++c0) + { + for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + + for(int c1 = 0; c1 < wei.mDesc.GetLengths()[4]; ++c1) + { + v += static_cast(in(n, c0, hi, wi, c1)) * + static_cast(wei(k, c0, y, x, c1)); + } + } + } + } + } + + v += bias(k0, k1); + v = activ(v, activ_type); + + const int hox2 = ho * 2; + const int wox2 = wo * 2; + + out_host(n, k0, ho, wo, k1) = v; + + add_host(n, k0, hox2, wox2, k1) = v + add(n, k0, hox2, wox2, k1); + add_host(n, k0, hox2, wox2 + 1, k1) = v + add(n, k0, hox2, wox2 + 1, k1); + add_host(n, k0, hox2 + 1, wox2, k1) = v + add(n, k0, hox2 + 1, wox2, k1); + add_host(n, k0, hox2 + 1, wox2 + 1, k1) = v + add(n, k0, hox2 + 1, wox2 + 1, k1); + }; + + make_ParallelTensorFunctor(f_nchw, + out_host.mDesc.GetLengths()[0], + out_host.mDesc.GetLengths()[1], + out_host.mDesc.GetLengths()[2], + out_host.mDesc.GetLengths()[3], + out_host.mDesc.GetLengths()[4])(std::thread::hardware_concurrency()); +} + +int main(int argc, char* argv[]) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + +#if USE_DYNAMIC_MODE + // dynamic mode + if(argc != 23) + { + printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n"); + printf("rest: N, K0, K1, C0, C1, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu; + + const ConvForwardAlgo algo = static_cast(std::stoi(argv[1])); + const bool do_verification = std::stoi(argv[2]); + const int init_method = std::stoi(argv[3]); + const bool do_log = std::stoi(argv[4]); + const int nrepeat = std::stoi(argv[5]); + + const index_t N = std::stoi(argv[6]); + const index_t K0 = std::stoi(argv[7]); + const index_t K1 = std::stoi(argv[8]); + const index_t C0 = std::stoi(argv[9]); + const index_t C1 = std::stoi(argv[10]); + const index_t Y = std::stoi(argv[11]); + const index_t X = std::stoi(argv[12]); + const index_t Hi = std::stoi(argv[13]); + const index_t Wi = std::stoi(argv[14]); + + const index_t conv_stride_h = std::stoi(argv[15]); + const index_t conv_stride_w = std::stoi(argv[16]); + const index_t conv_dilation_h = std::stoi(argv[17]); + const index_t conv_dilation_w = std::stoi(argv[18]); + const index_t in_left_pad_h = std::stoi(argv[19]); + const index_t in_left_pad_w = std::stoi(argv[20]); + const index_t in_right_pad_h = std::stoi(argv[21]); + const index_t in_right_pad_w = std::stoi(argv[22]); + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + const auto Hox2 = Ho * 2; + const auto Wox2 = Wo * 2; +#else + // static mode + if(argc < 6) + { + printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n"); + exit(1); + } + + const ConvForwardAlgo algo = static_cast(std::stoi(argv[1])); + + const bool do_verification = std::stoi(argv[2]); + const int init_method = std::stoi(argv[3]); + const bool do_log = std::stoi(argv[4]); + const int nrepeat = std::stoi(argv[5]); + + constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu; + +#if 0 + constexpr auto N = Number<1>{}; + constexpr auto Hi = Number<1080>{}; + constexpr auto Wi = Number<1920>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + constexpr auto C0 = Number<2>{}; + constexpr auto C1 = Number<8>{}; + constexpr auto K1 = Number<8>{}; + constexpr auto K0 = Number<8>{}; +#elif 0 + constexpr auto N = Number<1>{}; + constexpr auto Hi = Number<540>{}; + constexpr auto Wi = Number<960>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + constexpr auto C0 = Number<2>{}; + constexpr auto C1 = Number<8>{}; + constexpr auto K0 = Number<2>{}; + constexpr auto K1 = Number<8>{}; +#elif 0 + constexpr auto N = Number<1>{}; + constexpr auto Hi = Number<270>{}; + constexpr auto Wi = Number<480>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + constexpr auto C0 = Number<2>{}; + constexpr auto C1 = Number<8>{}; + constexpr auto K0 = Number<2>{}; + constexpr auto K1 = Number<8>{}; +#elif 1 + constexpr auto N = Number<128>{}; + constexpr auto Hi = Number<135>{}; + constexpr auto Wi = Number<240>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + constexpr auto C0 = Number<2>{}; + constexpr auto C1 = Number<8>{}; + constexpr auto K0 = Number<2>{}; + constexpr auto K1 = Number<8>{}; +#elif 1 + constexpr auto N = Number<1>{}; + constexpr auto Hi = Number<32>{}; + constexpr auto Wi = Number<32>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + constexpr auto C0 = Number<2>{}; + constexpr auto C1 = Number<8>{}; + constexpr auto K1 = Number<8>{}; + constexpr auto K0 = Number<8>{}; +#endif + + constexpr auto conv_stride_h = I1; + constexpr auto conv_stride_w = I1; + constexpr auto conv_dilation_h = I1; + constexpr auto conv_dilation_w = I1; + constexpr auto in_left_pad_h = I1; + constexpr auto in_left_pad_w = I1; + constexpr auto in_right_pad_h = I1; + constexpr auto in_right_pad_w = I1; + + constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; + constexpr auto XEff = (X - I1) * conv_dilation_w + I1; + + constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; + constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; + + constexpr auto Hox2 = Number{}; + constexpr auto Wox2 = Number{}; + +#endif + +#if 0 + using in_data_t = float; + using acc_data_t = float; + using out_data_t = float; +#elif 1 + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; +#elif 1 + using in_data_t = int8_t; + using acc_data_t = int32_t; + using out_data_t = int8_t; +#endif + + std::vector in_lengths_host(5), wei_lengths_host(5), out_lengths_host(5), + add_lengths_host(5), bias_lengths_host(2); + + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(C0); + in_lengths_host[2] = static_cast(Hi); + in_lengths_host[3] = static_cast(Wi); + in_lengths_host[4] = static_cast(C1); + + wei_lengths_host[0] = static_cast(K0 * K1); + wei_lengths_host[1] = static_cast(C0); + wei_lengths_host[2] = static_cast(Y); + wei_lengths_host[3] = static_cast(X); + wei_lengths_host[4] = static_cast(C1); + + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(K0); + out_lengths_host[2] = static_cast(Ho); + out_lengths_host[3] = static_cast(Wo); + out_lengths_host[4] = static_cast(K1); + + add_lengths_host[0] = static_cast(N); + add_lengths_host[1] = static_cast(K0); + add_lengths_host[2] = static_cast(Hox2); + add_lengths_host[3] = static_cast(Wox2); + add_lengths_host[4] = static_cast(K1); + + bias_lengths_host[0] = static_cast(K0); + bias_lengths_host[1] = static_cast(K1); + + Tensor in(in_lengths_host); + Tensor wei(wei_lengths_host); + Tensor add(add_lengths_host); + Tensor add_device(add_lengths_host); + Tensor add_host(add_lengths_host); + Tensor bias(bias_lengths_host); + Tensor out_host(out_lengths_host); + + ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); + ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: "); + ostream_HostTensorDescriptor(add.mDesc, std::cout << "add: "); + + print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); + print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); + print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); + print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); + + std::size_t num_thread = 1; + + switch(init_method) + { + case 0: + // no initialization + break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 3: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 4: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 5: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + + auto gen_wei = [](auto... is) { + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + }; + wei.GenerateTensorValue(gen_wei, num_thread); + } + + bias.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + add.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + + auto f_make_for_device_nchwc = [&]() { + const auto in_lengths_dev = make_tuple(N, C0, Hi, Wi, C1); + const auto wei_lengths_dev = make_tuple(K0 * K1, C0, Y, X, C1); + const auto add_lengths_dev = make_tuple(N, K0, Hox2, Wox2, K1); + const auto out_lengths_dev = make_tuple(N, K0, Ho, Wo, K1); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + add_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + +#if USE_CONV_FWD_V5R1_NCHWC + if(algo == ConvForwardAlgo::V5R1NCHWC) + { + const auto tmp = f_make_for_device_nchwc(); + + device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1( + tmp[I0], // in_lengths_dev + tmp[I1], // wei_lengths_dev + tmp[I2], // add_lengths_dev + tmp[I3], // out_lengths_dev + tmp[I4], // conv_strides_dev + tmp[I5], // conv_dilations_dev + tmp[I6], // in_left_pads_dev + tmp[I7], // in_right_pads_dev + in, + wei, + bias, + add, + add_device, + nrepeat); + } +#endif + + if(do_verification) + { + host_direct_convolution_add_nchwc(in, + wei, + add, + bias, + add_host, + out_host, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + activ_type); + + ck::utils::check_err(add_device.mData, add_host.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; + LogRangeAsType(std::cout << "add_host: ", add_host.mData, ",") << std::endl; + LogRangeAsType(std::cout << "add_device: ", add_device.mData, ",") << std::endl; + } + } +} diff --git a/host/driver_offline/src/conv_bwd_driver_offline.cpp b/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp similarity index 65% rename from host/driver_offline/src/conv_bwd_driver_offline.cpp rename to library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp index 366b5dffbce..c4dcb7c0853 100644 --- a/host/driver_offline/src/conv_bwd_driver_offline.cpp +++ b/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "debug.hpp" #include "print.hpp" @@ -11,7 +13,6 @@ #include "host_tensor.hpp" #include "host_tensor_generator.hpp" #include "conv_common.hpp" -#include "host_conv_bwd_data.hpp" #include "device_tensor.hpp" #include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp" @@ -21,12 +22,153 @@ #define USE_CONV_BWD_V4R1_XDL_NHWC 0 #define USE_CONV_BWD_V4R1R2_XDL_NHWC 1 +enum ConvTensorLayout +{ + NCHW, + NHWC, + CHWN, + NCHWc, + NHWCc +}; + enum ConvBackwardDataAlgo { V4R1XDLNHWC, // 0 V4R1R2XDLNHWC, // 1 }; +template +void host_convolution_backward_data(Tensor& in, + const Tensor& wei, + const Tensor& out, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& /* in_right_pads */, + const ConvTensorLayout layout = ConvTensorLayout::NCHW) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { + std::size_t K = wei.mDesc.GetLengths()[I0]; + std::size_t Y = wei.mDesc.GetLengths()[I2]; + std::size_t X = wei.mDesc.GetLengths()[I3]; + + std::size_t Ho = out.mDesc.GetLengths()[I2]; + std::size_t Wo = out.mDesc.GetLengths()[I3]; + + double v = 0; + + for(int y = 0; y < Y; ++y) + { + int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0]; + + if(h_tmp % conv_strides[I0] == 0) + { + int ho = h_tmp / conv_strides[I0]; + + if(ho >= 0 && ho < Ho) + { + for(int x = 0; x < X; ++x) + { + int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1]; + + if(w_tmp % conv_strides[I1] == 0) + { + int wo = w_tmp / conv_strides[I1]; + + if(wo >= 0 && wo < Wo) + { + for(int k = 0; k < K; ++k) + { + v += out(n, k, ho, wo) * wei(k, c, y, x); + } + } + } + } + } + } + } + + in(n, c, hi, wi) = v; + }; + + auto f_nhwc = [&](auto n, auto hi, auto wi, auto c) { + std::size_t K = wei.mDesc.GetLengths()[I0]; + std::size_t Y = wei.mDesc.GetLengths()[I1]; + std::size_t X = wei.mDesc.GetLengths()[I2]; + + std::size_t Ho = out.mDesc.GetLengths()[I1]; + std::size_t Wo = out.mDesc.GetLengths()[I2]; + + double v = 0; + + for(int y = 0; y < Y; ++y) + { + int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0]; + + if(h_tmp % conv_strides[I0] == 0) + { + int ho = h_tmp / conv_strides[I0]; + + if(ho >= 0 && ho < Ho) + { + for(int x = 0; x < X; ++x) + { + int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1]; + + if(w_tmp % conv_strides[I1] == 0) + { + int wo = w_tmp / conv_strides[I1]; + + if(wo >= 0 && wo < Wo) + { + for(int k = 0; k < K; ++k) + { + v += out(n, ho, wo, k) * wei(k, y, x, c); + } + } + } + } + } + } + } + + in(n, hi, wi, c) = v; + }; + + if(layout == ConvTensorLayout::NCHW) + { + make_ParallelTensorFunctor(f_nchw, + in.mDesc.GetLengths()[0], + in.mDesc.GetLengths()[1], + in.mDesc.GetLengths()[2], + in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else if(layout == ConvTensorLayout::NHWC) + { + make_ParallelTensorFunctor(f_nhwc, + in.mDesc.GetLengths()[0], + in.mDesc.GetLengths()[1], + in.mDesc.GetLengths()[2], + in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("wrong! not supported layout"); + } +} int main(int argc, char* argv[]) { using namespace ck; @@ -177,7 +319,7 @@ int main(int argc, char* argv[]) print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); - std::size_t num_thread = std::thread::hardware_concurrency(); + std::size_t num_thread = 1; switch(init_method) { @@ -185,30 +327,30 @@ int main(int argc, char* argv[]) // no initialization break; case 1: - out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); break; case 2: - out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; case 3: - out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); break; case 4: - out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; case 5: - out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); break; default: - out.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); auto gen_wei = [](auto... is) { - return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); }; wei.GenerateTensorValue(gen_wei, num_thread); } @@ -324,16 +466,16 @@ int main(int argc, char* argv[]) if(do_verification) { - host_direct_convolution_backward_data(in_host, - wei, - out, - make_tuple(conv_stride_h, conv_stride_w), - make_tuple(conv_dilation_h, conv_dilation_w), - make_tuple(in_left_pad_h, in_left_pad_w), - make_tuple(in_right_pad_h, in_right_pad_w), - layout); - - check_error(in_host, in_device); + host_convolution_backward_data(in_host, + wei, + out, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + layout); + + ck::utils::check_err(in_device.mData, in_host.mData); if(do_log) { diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp similarity index 71% rename from host/driver_offline/src/conv_fwd_driver_offline.cpp rename to library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp index 48eba2b3725..ab8beec87bf 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "debug.hpp" #include "print.hpp" @@ -11,12 +13,10 @@ #include "host_tensor.hpp" #include "host_tensor_generator.hpp" #include "conv_common.hpp" -#include "host_conv.hpp" #include "device_tensor.hpp" #include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" -#include "device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" @@ -24,20 +24,145 @@ #define USE_CONV_FWD_V4R4_NCHW 0 #define USE_CONV_FWD_V4R4R2_NHWC 0 #define USE_CONV_FWD_V6R1_NCHW 0 -#define USE_CONV_FWD_V5R1_NCHW 0 #define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 #define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 +enum ConvTensorLayout +{ + NCHW, + NHWC, + CHWN, + NCHWc, + NHWCc +}; + enum ConvForwardAlgo { V4R4NCHW, // 0 V4R4R2NHWC, // 1 V6R1NCHW, // 2 - V5R1NCHW, // 3 - V4R4R2XDLNCHW, // 4 - V4R4R4XDLNHWC // 5 + V4R4R2XDLNCHW, // 3 + V4R4R4XDLNHWC // 4 }; +template +void host_convolution_forward(const Tensor& in, + const Tensor& wei, + Tensor& out, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads&, + const ConvTensorLayout layout = ConvTensorLayout::NCHW) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + double v = 0; + for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) + { + for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + if constexpr(is_same::value) + { + v += ck::type_convert(in(n, c, hi, wi)) * + ck::type_convert(wei(k, c, y, x)); + } + else + { + v += static_cast(in(n, c, hi, wi)) * + static_cast(wei(k, c, y, x)); + } + } + } + } + } + + if constexpr(is_same::value) + { + out(n, k, ho, wo) = ck::type_convert(static_cast(v)); + } + else + { + out(n, k, ho, wo) = v; + } + }; + + auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) { + double v = 0; + for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c) + { + for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && + wi < in.mDesc.GetLengths()[2]) + { + if constexpr(is_same::value) + { + v += ck::type_convert(in(n, hi, wi, c)) * + ck::type_convert(wei(k, y, x, c)); + } + else + { + v += static_cast(in(n, hi, wi, c)) * + static_cast(wei(k, y, x, c)); + } + } + } + } + } + if constexpr(is_same::value) + { + out(n, ho, wo, k) = ck::type_convert(static_cast(v)); + } + else + { + out(n, ho, wo, k) = v; + } + }; + + if(layout == ConvTensorLayout::NCHW) + { + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else if(layout == ConvTensorLayout::NHWC) + { + make_ParallelTensorFunctor(f_nhwc, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("wrong! not supported layout"); + } +} + int main(int argc, char* argv[]) { using namespace ck; @@ -111,8 +236,8 @@ int main(int argc, char* argv[]) constexpr auto Y = Number<3>{}; constexpr auto X = Number<3>{}; - constexpr auto conv_stride_h = I2; - constexpr auto conv_stride_w = I2; + constexpr auto conv_stride_h = I1; + constexpr auto conv_stride_w = I1; constexpr auto conv_dilation_h = I1; constexpr auto conv_dilation_w = I1; constexpr auto in_left_pad_h = I1; @@ -127,7 +252,7 @@ int main(int argc, char* argv[]) constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; #endif -#if 0 +#if 1 using in_data_t = float; using acc_data_t = float; using out_data_t = float; @@ -135,6 +260,10 @@ int main(int argc, char* argv[]) using in_data_t = half_t; using acc_data_t = float; using out_data_t = half_t; +#elif 0 + using in_data_t = bhalf_t; + using acc_data_t = float; + using out_data_t = bhalf_t; #elif 1 using in_data_t = int8_t; using acc_data_t = int32_t; @@ -192,7 +321,7 @@ int main(int argc, char* argv[]) print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); - std::size_t num_thread = std::thread::hardware_concurrency(); + std::size_t num_thread = 1; switch(init_method) { @@ -200,30 +329,30 @@ int main(int argc, char* argv[]) // no initialization break; case 1: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); break; case 2: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; case 3: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); break; case 4: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; case 5: - in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); break; default: - in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); auto gen_wei = [](auto... is) { - return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); }; wei.GenerateTensorValue(gen_wei, num_thread); } @@ -342,33 +471,6 @@ int main(int argc, char* argv[]) } #endif -#if USE_CONV_FWD_V5R1_NCHW - if(algo == ConvForwardAlgo::V5R1NCHW) - { - if(layout != ConvTensorLayout::NCHW) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nchw(); - - device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); - } -#endif - #if USE_CONV_FWD_V4R4R2_XDL_NCHW if(algo == ConvForwardAlgo::V4R4R2XDLNCHW) { @@ -425,16 +527,16 @@ int main(int argc, char* argv[]) if(do_verification) { - host_direct_convolution(in, - wei, - out_host, - make_tuple(conv_stride_h, conv_stride_w), - make_tuple(conv_dilation_h, conv_dilation_w), - make_tuple(in_left_pad_h, in_left_pad_w), - make_tuple(in_right_pad_h, in_right_pad_w), - layout); - - check_error(out_host, out_device); + host_convolution_forward(in, + wei, + out_host, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + layout); + + ck::utils::check_err(out_device.mData, out_host.mData); if(do_log) { diff --git a/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp b/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp new file mode 100644 index 00000000000..6fb8b4c2aa3 --- /dev/null +++ b/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp @@ -0,0 +1,393 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "debug.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "conv_common.hpp" +#include "device_tensor.hpp" +#include "device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp" + +#define USE_DYNAMIC_MODE 0 +#define USE_CONV_FWD_V5R1_NCHWC 1 + +enum ConvForwardAlgo +{ + V5R1NCHWC // 0 +}; + +template +void host_direct_convolution_nchwc(const Tensor& in, + const Tensor& wei, + const Tensor& bias, + Tensor& out, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads&, + const ck::ActivTypeEnum activ_type) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + auto f_nchw = [&](auto n, auto k0, auto ho, auto wo, auto k1) { + double v = 0; + const int k = k0 * out.mDesc.GetLengths()[4] + k1; + + for(int c0 = 0; c0 < wei.mDesc.GetLengths()[1]; ++c0) + { + for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + for(int c1 = 0; c1 < wei.mDesc.GetLengths()[4]; ++c1) + { + v += static_cast(in(n, c0, hi, wi, c1)) * + static_cast(wei(k, c0, y, x, c1)); + } + } + } + } + } + v += bias(k0, k1); + out(n, k0, ho, wo, k1) = activ(v, activ_type); + }; + + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3], + out.mDesc.GetLengths()[4])(std::thread::hardware_concurrency()); +} + +int main(int argc, char* argv[]) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + +#if USE_DYNAMIC_MODE + // dynamic mode + if(argc != 23) + { + printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n"); + printf("rest: N, K0, K1, C0, C1, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu; + + const ConvForwardAlgo algo = static_cast(std::stoi(argv[1])); + const bool do_verification = std::stoi(argv[2]); + const int init_method = std::stoi(argv[3]); + const bool do_log = std::stoi(argv[4]); + const int nrepeat = std::stoi(argv[5]); + + const index_t N = std::stoi(argv[6]); + const index_t K0 = std::stoi(argv[7]); + const index_t K1 = std::stoi(argv[8]); + const index_t C0 = std::stoi(argv[9]); + const index_t C1 = std::stoi(argv[10]); + const index_t Y = std::stoi(argv[11]); + const index_t X = std::stoi(argv[12]); + const index_t Hi = std::stoi(argv[13]); + const index_t Wi = std::stoi(argv[14]); + + const index_t conv_stride_h = std::stoi(argv[15]); + const index_t conv_stride_w = std::stoi(argv[16]); + const index_t conv_dilation_h = std::stoi(argv[17]); + const index_t conv_dilation_w = std::stoi(argv[18]); + const index_t in_left_pad_h = std::stoi(argv[19]); + const index_t in_left_pad_w = std::stoi(argv[20]); + const index_t in_right_pad_h = std::stoi(argv[21]); + const index_t in_right_pad_w = std::stoi(argv[22]); + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; +#else + // static mode + if(argc < 6) + { + printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n"); + exit(1); + } + + const ConvForwardAlgo algo = static_cast(std::stoi(argv[1])); + + const bool do_verification = std::stoi(argv[2]); + const int init_method = std::stoi(argv[3]); + const bool do_log = std::stoi(argv[4]); + const int nrepeat = std::stoi(argv[5]); + + // constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::Sigmoid; + constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu; + +#if 0 + constexpr auto N = Number<1>{}; + constexpr auto Hi = Number<1080>{}; + constexpr auto Wi = Number<1920>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + constexpr auto C0 = Number<2>{}; + constexpr auto C1 = Number<8>{}; + constexpr auto K0 = Number<1>{}; + constexpr auto K1 = Number<4>{}; +#elif 1 + constexpr auto N = Number<1>{}; + constexpr auto Hi = Number<1080>{}; + constexpr auto Wi = Number<1920>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + constexpr auto C0 = Number<2>{}; + constexpr auto C1 = Number<8>{}; + constexpr auto K0 = Number<2>{}; + constexpr auto K1 = Number<8>{}; +#elif 0 + constexpr auto N = Number<1>{}; + constexpr auto Hi = Number<1080>{}; + constexpr auto Wi = Number<1920>{}; + constexpr auto Y = Number<1>{}; + constexpr auto X = Number<1>{}; + constexpr auto C0 = Number<2>{}; + constexpr auto C1 = Number<8>{}; + constexpr auto K0 = Number<2>{}; + constexpr auto K1 = Number<8>{}; +#elif 0 + constexpr auto N = Number<1>{}; + constexpr auto Hi = Number<540>{}; + constexpr auto Wi = Number<960>{}; + constexpr auto Y = Number<1>{}; + constexpr auto X = Number<1>{}; + constexpr auto C0 = Number<2>{}; + constexpr auto C1 = Number<8>{}; + constexpr auto K0 = Number<2>{}; + constexpr auto K1 = Number<8>{}; +#elif 0 + constexpr auto N = Number<128>{}; + constexpr auto Hi = Number<270>{}; + constexpr auto Wi = Number<480>{}; + constexpr auto Y = Number<1>{}; + constexpr auto X = Number<1>{}; + constexpr auto C0 = Number<2>{}; + constexpr auto C1 = Number<8>{}; + constexpr auto K0 = Number<2>{}; + constexpr auto K1 = Number<8>{}; +#endif + + constexpr auto conv_stride_h = I1; + constexpr auto conv_stride_w = I1; + constexpr auto conv_dilation_h = I1; + constexpr auto conv_dilation_w = I1; + +#if 1 + constexpr auto in_left_pad_h = I1; + constexpr auto in_left_pad_w = I1; + constexpr auto in_right_pad_h = I1; + constexpr auto in_right_pad_w = I1; +#else + constexpr auto in_left_pad_h = I0; + constexpr auto in_left_pad_w = I0; + constexpr auto in_right_pad_h = I0; + constexpr auto in_right_pad_w = I0; +#endif + + constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; + constexpr auto XEff = (X - I1) * conv_dilation_w + I1; + + constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; + constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; +#endif + +#if 0 + using in_data_t = float; + using acc_data_t = float; + using out_data_t = float; +#elif 1 + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; +#elif 1 + using in_data_t = int8_t; + using acc_data_t = int32_t; + using out_data_t = int8_t; +#endif + + std::vector in_lengths_host(5), wei_lengths_host(5), out_lengths_host(5), + bias_lengths_host(2); + + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(C0); + in_lengths_host[2] = static_cast(Hi); + in_lengths_host[3] = static_cast(Wi); + in_lengths_host[4] = static_cast(C1); + + wei_lengths_host[0] = static_cast(K0 * K1); + wei_lengths_host[1] = static_cast(C0); + wei_lengths_host[2] = static_cast(Y); + wei_lengths_host[3] = static_cast(X); + wei_lengths_host[4] = static_cast(C1); + + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(K0); + out_lengths_host[2] = static_cast(Ho); + out_lengths_host[3] = static_cast(Wo); + out_lengths_host[4] = static_cast(K1); + + bias_lengths_host[0] = static_cast(K0); + bias_lengths_host[1] = static_cast(K1); + + Tensor in(in_lengths_host); + Tensor wei(wei_lengths_host); + Tensor bias(bias_lengths_host); + Tensor out_host(out_lengths_host); + Tensor out_device(out_lengths_host); + + ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); + ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: "); + ostream_HostTensorDescriptor(bias.mDesc, std::cout << "bias: "); + ostream_HostTensorDescriptor(out_host.mDesc, std::cout << "out: "); + + print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); + print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); + print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); + print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); + + std::size_t num_thread = 1; + + switch(init_method) + { + case 0: + // no initialization + break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + bias.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 3: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + bias.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 4: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 5: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + bias.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + + auto gen_wei = [](auto... is) { + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + }; + wei.GenerateTensorValue(gen_wei, num_thread); + } + + auto f_make_for_device_nchwc = [&]() { + const auto in_lengths_dev = make_tuple(N, C0, Hi, Wi, C1); + const auto wei_lengths_dev = make_tuple(K0 * K1, C0, Y, X, C1); + const auto out_lengths_dev = make_tuple(N, K0, Ho, Wo, K1); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + +#if USE_CONV_FWD_V5R1_NCHWC + if(algo == ConvForwardAlgo::V5R1NCHWC) + { + const auto tmp = f_make_for_device_nchwc(); + + device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + bias, + out_device, + nrepeat); + } +#endif + + if(do_verification) + { + host_direct_convolution_nchwc(in, + wei, + bias, + out_host, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + activ_type); + + ck::utils::check_err(out_device.mData, out_host.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; + LogRangeAsType(std::cout << "bias: ", bias.mData, ",") << std::endl; + LogRangeAsType(std::cout << "out_host : ", out_host.mData, ",") << std::endl; + LogRangeAsType(std::cout << "out_device: ", out_device.mData, ",") << std::endl; + } + } +} diff --git a/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp b/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp new file mode 100644 index 00000000000..fb7e8e975b9 --- /dev/null +++ b/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp @@ -0,0 +1,415 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "debug.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "conv_common.hpp" +#include "device_tensor.hpp" +#include "device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp" + +#define USE_DYNAMIC_MODE 0 +#define USE_CONV_FWD_V5R1_NCHWC 1 + +enum ConvForwardAlgo +{ + V5R1NCHWC // 0 +}; + +template +void host_direct_convolution_maxpool_nchwc(const Tensor& in, + const Tensor& wei, + const Tensor& bias, + Tensor& out_host, + Tensor& max_host, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads&, + const ck::ActivTypeEnum activ_type) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + auto f_nchw = [&](auto n, auto k0, auto ho, auto wo, auto k1) { + double v = 0; + auto k = k0 * out_host.mDesc.GetLengths()[4] + k1; + + for(int c0 = 0; c0 < wei.mDesc.GetLengths()[1]; ++c0) + { + for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + for(int c1 = 0; c1 < wei.mDesc.GetLengths()[4]; ++c1) + { + v += static_cast(in(n, c0, hi, wi, c1)) * + static_cast(wei(k, c0, y, x, c1)); + } + } + } + } + } + + v += bias(k0, k1); + v = activ(v, activ_type); + + out_host(n, k0, ho, wo, k1) = v; + }; + + make_ParallelTensorFunctor(f_nchw, + out_host.mDesc.GetLengths()[0], + out_host.mDesc.GetLengths()[1], + out_host.mDesc.GetLengths()[2], + out_host.mDesc.GetLengths()[3], + out_host.mDesc.GetLengths()[4])(std::thread::hardware_concurrency()); + + auto maxpool_nchw = [&](auto n, auto k0, auto ho, auto wo, auto k1) { + auto hx = ho * 2; + auto wx = wo * 2; + + auto v0 = out_host(n, k0, hx, wx, k1); + auto v1 = out_host(n, k0, hx, wx + 1, k1); + auto v2 = out_host(n, k0, hx + 1, wx, k1); + auto v3 = out_host(n, k0, hx + 1, wx + 1, k1); + + max_host(n, k0, ho, wo, k1) = std::max({v0, v1, v2, v3}); + }; + + make_ParallelTensorFunctor(maxpool_nchw, + max_host.mDesc.GetLengths()[0], + max_host.mDesc.GetLengths()[1], + max_host.mDesc.GetLengths()[2], + max_host.mDesc.GetLengths()[3], + max_host.mDesc.GetLengths()[4])(std::thread::hardware_concurrency()); +} + +int main(int argc, char* argv[]) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + +#if USE_DYNAMIC_MODE + // dynamic mode + if(argc != 23) + { + printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n"); + printf("rest: N, K0, K1, C0, C1, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu; + + const ConvForwardAlgo algo = static_cast(std::stoi(argv[1])); + const bool do_verification = std::stoi(argv[2]); + const int init_method = std::stoi(argv[3]); + const bool do_log = std::stoi(argv[4]); + const int nrepeat = std::stoi(argv[5]); + + const index_t N = std::stoi(argv[6]); + const index_t K0 = std::stoi(argv[7]); + const index_t K1 = std::stoi(argv[8]); + const index_t C0 = std::stoi(argv[9]); + const index_t C1 = std::stoi(argv[10]); + const index_t Y = std::stoi(argv[11]); + const index_t X = std::stoi(argv[12]); + const index_t Hi = std::stoi(argv[13]); + const index_t Wi = std::stoi(argv[14]); + + const index_t conv_stride_h = std::stoi(argv[15]); + const index_t conv_stride_w = std::stoi(argv[16]); + const index_t conv_dilation_h = std::stoi(argv[17]); + const index_t conv_dilation_w = std::stoi(argv[18]); + const index_t in_left_pad_h = std::stoi(argv[19]); + const index_t in_left_pad_w = std::stoi(argv[20]); + const index_t in_right_pad_h = std::stoi(argv[21]); + const index_t in_right_pad_w = std::stoi(argv[22]); + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + const index_t Ho_2 = Ho / 2; + const index_t Wo_2 = Wo / 2; +#else + // static mode + if(argc < 6) + { + printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n"); + exit(1); + } + + const ConvForwardAlgo algo = static_cast(std::stoi(argv[1])); + + const bool do_verification = std::stoi(argv[2]); + const int init_method = std::stoi(argv[3]); + const bool do_log = std::stoi(argv[4]); + const int nrepeat = std::stoi(argv[5]); + + constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu; + +#if 1 + constexpr auto N = Number<1>{}; + constexpr auto Hi = Number<1080>{}; + constexpr auto Wi = Number<1920>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + constexpr auto C0 = Number<2>{}; + constexpr auto C1 = Number<8>{}; + constexpr auto K0 = Number<2>{}; + constexpr auto K1 = Number<8>{}; +#elif 0 + constexpr auto N = Number<1>{}; + constexpr auto Hi = Number<1080>{}; + constexpr auto Wi = Number<1920>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + constexpr auto C0 = Number<3>{}; + constexpr auto C1 = Number<4>{}; + constexpr auto K0 = Number<2>{}; + constexpr auto K1 = Number<8>{}; +#elif 0 + constexpr auto N = Number<1>{}; + constexpr auto Hi = Number<540>{}; + constexpr auto Wi = Number<960>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + constexpr auto C0 = Number<2>{}; + constexpr auto C1 = Number<8>{}; + constexpr auto K0 = Number<2>{}; + constexpr auto K1 = Number<8>{}; +#elif 0 + constexpr auto N = Number<128>{}; + constexpr auto Hi = Number<270>{}; + constexpr auto Wi = Number<480>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + constexpr auto C0 = Number<2>{}; + constexpr auto C1 = Number<8>{}; + constexpr auto K0 = Number<2>{}; + constexpr auto K1 = Number<8>{}; +#endif + + constexpr auto conv_stride_h = I1; + constexpr auto conv_stride_w = I1; + constexpr auto conv_dilation_h = I1; + constexpr auto conv_dilation_w = I1; + constexpr auto in_left_pad_h = I1; + constexpr auto in_left_pad_w = I1; + constexpr auto in_right_pad_h = I1; + constexpr auto in_right_pad_w = I1; + + constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; + constexpr auto XEff = (X - I1) * conv_dilation_w + I1; + + constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; + constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; + + constexpr auto Ho_2 = Number{}; + constexpr auto Wo_2 = Number{}; + +#endif + +#if 0 + using in_data_t = float; + using acc_data_t = float; + using out_data_t = float; +#elif 1 + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; +#elif 1 + using in_data_t = int8_t; + using acc_data_t = int32_t; + using out_data_t = int8_t; +#endif + + std::vector in_lengths_host(5), wei_lengths_host(5), out_lengths_host(5), + max_lengths_host(5), bias_lengths_host(2); + + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(C0); + in_lengths_host[2] = static_cast(Hi); + in_lengths_host[3] = static_cast(Wi); + in_lengths_host[4] = static_cast(C1); + + wei_lengths_host[0] = static_cast(K0 * K1); + wei_lengths_host[1] = static_cast(C0); + wei_lengths_host[2] = static_cast(Y); + wei_lengths_host[3] = static_cast(X); + wei_lengths_host[4] = static_cast(C1); + + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(K0); + out_lengths_host[2] = static_cast(Ho); + out_lengths_host[3] = static_cast(Wo); + out_lengths_host[4] = static_cast(K1); + + max_lengths_host[0] = static_cast(N); + max_lengths_host[1] = static_cast(K0); + max_lengths_host[2] = static_cast(Ho_2); + max_lengths_host[3] = static_cast(Wo_2); + max_lengths_host[4] = static_cast(K1); + + bias_lengths_host[0] = static_cast(K0); + bias_lengths_host[1] = static_cast(K1); + + Tensor in(in_lengths_host); + Tensor wei(wei_lengths_host); + Tensor bias(bias_lengths_host); + Tensor out_device(out_lengths_host); + Tensor out_host(out_lengths_host); + Tensor max_device(max_lengths_host); + Tensor max_host(max_lengths_host); + + ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); + ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: "); + + print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); + print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); + print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); + print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); + + std::size_t num_thread = 1; + + switch(init_method) + { + case 0: + // no initialization + break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 3: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 4: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 5: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + + auto gen_wei = [](auto... is) { + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + }; + wei.GenerateTensorValue(gen_wei, num_thread); + } + + bias.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + + auto f_make_for_device_nchwc = [&]() { + const auto in_lengths_dev = make_tuple(N, C0, Hi, Wi, C1); + const auto wei_lengths_dev = make_tuple(K0 * K1, C0, Y, X, C1); + const auto max_lengths_dev = make_tuple(N, K0, Ho_2, Wo_2, K1); + const auto out_lengths_dev = make_tuple(N, K0, Ho, Wo, K1); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + max_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + +#if USE_CONV_FWD_V5R1_NCHWC + if(algo == ConvForwardAlgo::V5R1NCHWC) + { + const auto tmp = f_make_for_device_nchwc(); + + device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1< + in_data_t, + acc_data_t, + out_data_t, + activ_type>(tmp[I0], // in_lengths_dev + tmp[I1], // wei_lengths_dev + tmp[I2], // max_lengths_dev + tmp[I3], // out_lengths_dev + tmp[I4], // conv_strides_dev + tmp[I5], // conv_dilations_dev + tmp[I6], // in_left_pads_dev + tmp[I7], // in_right_pads_dev + in, + wei, + bias, + out_device, + max_device, + nrepeat); + } +#endif + + if(do_verification) + { + host_direct_convolution_maxpool_nchwc(in, + wei, + bias, + out_host, + max_host, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + activ_type); + + ck::utils::check_err(out_device.mData, out_host.mData); + ck::utils::check_err(max_device.mData, max_host.mData); + + if(do_log) + { + // LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; + // LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; + // LogRangeAsType(std::cout << "out_device: ", out_device.mData, ",") << + // std::endl; + LogRangeAsType(std::cout << "max_host: ", max_host.mData, ",") << std::endl; + LogRangeAsType(std::cout << "max_device: ", max_device.mData, ",") << std::endl; + } + } +} diff --git a/host/driver_offline/src/conv_wrw_driver_offline.cpp b/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp similarity index 73% rename from host/driver_offline/src/conv_wrw_driver_offline.cpp rename to library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp index 50f4d6a9b34..1ac974202ca 100644 --- a/host/driver_offline/src/conv_wrw_driver_offline.cpp +++ b/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "debug.hpp" #include "print.hpp" @@ -11,7 +13,6 @@ #include "host_tensor.hpp" #include "host_tensor_generator.hpp" #include "conv_common.hpp" -#include "host_conv_bwd_weight.hpp" #include "device_tensor.hpp" #include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" @@ -19,6 +20,15 @@ #include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp" #include "device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp" +enum ConvTensorLayout +{ + NCHW, + NHWC, + CHWN, + NCHWc, + NHWCc +}; + #define USE_DYNAMIC_MODE 1 #define USE_CONV_WRW_V4R4R2_XDL_NCHW 0 #define USE_CONV_WRW_V4R4R4_XDL_NHWC 0 @@ -35,6 +45,92 @@ enum ConvBackwardWeightAlgo V4R4R5XDLATOMICNHWC, // 4 }; +template +void host_convolution_backward_weight(const Tensor& out, + const Tensor& in, + Tensor& wei, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads&, + const ConvTensorLayout layout = ConvTensorLayout::NCHW) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + auto f_kcyx = [&](auto k, auto c, auto y, auto x) { + double v = 0; + for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) + { + for(int ho = 0; ho < out.mDesc.GetLengths()[2]; ++ho) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int wo = 0; wo < out.mDesc.GetLengths()[3]; ++wo) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + v += static_cast(in(n, c, hi, wi)) * + static_cast(out(n, k, ho, wo)); + } + } + } + } + wei(k, c, y, x) = v; + }; + + auto f_kyxc = [&](auto k, auto y, auto x, auto c) { + double v = 0; + for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) + { + for(int ho = 0; ho < out.mDesc.GetLengths()[1]; ++ho) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int wo = 0; wo < out.mDesc.GetLengths()[2]; ++wo) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && + wi < in.mDesc.GetLengths()[2]) + { + v += static_cast(in(n, hi, wi, c)) * + static_cast(out(n, ho, wo, k)); + } + } + } + } + wei(k, y, x, c) = v; + }; + + if(layout == ConvTensorLayout::NCHW) + { + make_ParallelTensorFunctor(f_kcyx, + wei.mDesc.GetLengths()[0], + wei.mDesc.GetLengths()[1], + wei.mDesc.GetLengths()[2], + wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else if(layout == ConvTensorLayout::NHWC) + { + make_ParallelTensorFunctor(f_kyxc, + wei.mDesc.GetLengths()[0], + wei.mDesc.GetLengths()[1], + wei.mDesc.GetLengths()[2], + wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("wrong! not supported layout"); + } +} + int main(int argc, char* argv[]) { using namespace ck; @@ -195,7 +291,7 @@ int main(int argc, char* argv[]) print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); - std::size_t num_thread = std::thread::hardware_concurrency(); + std::size_t num_thread = 1; switch(init_method) { @@ -203,30 +299,30 @@ int main(int argc, char* argv[]) // no initialization break; case 1: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); break; case 2: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; case 3: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); break; case 4: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; case 5: - in.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); - out.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); + in.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); + out.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); break; default: - in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); auto gen_out = [](auto... is) { - return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); }; out.GenerateTensorValue(gen_out, num_thread); } @@ -414,16 +510,16 @@ int main(int argc, char* argv[]) if(do_verification) { - host_direct_convolution_backward_weights(out, - in, - wei_host, - make_tuple(conv_stride_h, conv_stride_w), - make_tuple(conv_dilation_h, conv_dilation_w), - make_tuple(in_left_pad_h, in_left_pad_w), - make_tuple(in_right_pad_h, in_right_pad_w), - layout); - - check_error(wei_host, wei_device); + host_convolution_backward_weight(out, + in, + wei_host, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + layout); + + ck::utils::check_err(wei_device.mData, wei_host.mData); if(do_log) { diff --git a/host/driver_offline/src/gemm_driver_offline.cpp b/library/src/obselete_driver_offline/gemm_driver_offline.cpp similarity index 59% rename from host/driver_offline/src/gemm_driver_offline.cpp rename to library/src/obselete_driver_offline/gemm_driver_offline.cpp index e60b4905ae7..a09cb932d61 100644 --- a/host/driver_offline/src/gemm_driver_offline.cpp +++ b/library/src/obselete_driver_offline/gemm_driver_offline.cpp @@ -4,13 +4,14 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "debug.hpp" #include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" #include "host_tensor_generator.hpp" -#include "gemm_common.hpp" #include "host_gemm.hpp" #include "device_tensor.hpp" #include "device_gemm_xdlops_mk_kn_mn.hpp" @@ -31,7 +32,19 @@ #define USE_GEMM_XDL_KM_KN_NM 0 #define USE_GEMM_XDL_KM_NK_NM 0 -enum GemmAlgo +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + MK_KN_NM, // 4 + MK_NK_NM, // 5 + KM_KN_NM, // 6 + KM_NK_NM // 7 +}; + +enum struct GemmAlgo { Xdl_MK_KN_MN, // 0 Xdl_MK_NK_MN, // 1 @@ -43,6 +56,161 @@ enum GemmAlgo Xdl_KM_NK_NM, // 7 }; +template +void host_gemm(const Tensor& a, + const Tensor& b, + Tensor& c, + const GemmMatrixLayout layout) +{ + if(layout == GemmMatrixLayout::MK_KN_MN) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(m, k)) * static_cast(b(k, n)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::MK_NK_MN) + { + auto f_mk_nk_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(m, k)) * static_cast(b(n, k)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_KN_MN) + { + auto f_km_kn_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(k, m)) * static_cast(b(k, n)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_NK_MN) + { + auto f_km_nk_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(k, m)) * static_cast(b(n, k)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::MK_KN_NM) + { + auto f_mk_kn_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(m, k)) * static_cast(b(k, n)); + } + + c(n, m) = v; + }; + + make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::MK_NK_NM) + { + auto f_mk_nk_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(m, k)) * static_cast(b(n, k)); + } + + c(n, m) = v; + }; + + make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_KN_NM) + { + auto f_km_kn_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(k, m)) * static_cast(b(k, n)); + } + + c(n, m) = v; + }; + + make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_NK_NM) + { + auto f_km_nk_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(k, m)) * static_cast(b(n, k)); + } + + c(n, m) = v; + }; + + make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("wrong! not supported layout"); + } +} int main(int argc, char* argv[]) { using namespace ck; @@ -147,7 +315,7 @@ int main(int argc, char* argv[]) ostream_HostTensorDescriptor(b.mDesc, std::cout << "b: "); ostream_HostTensorDescriptor(c_host.mDesc, std::cout << "c: "); - std::size_t num_thread = std::thread::hardware_concurrency(); + std::size_t num_thread = 1; switch(init_method) { @@ -155,24 +323,24 @@ int main(int argc, char* argv[]) // no initialization break; case 1: - a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); break; case 2: - a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; case 3: - a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); break; case 4: - a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; default: - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - b.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); } #if USE_GEMM_XDL_MK_KN_MN @@ -275,7 +443,7 @@ int main(int argc, char* argv[]) { host_gemm(a, b, c_host, layout); - check_error(c_host, c_device); + ck::utils::check_err(c_device.mData, c_host.mData); if(do_log) { diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt new file mode 100644 index 00000000000..b20a4b57e58 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -0,0 +1,115 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/include/ck + ${PROJECT_SOURCE_DIR}/include/ck/utility + ${PROJECT_SOURCE_DIR}/include/ck/host_utility + ${PROJECT_SOURCE_DIR}/include/ck/tensor_description + ${PROJECT_SOURCE_DIR}/include/ck/tensor + ${PROJECT_SOURCE_DIR}/include/ck/problem_transform + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/device + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/grid + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/block + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/warp + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/thread + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/element + ${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor + ${PROJECT_SOURCE_DIR}/library/include/ck/library/host + ${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance + ${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance/gpu/reduce + ${PROJECT_SOURCE_DIR}/external/include/half +) + +function(add_instance_library INSTANCE_NAME) + message("adding instance ${INSTANCE_NAME}") + add_library(${INSTANCE_NAME} OBJECT ${ARGN}) + target_compile_features(${INSTANCE_NAME} PUBLIC) + set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) +endfunction(add_instance_library INSTANCE_NAME) + +add_subdirectory(gemm) +add_subdirectory(gemm_bias2d) +add_subdirectory(gemm_bias_relu) +add_subdirectory(gemm_bias_relu_add) +add_subdirectory(gemm_reduce) +add_subdirectory(batched_gemm) +add_subdirectory(conv1d_fwd) +add_subdirectory(conv2d_fwd) +add_subdirectory(conv3d_fwd) +add_subdirectory(conv2d_fwd_bias_relu) +add_subdirectory(conv2d_fwd_bias_relu_add) +add_subdirectory(conv2d_fwd_bias_relu_atomic_add) +add_subdirectory(conv2d_bwd_data) +add_subdirectory(reduce) +add_subdirectory(convnd_bwd_data) +add_subdirectory(grouped_gemm) +add_subdirectory(conv2d_bwd_weight) +add_subdirectory(batched_gemm_reduce) + +add_library(device_operations STATIC + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + device_conv2d.cpp +) +add_library(composablekernels::device_operations ALIAS device_operations) + + +set(DEV_OPS_INC_DIRS + ${PROJECT_SOURCE_DIR}/include/ck/ + ${PROJECT_SOURCE_DIR}/library/include/ck/ + ${PROJECT_SOURCE_DIR}/external/include/ +) +target_compile_features(device_operations PUBLIC) +set_target_properties(device_operations PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_include_directories(device_operations PUBLIC + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ +) + +#once new arches are enabled make this an option on the main cmake file +# and pass down here to be exported + +target_compile_options(device_operations +PRIVATE --offload-arch=gfx908 +) +# install(TARGETS device_operations LIBRARY DESTINATION lib) +install(TARGETS device_operations + EXPORT device_operationsTargets + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} +) +install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck) +install(EXPORT device_operationsTargets + FILE composable_kerneldevice_operationsTargets.cmake + NAMESPACE composable_kernel:: + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel +) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt new file mode 100644 index 00000000000..016c85f6732 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt @@ -0,0 +1,26 @@ +#device_batched_gemm_instance +set(DEVICE_BATCHED_GEMM_INSTANCE_SOURCE + device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp; + device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp; + device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp; + device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp; + device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp; + device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp; + device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp; + device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp; + device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp; + device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp; + device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp; + device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp; + device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp; + device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp; + device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp; + device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp; +) + +add_library(device_batched_gemm_instance OBJECT ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE}) +# target_compile_features(device_batched_gemm_instance PUBLIC) +set_target_properties(device_batched_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +# install(TARGETS device_batched_gemm_instance LIBRARY DESTINATION lib) + +clang_tidy_check(device_batched_gemm_instance) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp new file mode 100644 index 00000000000..9641e3cf72d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_batched_gemm_instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances{}); +} + +} // namespace device_batched_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp new file mode 100644 index 00000000000..c93c77dccce --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_batched_gemm_instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances{}); +} + +} // namespace device_batched_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp new file mode 100644 index 00000000000..8da334071a6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,55 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_batched_gemm_instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances{}); +} + +} // namespace device_batched_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp new file mode 100644 index 00000000000..9566d5ecd4c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,56 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_batched_gemm_instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances{}); +} + +} // namespace device_batched_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp new file mode 100644 index 00000000000..3be80837134 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_batched_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances{}); +} + +} // namespace device_batched_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp new file mode 100644 index 00000000000..21daf0b1931 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_batched_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances{}); +} + +} // namespace device_batched_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp new file mode 100644 index 00000000000..9606b1f0cc7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,60 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_batched_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances{}); +} + +} // namespace device_batched_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp new file mode 100644 index 00000000000..3d3e35e8e45 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,56 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_batched_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances{}); +} + +} // namespace device_batched_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp new file mode 100644 index 00000000000..c6d6a1ba6a3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_batched_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances{}); +} + +} // namespace device_batched_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp new file mode 100644 index 00000000000..157bf413ac3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_batched_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances{}); +} + +} // namespace device_batched_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp new file mode 100644 index 00000000000..5a8988722e2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_batched_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances{}); +} + +} // namespace device_batched_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp new file mode 100644 index 00000000000..2e892d97f51 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,56 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_batched_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances{}); +} + +} // namespace device_batched_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp new file mode 100644 index 00000000000..1f3951c938f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,66 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_batched_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using AData = int8_t; +using BData = int8_t; +using CData = int8_t; +using AccData = int32_t; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances{}); +} + +} // namespace device_batched_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp new file mode 100644 index 00000000000..d6faa5a9cb3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,66 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_batched_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using AData = int8_t; +using BData = int8_t; +using CData = int8_t; +using AccData = int32_t; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances{}); +} + +} // namespace device_batched_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp new file mode 100644 index 00000000000..b5bc2786f23 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,66 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_batched_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using AData = int8_t; +using BData = int8_t; +using CData = int8_t; +using AccData = int32_t; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances{}); +} + +} // namespace device_batched_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp new file mode 100644 index 00000000000..6858903ff48 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,58 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_batched_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using AData = int8_t; +using BData = int8_t; +using CData = int8_t; +using AccData = int32_t; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances{}); +} + +} // namespace device_batched_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt new file mode 100644 index 00000000000..0606df01f14 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt @@ -0,0 +1,12 @@ +set(DEVICE_BATCHED_GEMM_REDUCE_INSTANCE_SOURCE + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp +) + +add_instance_library(device_batched_gemm_reduce_instance OBJECT ${DEVICE_BATCHED_GEMM_REDUCE_INSTANCE_SOURCE}) +target_compile_features(device_batched_gemm_reduce_instance PUBLIC) +set_target_properties(device_batched_gemm_reduce_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +clang_tidy_check(device_batched_gemm_reduce_instance) + diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp new file mode 100644 index 00000000000..322b0ddaf54 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,80 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reduction_operator.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::UnaryIdentic; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp new file mode 100644 index 00000000000..bdc5aebe1a3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,80 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reduction_operator.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::UnaryIdentic; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp new file mode 100644 index 00000000000..df51cb617bb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,80 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reduction_operator.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::UnaryIdentic; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp new file mode 100644 index 00000000000..10afddb5c6a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,77 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reduction_operator.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::UnaryIdentic; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + // clang-format on + >; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv1d_fwd/CMakeLists.txt new file mode 100644 index 00000000000..77aa6198f59 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/CMakeLists.txt @@ -0,0 +1,14 @@ +# device_conv1d_fwd_instance +set(DEVICE_CONV1D_FWD_INSTANCE_SOURCE + device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp; + device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp; + device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp; + device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp; +) + +add_library(device_conv1d_fwd_instance OBJECT ${DEVICE_CONV1D_FWD_INSTANCE_SOURCE}) +# target_compile_features(device_conv1d_fwd_instance PUBLIC) +set_target_properties(device_conv1d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +# install(TARGETS device_conv1d_fwd_instance LIBRARY DESTINATION lib) + +clang_tidy_check(device_conv1d_fwd_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp new file mode 100644 index 00000000000..9288e40e566 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp @@ -0,0 +1,112 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv1d_fwd_instance { + +using F32 = float; +using BF16 = bhalf_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances = std::tuple< +// clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if !CK_WORKAROUND_GITHUB_135 + // FIXME: this instance causes numerical errors. + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, +#endif + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_bf16_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace device_conv1d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp new file mode 100644 index 00000000000..669dca617a0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp @@ -0,0 +1,109 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv1d_fwd_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_f16_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace device_conv1d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp new file mode 100644 index 00000000000..0abd47142ba --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp @@ -0,0 +1,112 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv1d_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +//------------------------------------------------------------------------------ +// Conv1D +//------------------------------------------------------------------------------ + +// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k] +using device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_f32_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace device_conv1d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp new file mode 100644 index 00000000000..53e0f775502 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp @@ -0,0 +1,111 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv1d_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_int8_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace device_conv1d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt new file mode 100644 index 00000000000..d7882a7d8b0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt @@ -0,0 +1,12 @@ +# device_conv2d_bwd_data_instance +set(DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp; + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; +) + +add_library(device_conv2d_bwd_data_instance OBJECT ${DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE}) +set_target_properties(device_conv2d_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_conv2d_bwd_data_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp new file mode 100644 index 00000000000..b5814aa17fc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -0,0 +1,83 @@ +#include +#include "config.hpp" +#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances = std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = + std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 00000000000..53498aff344 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,85 @@ +#include +#include "config.hpp" +#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, +#if !CK_WORKAROUND_SWDEV_325164 + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, +#endif + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = + std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 00000000000..fbe279e0333 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,82 @@ +#include +#include "config.hpp" +#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances = std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = + std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp new file mode 100644 index 00000000000..7fd51bbfbfb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -0,0 +1,83 @@ +#include +#include "config.hpp" +#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using DataType = int8_t; +using AccType = int32_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances = std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //#####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/CMakeLists.txt new file mode 100644 index 00000000000..7c384a882b7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/CMakeLists.txt @@ -0,0 +1,11 @@ +# device_conv2d_bwd_weight_instance +set(DEVICE_CONV2D_BWD_WEIGHT_INSTANCE_SOURCE + device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; + device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; +) +add_library(device_conv2d_bwd_weight_instance OBJECT ${DEVICE_CONV2D_BWD_WEIGHT_INSTANCE_SOURCE}) +target_compile_features(device_conv2d_bwd_weight_instance PUBLIC) +set_target_properties(device_conv2d_bwd_weight_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +install(TARGETS device_conv2d_bwd_weight_instance LIBRARY DESTINATION lib) + +clang_tidy_check(device_conv2d_bwd_weight_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 00000000000..d915db67587 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_weight_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances{}); +} + +} // namespace device_conv2d_bwd_weight_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 00000000000..e9f6636518d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_weight_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances{}); +} + +} // namespace device_conv2d_bwd_weight_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt new file mode 100644 index 00000000000..857e36d6f57 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt @@ -0,0 +1,12 @@ +# device_conv2d_fwd_instance +set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp; + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; + device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp; +) +add_library(device_conv2d_fwd_instance OBJECT ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) +set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_conv2d_fwd_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 00000000000..b2f6f9335eb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,144 @@ +#include +#include "config.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +// arbitrary conv +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, stride 1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 64, 2, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 256, 64, 2, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 64, 2, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_odd_c_f16_instances{}); +} + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp new file mode 100644 index 00000000000..47405ea1bfb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -0,0 +1,110 @@ +#include +#include "config.hpp" +#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 00000000000..a4060f8bf20 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,109 @@ +#include +#include "config.hpp" +#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 00000000000..3c46c2f7e98 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,108 @@ +#include +#include "config.hpp" +#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp new file mode 100644 index 00000000000..0db59ca394c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -0,0 +1,109 @@ +#include +#include "config.hpp" +#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt new file mode 100644 index 00000000000..ad66c73bf84 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt @@ -0,0 +1,8 @@ +# device_conv2d_fwd_bias_relu_instance +set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE + device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp; +) +add_library(device_conv2d_fwd_bias_relu_instance OBJECT ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE}) +set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_conv2d_fwd_bias_relu_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 00000000000..9c3f0a4b964 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,149 @@ +#include +#include "config.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_bias_activation_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +static constexpr auto MemorySet = ck::InMemoryDataOperationEnum::Set; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +// arbitrary conv +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, stride 1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// Odd C +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 64, 2, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 256, 64, 2, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 64, 2, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instances{}); +} + +} // namespace device_conv2d_fwd_bias_activation_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt new file mode 100644 index 00000000000..36b1f6c1535 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt @@ -0,0 +1,8 @@ +# device_conv2d_fwd_bias_relu_add_instance +set(DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp; +) +add_library(device_conv2d_fwd_bias_relu_add_instance OBJECT ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) +set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_conv2d_fwd_bias_relu_add_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 00000000000..b9f46e26119 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,149 @@ +#include +#include "config.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_bias_activation_add_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +// arbitrary conv +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, stride 1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// Odd C +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 64, 2, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 256, 64, 2, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 64, 2, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_instances{}); +} + +} // namespace device_conv2d_fwd_bias_activation_add_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/CMakeLists.txt new file mode 100644 index 00000000000..5906c7c5ac7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/CMakeLists.txt @@ -0,0 +1,9 @@ +# device_conv2d_fwd_bias_relu_atomic_add_instance +set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE + device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp; +) + +add_library(device_conv2d_fwd_bias_relu_atomic_add_instance OBJECT ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE}) +set_target_properties(device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_conv2d_fwd_bias_relu_atomic_add_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 00000000000..c56ad270aa4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,69 @@ +#include +#include "config.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_bias_activation_atomic_add_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +static constexpr auto InMemoryAtomicAdd = ck::InMemoryDataOperationEnum::AtomicAdd; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 16>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 16>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances( + std::vector>& + instance_container) +{ + using Instances = + device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances; + + const auto instances = Instances{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Instance = remove_cvref_t(instances))>; + + auto instance = Instance{}; + + instance_container.push_back(std::make_unique(instance)); + }); +} + +} // namespace device_conv2d_fwd_bias_activation_atomic_add_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv3d_fwd/CMakeLists.txt new file mode 100644 index 00000000000..91a299c7422 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/CMakeLists.txt @@ -0,0 +1,12 @@ +# device_conv3d_fwd_instance +set(DEVICE_CONV3D_FWD_INSTANCE_SOURCE + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp; + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp; + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp; + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp; +) +add_library(device_conv3d_fwd_instance OBJECT ${DEVICE_CONV3D_FWD_INSTANCE_SOURCE}) +target_compile_features(device_conv3d_fwd_instance PUBLIC) +set_target_properties(device_conv3d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_conv3d_fwd_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp new file mode 100644 index 00000000000..745d26904aa --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp @@ -0,0 +1,113 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv3d_fwd_instance { + +using F32 = float; +using BF16 = bhalf_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances = std::tuple< +// clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if !CK_WORKAROUND_GITHUB_135 + // FIXME: this instance causes numerical errors. + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, +#endif + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances{}); + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_bf16_instances{}); + add_device_operation_instances( + instances, device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace device_conv3d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp new file mode 100644 index 00000000000..4d51180e725 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp @@ -0,0 +1,110 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv3d_fwd_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances{}); + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_f16_instances{}); + add_device_operation_instances( + instances, device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace device_conv3d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp new file mode 100644 index 00000000000..9a8ff8d7143 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp @@ -0,0 +1,109 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv3d_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances{}); + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_f32_instances{}); + add_device_operation_instances( + instances, device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace device_conv3d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp new file mode 100644 index 00000000000..7f54b66f9b5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp @@ -0,0 +1,112 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv3d_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances{}); + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_int8_instances{}); + add_device_operation_instances( + instances, device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace device_conv3d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt new file mode 100644 index 00000000000..037f8608086 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt @@ -0,0 +1,22 @@ +# device_convnd_bwd_data_instance +set(DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp; + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp; + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp; + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp; + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp; + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp; + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp; + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp; + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp; +) + +add_library(device_convnd_bwd_data_instance OBJECT ${DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE}) +target_compile_features(device_convnd_bwd_data_instance PUBLIC) +set_target_properties(device_convnd_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +install(TARGETS device_convnd_bwd_data_instance LIBRARY DESTINATION lib) + +clang_tidy_check(device_convnd_bwd_data_instance) diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp new file mode 100644 index 00000000000..5c915dcc426 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp @@ -0,0 +1,84 @@ +#include +#include "config.hpp" +#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using BF16 = bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances{}); + add_device_operation_instances( + instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp new file mode 100644 index 00000000000..e8f7d4f11ad --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp @@ -0,0 +1,86 @@ +#include +#include "config.hpp" +#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, +#if 1 + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, +#endif + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp new file mode 100644 index 00000000000..b4c65ab66ab --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp @@ -0,0 +1,83 @@ +#include +#include "config.hpp" +#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances{}); + add_device_operation_instances( + instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp new file mode 100644 index 00000000000..e3958ef6891 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp @@ -0,0 +1,86 @@ +#include +#include "config.hpp" +#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using DataType = int8_t; +using AccType = int32_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + #if 1 + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + #endif + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //##############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances{}); + add_device_operation_instances( + instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp new file mode 100644 index 00000000000..2e4cd5cf312 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -0,0 +1,84 @@ +#include +#include "config.hpp" +#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using BF16 = bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 00000000000..7170decc439 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,86 @@ +#include +#include "config.hpp" +#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, +#if 1 + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, +#endif + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 00000000000..5a727b1113a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,83 @@ +#include +#include "config.hpp" +#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp new file mode 100644 index 00000000000..3c53644ddc5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -0,0 +1,88 @@ +#include +#include "config.hpp" +#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using DataType = int8_t; +using AccType = int32_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + #if 1 + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + #endif + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //##############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + #if 1 + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + #endif + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp new file mode 100644 index 00000000000..edbb7a14d9e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp @@ -0,0 +1,84 @@ +#include +#include "config.hpp" +#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using BF16 = bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | ./ | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances{}); + add_device_operation_instances( + instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp new file mode 100644 index 00000000000..5d00fa8f081 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp @@ -0,0 +1,86 @@ +#include +#include "config.hpp" +#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, +#if 1 + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, +#endif + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp new file mode 100644 index 00000000000..d5cd04de6b9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp @@ -0,0 +1,83 @@ +#include +#include "config.hpp" +#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances{}); + add_device_operation_instances( + instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp new file mode 100644 index 00000000000..d5519706061 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp @@ -0,0 +1,86 @@ +#include +#include "config.hpp" +#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using DataType = int8_t; +using AccType = int32_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, +#if 1 + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, +#endif + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1> + // clang-format on + >; + +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //##############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances{}); + add_device_operation_instances( + instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/device_conv2d.cpp b/library/src/tensor_operation_instance/gpu/device_conv2d.cpp new file mode 100644 index 00000000000..6b99433ffa2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/device_conv2d.cpp @@ -0,0 +1,201 @@ +#include +#include "config.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" +#include "host_interface.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>& instances); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>& instances); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>& instances); + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl +{ + std::unique_ptr + MakeArgumentPointer(void* in_ptr, + void* wei_ptr, + void* out_ptr, + size_t N, + size_t K, + size_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) const + { + return el->MakeArgumentPointer(in_ptr, + wei_ptr, + out_ptr, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + } + std::unique_ptr MakeInvokerPointer() const + { + return el->MakeInvokerPointer(); + } + + std::string GetTypeString() { return el->GetTypeString(); } + bool IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg) + { + return el->IsSupportedArgument(arg); + } + + ck::tensor_operation::device::DeviceConvFwdPtr el; +}; + +DeviceConvFwdPtr_t::DeviceConvFwdPtr_t() : pImpl(nullptr) {} +DeviceConvFwdPtr_t::~DeviceConvFwdPtr_t() = default; +DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&) = default; +DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& other) + : pImpl(std::make_unique(std::move(other))) +{ +} + +std::unique_ptr +DeviceConvFwdPtr_t::MakeArgumentPointer(void* in_ptr, + void* wei_ptr, + void* out_ptr, + size_t N, + size_t K, + size_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) const +{ + return pImpl->MakeArgumentPointer(in_ptr, + wei_ptr, + out_ptr, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); +} + +std::unique_ptr DeviceConvFwdPtr_t::MakeInvokerPointer() const +{ + return pImpl->MakeInvokerPointer(); +} + +std::string DeviceConvFwdPtr_t::GetTypeString() { return pImpl->GetTypeString(); } +bool DeviceConvFwdPtr_t::IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg_ptr) +{ + return pImpl->IsSupportedArgument(arg_ptr); +} + +using namespace ck::tensor_operation::device::device_conv2d_fwd_instance; +void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t( + std::vector& instances) +{ + std::vector< + ck::tensor_operation::device::DeviceConvFwdPtr> + local_instances; + add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(local_instances); + for(auto& kinder : local_instances) + { + DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; + instances.emplace_back(tmp); + } + return; +} + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t( + std::vector& instances) +{ + std::vector< + ck::tensor_operation::device::DeviceConvFwdPtr> + local_instances; + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(local_instances); + for(auto& kinder : local_instances) + { + DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; + instances.emplace_back(tmp); // Perhaps we can do better + } + return; +} + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t( + std::vector& instances) +{ + std::vector< + ck::tensor_operation::device::DeviceConvFwdPtr> + local_instances; + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(local_instances); + for(auto& kinder : local_instances) + { + DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; + instances.emplace_back(tmp); // Perhaps we can do better + } + return; +} + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t( + std::vector& instances) +{ + std::vector< + ck::tensor_operation::device::DeviceConvFwdPtr> + local_instances; + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(local_instances); + for(auto& kinder : local_instances) + { + DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; + instances.emplace_back(tmp); // Perhaps we can do better + } + return; +} + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t( + std::vector& instances) +{ + std::vector< + ck::tensor_operation::device::DeviceConvFwdPtr> + local_instances; + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(local_instances); + for(auto& kinder : local_instances) + { + DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; + instances.emplace_back(tmp); + } + return; +} diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt new file mode 100644 index 00000000000..da769a56269 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -0,0 +1,52 @@ +set(DEVICE_GEMM_INSTANCE_SOURCE + device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp; + device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp; + device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp; + device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp; + device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp; + device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp; + device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp; + device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp; + device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp; + device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp; + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp; + device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp; +) + +add_library(device_gemm_instance OBJECT ${DEVICE_GEMM_INSTANCE_SOURCE}) + +target_compile_features(device_gemm_instance PUBLIC) +set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 00000000000..db7f6af04b4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,45 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 00000000000..c4253bcc4cd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,45 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..d19d11f1f8a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,45 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..cd86e5ceaed --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,46 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_f16_f16_f16_mk_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 00000000000..3fcc5fdfdcb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,45 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 00000000000..8cd32128b55 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,46 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_f32_f32_f32_km_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..4c4bfc440d6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,46 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_f32_f32_f32_mk_kn_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..c6077341b1c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,46 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_f32_f32_f32_mk_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp new file mode 100644 index 00000000000..91b68d4bf23 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp @@ -0,0 +1,42 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_i8_i8_i8_km_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp new file mode 100644 index 00000000000..13b185fd936 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp @@ -0,0 +1,42 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_i8_i8_i8_km_nk_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..ff4a89beb4d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp @@ -0,0 +1,42 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_i8_i8_i8_mk_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..e32158a292d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp @@ -0,0 +1,42 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_i8_i8_i8_mk_nk_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..de97b60a62a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,58 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp new file mode 100644 index 00000000000..5e99c67b3f7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp @@ -0,0 +1,61 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp new file mode 100644 index 00000000000..321b97fd30e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp @@ -0,0 +1,61 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..1d69a23dd72 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,61 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..8ffa2b8b867 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp @@ -0,0 +1,58 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 00000000000..09adf1678d2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,61 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 00000000000..121b5857b2e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,61 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..2073d5f50ec --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,61 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..e177ee60ec9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,58 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 00000000000..ff830d41619 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,60 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 1, 1, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 1, 1, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 00000000000..79bca77aad1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,60 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 1, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 1, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..fac4e8d96ee --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,60 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..ffcd957913e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,57 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp new file mode 100644 index 00000000000..2185b55aac0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp @@ -0,0 +1,61 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances = + std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 4, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 4, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp new file mode 100644 index 00000000000..90966349b21 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp @@ -0,0 +1,61 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances = + std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 4, 16, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 4, 16, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 4, 16, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 16, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..aa5a13001c0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp @@ -0,0 +1,61 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances = + std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..82eec1164af --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp @@ -0,0 +1,58 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances = + std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 00000000000..08047c7e52b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_f16_f16_f16_km_kn_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 00000000000..05cb080cbfd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_f16_f16_f16_km_nk_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..4de989caf0c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,62 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..633e2aac2e4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,74 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = + std::tuple< + // clang-format off + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +// irregular tile size +using device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{}); + add_device_operation_instances(instances, + device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 00000000000..8284311102d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 00000000000..235c4771f9e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_f32_f32_f32_km_nk_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..b7000bddf87 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..1b4f23141b3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,58 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 00000000000..26ec965bb50 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_splitk_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 00000000000..45e3f9f9400 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_splitk_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..042ac2b8cae --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_splitk_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..21fdb7cd9df --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,94 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_splitk_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +// using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< +// // clang-format off +// //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| +// B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| +// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| +// ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| +// BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| +// CBlockTransferClusterLengths| CBlockTransfer| +// //#########################| Type| Type| Type| Type| | | | +// Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | +// XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| +// SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| +// SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| +// _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +// //#########################| | | | | | | | +// Operation| Operation| Operation| | | | | | | | +// | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| +// PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | +// PerVector| PerVector_K1| | PerShuffle| PerShuffle| +// _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +// //#########################| | | | | | | | | | +// | | | | | | | | | | | | +// | | | | | | | | | | | | +// | | | | | +// DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, +// PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 4, 8, 16, +// 16, 2, 9, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, +// true, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, +// true, 1, 9, S<1, 2, 1, 72>, 2> +// // clang-format on +// >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances{}); + + // FIXME - IsSupportedArgument() is false, need to check validity + // add_device_operation_instances( + // instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 00000000000..971bdcad583 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_splitk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 00000000000..3b7bdb87be0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_splitk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..8366616246e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,58 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_splitk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances = std::tuple< + // clang-format off + //###################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM|Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //###################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //###################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 96, 128, 4, 8, 16, 16, 3, 4, S<1, 4, 32, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 32, 256, 4, 4, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 16, 256, 4, 4, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 16, 128, 4, 4, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..396de62cfb2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,58 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_splitk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias2d/CMakeLists.txt new file mode 100644 index 00000000000..e2b0abb1d10 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias2d/CMakeLists.txt @@ -0,0 +1,16 @@ +# device_gemm_bias2d_instance +set(DEVICE_GEMM_BIAS2D_INSTANCE_SOURCE + device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp; +) + +add_library(device_gemm_bias2d_instance OBJECT ${DEVICE_GEMM_BIAS2D_INSTANCE_SOURCE}) +set_target_properties(device_gemm_bias2d_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_gemm_bias2d_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 00000000000..bd16850ee4f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 00000000000..12740ce256f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..56db0475efe --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..b20ee8db69a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,57 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 00000000000..11984c36db5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 00000000000..bd0a9880594 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances = std::tuple< + // clang-format off + //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..440ea1582e5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances = std::tuple< + // clang-format off + //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..fab885969f7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias2d/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,56 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances = std::tuple< + // clang-format off + //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias_relu/CMakeLists.txt new file mode 100644 index 00000000000..e2e7d4badd2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_relu/CMakeLists.txt @@ -0,0 +1,12 @@ +# device_gemm_bias_relu_instance +set(DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE + device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp; +) + +add_library(device_gemm_bias_relu_instance OBJECT ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE}) +set_target_properties(device_gemm_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_gemm_bias_relu_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 00000000000..4927a05ca4e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +// c[m, n] = ReLU(a[k, m] * b[k, n] + c0[n]) +using device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //#####################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 00000000000..f712f9de118 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +// c[m, n] = ReLU(a[k, m] * b[n, k] + c0[n]) +using device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //#####################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..26af05bbde4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +// c[m, n] = ReLU(a[m, k] * b[k, n] + c0[n]) +using device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#####################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..901b7a5d644 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_relu/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,57 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +// c[m, n] = ReLU(a[m, k] * b[n, k] + c0[n]) +using device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#####################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/CMakeLists.txt new file mode 100644 index 00000000000..a10dbb555dc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/CMakeLists.txt @@ -0,0 +1,12 @@ +# device_gemm_bias_relu_add_instance +set(DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE + device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp; +) + +add_library(device_gemm_bias_relu_add_instance OBJECT ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE}) +set_target_properties(device_gemm_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_gemm_bias_relu_add_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 00000000000..c26f66a9ed5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +// c[m, n] = ReLU(a[k, m] * b[k, n] + c0[n]) + c1[m, n] +using device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //#########################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 00000000000..c0950666b17 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +// c[m, n] = ReLU(a[k, m] * b[n, k] + c0[n]) + c1[m, n] +using device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //#########################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..42c1f72d6e6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +// c[m, n] = ReLU(a[m, k] * b[k, n] + c0[n]) + c1[m, n] +using device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#########################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..3961def81d3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,57 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +// c[m, n] = ReLU(a[m, k] * b[n, k] + c0[n]) + c1[m, n] +using device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#########################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt new file mode 100644 index 00000000000..5bc6d17a93a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt @@ -0,0 +1,10 @@ +set(DEVICE_GEMM_REDUCE_INSTANCE_SOURCE + device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp + device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp + device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp + device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp +) + +add_instance_library(device_gemm_reduce_instance ${DEVICE_GEMM_REDUCE_INSTANCE_SOURCE}) +install(TARGETS device_gemm_reduce_instance LIBRARY DESTINATION lib) +clang_tidy_check(device_gemm_reduce_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 00000000000..33660c04818 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,78 @@ +#include +#include "config.hpp" +#include "device_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reduction_operator.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::UnaryIdentic; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[k, m] * b[k, n] +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 00000000000..bd8766a617c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,78 @@ +#include +#include "config.hpp" +#include "device_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reduction_operator.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::UnaryIdentic; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[k, m] * b[n, k] +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances = std::tuple< + // clang-format off + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..c04431c1e02 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,78 @@ +#include +#include "config.hpp" +#include "device_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reduction_operator.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::UnaryIdentic; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances = std::tuple< + // clang-format off + //###########################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..ebd89e5975f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,75 @@ +#include +#include "config.hpp" +#include "device_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reduction_operator.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::UnaryIdentic; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances = std::tuple< + // clang-format off + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + // clang-format on + >; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt new file mode 100644 index 00000000000..6c5e31fddd3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -0,0 +1,15 @@ +# device_grouped_gemm_instance +set(DEVICE_GROUPED_GEMM_INSTANCE_SOURCE + device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp; + device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp; + device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp; + device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; +) + +add_library(device_grouped_gemm_instance OBJECT ${DEVICE_GROUPED_GEMM_INSTANCE_SOURCE}) + +target_compile_features(device_grouped_gemm_instance PUBLIC) +set_target_properties(device_grouped_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +install(TARGETS device_grouped_gemm_instance LIBRARY DESTINATION lib) + +clang_tidy_check(device_grouped_gemm_instance) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 00000000000..19f1011c3f1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_grouped_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_grouped_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace device_grouped_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 00000000000..59e0d240555 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_grouped_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_grouped_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace device_grouped_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..35052ae8a93 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,62 @@ +#include +#include "config.hpp" +#include "device_grouped_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_grouped_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace device_grouped_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..cb41d2724c4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,73 @@ +#include +#include "config.hpp" +#include "device_grouped_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_grouped_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +// irregular tile size +using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< + // clang-format off + //##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{}); + add_device_operation_instances( + instances, device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace device_grouped_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/reduce/CMakeLists.txt new file mode 100644 index 00000000000..d566796c13a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/CMakeLists.txt @@ -0,0 +1,29 @@ +# device_reduce_instance +set(DEVICE_REDUCE_INSTANCE_SOURCE + device_reduce_instance_blockwise_f16_f16_f16.cpp; + device_reduce_instance_blockwise_f16_f32_f16.cpp; + device_reduce_instance_blockwise_f32_f32_f32.cpp; + device_reduce_instance_blockwise_f32_f64_f32.cpp; + device_reduce_instance_blockwise_f64_f64_f64.cpp; + device_reduce_instance_blockwise_i8_i32_i8.cpp; + device_reduce_instance_blockwise_i8_i8_i8.cpp; + device_reduce_instance_blockwise_b16_f32_b16.cpp; + device_reduce_instance_threadwise_f16_f16_f16.cpp; + device_reduce_instance_threadwise_f16_f32_f16.cpp; + device_reduce_instance_threadwise_f32_f32_f32.cpp; + device_reduce_instance_threadwise_f32_f64_f32.cpp; + device_reduce_instance_threadwise_f64_f64_f64.cpp; + device_reduce_instance_threadwise_i8_i32_i8.cpp; + device_reduce_instance_threadwise_i8_i8_i8.cpp; + device_reduce_instance_threadwise_b16_f32_b16.cpp; + device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp; + device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp; + device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp; + device_reduce_instance_multiblock_atomic_add_f64_f64_f64.cpp; + device_reduce_instance_multiblock_atomic_add_b16_f32_f32.cpp; +) + +add_library(device_reduce_instance OBJECT ${DEVICE_REDUCE_INSTANCE_SOURCE}) +set_target_properties(device_reduce_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_reduce_instance) diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.cpp new file mode 100644 index 00000000000..0274d89fc9e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.cpp @@ -0,0 +1,53 @@ +#include "device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1); + +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.cpp new file mode 100644 index 00000000000..8a43d860ea7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.cpp @@ -0,0 +1,40 @@ +#include "device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.cpp new file mode 100644 index 00000000000..3e0b8ba59c7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.cpp @@ -0,0 +1,28 @@ +#include "device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.cpp new file mode 100644 index 00000000000..ee96311f8ce --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.cpp @@ -0,0 +1,52 @@ +#include "device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.cpp new file mode 100644 index 00000000000..b0ae95e82d9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.cpp @@ -0,0 +1,28 @@ +#include "device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.cpp new file mode 100644 index 00000000000..9cca2dbbeb9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.cpp @@ -0,0 +1,52 @@ +#include "device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.cpp new file mode 100644 index 00000000000..05cd1921ee7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.cpp @@ -0,0 +1,24 @@ +#include "device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.cpp new file mode 100644 index 00000000000..66ef0178643 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.cpp @@ -0,0 +1,40 @@ +#include "device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.cpp new file mode 100644 index 00000000000..9b2b7f5d8c1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.cpp @@ -0,0 +1,24 @@ +#include "device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp new file mode 100644 index 00000000000..fc956aa04b6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp @@ -0,0 +1,24 @@ +#include "device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp new file mode 100644 index 00000000000..e5ffd9f976d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp @@ -0,0 +1,24 @@ +#include "device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp new file mode 100644 index 00000000000..229829b8897 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp @@ -0,0 +1,24 @@ +#include "device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.cpp new file mode 100644 index 00000000000..497f2695be0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.cpp @@ -0,0 +1,24 @@ +#include "device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.cpp new file mode 100644 index 00000000000..02fc4b4c01a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.cpp @@ -0,0 +1,53 @@ +#include "device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1); + +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.cpp new file mode 100644 index 00000000000..0984cdc46b9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.cpp @@ -0,0 +1,40 @@ +#include "device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.cpp new file mode 100644 index 00000000000..64f14bd4e72 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.cpp @@ -0,0 +1,28 @@ +#include "device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.cpp new file mode 100644 index 00000000000..69ed303b177 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.cpp @@ -0,0 +1,52 @@ +#include "device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.cpp new file mode 100644 index 00000000000..5d791cec410 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.cpp @@ -0,0 +1,28 @@ +#include "device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.cpp new file mode 100644 index 00000000000..16c0409134a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.cpp @@ -0,0 +1,52 @@ +#include "device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.cpp new file mode 100644 index 00000000000..7af7bc03f28 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.cpp @@ -0,0 +1,25 @@ +#include "device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); +// clang-format on +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.cpp new file mode 100644 index 00000000000..9580aae057d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.cpp @@ -0,0 +1,40 @@ +#include "device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/utility/CMakeLists.txt b/library/src/utility/CMakeLists.txt new file mode 100644 index 00000000000..0914855d59f --- /dev/null +++ b/library/src/utility/CMakeLists.txt @@ -0,0 +1,21 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/include/ck + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/device + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/element + ${PROJECT_SOURCE_DIR}/include/ck/utility + ${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor + ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu + ${PROJECT_SOURCE_DIR}/library/include/ck/library/utility +) + +set(CONV_UTIL_SOURCE + conv_util.cpp +) + +add_library(conv_util SHARED ${CONV_UTIL_SOURCE}) +target_link_libraries(conv_util PRIVATE host_tensor) +target_compile_features(conv_util PUBLIC) +set_target_properties(conv_util PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_include_directories(conv_util SYSTEM PUBLIC $) + +clang_tidy_check(conv_util) diff --git a/library/src/utility/conv_util.cpp b/library/src/utility/conv_util.cpp new file mode 100644 index 00000000000..a60d1a34952 --- /dev/null +++ b/library/src/utility/conv_util.cpp @@ -0,0 +1,240 @@ + +#include "conv_util.hpp" + +namespace ck { +namespace utils { +namespace conv { + +/** + * @brief Calculate number of FLOPs for Convolution + * + * @param[in] N Batch size. + * @param[in] C Number of input channels. + * @param[in] K Number of output channels. + * @param[in] filter_spatial_lengths Filter spatial dimensions lengths. + * @param[in] output_spatial_lengths Convolution output spatial dimensions + * lengths. + * + * @return The number of flops. + */ +std::size_t get_flops(ck::index_t N, + ck::index_t C, + ck::index_t K, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths) +{ + // 2 * N * K * * C * + return static_cast(2) * N * K * + std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + static_cast(1), + std::multiplies()) * + C * + std::accumulate(std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths), + static_cast(1), + std::multiplies()); +} + +ConvParams::ConvParams() + : num_dim_spatial_(2), + N_(128), + K_(256), + C_(192), + filter_spatial_lengths_(2, 3), + input_spatial_lengths_(2, 71), + conv_filter_strides_(2, 2), + conv_filter_dilations_(2, 1), + input_left_pads_(2, 1), + input_right_pads_(2, 1) +{ +} + +ConvParams::ConvParams(ck::index_t n_dim, + ck::index_t n_batch, + ck::index_t n_out_channels, + ck::index_t n_in_channels, + const std::vector& filters_len, + const std::vector& input_len, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads) + : num_dim_spatial_(n_dim), + N_(n_batch), + K_(n_out_channels), + C_(n_in_channels), + filter_spatial_lengths_(filters_len), + input_spatial_lengths_(input_len), + conv_filter_strides_(strides), + conv_filter_dilations_(dilations), + input_left_pads_(left_pads), + input_right_pads_(right_pads) +{ + if(ck::type_convert(filter_spatial_lengths_.size()) != num_dim_spatial_ || + ck::type_convert(input_spatial_lengths_.size()) != num_dim_spatial_ || + ck::type_convert(conv_filter_strides_.size()) != num_dim_spatial_ || + ck::type_convert(conv_filter_dilations_.size()) != num_dim_spatial_ || + ck::type_convert(input_left_pads_.size()) != num_dim_spatial_ || + ck::type_convert(input_right_pads_.size()) != num_dim_spatial_) + { + throw( + std::runtime_error("ConvParams::GetOutputSpatialLengths: " + "parameter size is different from number of declared dimensions!")); + } +} + +std::vector ConvParams::GetOutputSpatialLengths() const +{ + if(ck::type_convert(filter_spatial_lengths_.size()) != num_dim_spatial_ || + ck::type_convert(input_spatial_lengths_.size()) != num_dim_spatial_ || + ck::type_convert(conv_filter_strides_.size()) != num_dim_spatial_ || + ck::type_convert(conv_filter_dilations_.size()) != num_dim_spatial_ || + ck::type_convert(input_left_pads_.size()) != num_dim_spatial_ || + ck::type_convert(input_right_pads_.size()) != num_dim_spatial_) + { + throw( + std::runtime_error("ConvParams::GetOutputSpatialLengths: " + "parameter size is different from number of declared dimensions!")); + } + + std::vector out_spatial_len(num_dim_spatial_, 0); + for(ck::index_t i = 0; i < num_dim_spatial_; ++i) + { + // XEff = (X - 1) * conv_dilation_w + 1; + // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + const ck::index_t idx_eff = + (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1; + out_spatial_len[i] = + (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - idx_eff) / + conv_filter_strides_[i] + + 1; + } + return out_spatial_len; +} + +ConvParams parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[]) +{ + ck::utils::conv::ConvParams params; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +HostTensorDescriptor get_output_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, tl::NDHWK{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, tl::NHWK{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, tl::NWK{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, tl::KZYXC{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, tl::KYXC{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, tl::KXC{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, tl::NDHWC{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, tl::NHWC{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, tl::NWC{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +} // namespace conv +} // namespace utils +} // namespace ck + +std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParams& p) +{ + os << "ConvParams {" + << "\nnum_dim_spatial: " << p.num_dim_spatial_ << "\nN: " << p.N_ << "\nK: " << p.K_ + << "\nC: " << p.C_ << "\nfilter_spatial_lengths: " << p.filter_spatial_lengths_ + << "\ninput_spatial_lengths: " << p.input_spatial_lengths_ + << "\nconv_filter_strides: " << p.conv_filter_strides_ + << "\nconv_filter_dilations: " << p.conv_filter_dilations_ + << "\ninput_left_pads: " << p.input_left_pads_ + << "\ninput_right_pads: " << p.input_right_pads_; + return os; +} diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt new file mode 100644 index 00000000000..ee0050d2005 --- /dev/null +++ b/profiler/CMakeLists.txt @@ -0,0 +1,64 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/include/ck + ${PROJECT_SOURCE_DIR}/include/ck/utility + ${PROJECT_SOURCE_DIR}/include/ck/host_utility + ${PROJECT_SOURCE_DIR}/include/ck/tensor_description + ${PROJECT_SOURCE_DIR}/include/ck/tensor + ${PROJECT_SOURCE_DIR}/include/ck/problem_transform + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/device + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/grid + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/block + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/warp + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/thread + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/element + ${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor + ${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance + ${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance/gpu/reduce + ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu + ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu + ${PROJECT_SOURCE_DIR}/library/include/ck/library/utility + ${PROJECT_SOURCE_DIR}/profiler/include + ${PROJECT_SOURCE_DIR}/external/include/half +) + +# ck_profiler +set(PROFILER_SOURCE + src/profiler.cpp + src/profile_gemm.cpp + src/profile_gemm_bias_2d.cpp + src/profile_gemm_bias_relu.cpp + src/profile_gemm_bias_relu_add.cpp + src/profile_gemm_reduce.cpp + src/profile_batched_gemm.cpp + src/profile_conv_fwd_bias_relu.cpp + src/profile_conv_fwd_bias_relu_add.cpp + src/profile_conv_fwd_bias_relu_atomic_add.cpp + src/profile_convnd_fwd.cpp + src/profile_convnd_bwd_data.cpp + src/profile_reduce.cpp + src/profile_grouped_gemm.cpp + src/profile_conv_bwd_weight.cpp + src/profile_batched_gemm_reduce.cpp +) + +add_executable(ckProfiler ${PROFILER_SOURCE}) + +target_link_libraries(ckProfiler PRIVATE host_tensor) +target_link_libraries(ckProfiler PRIVATE conv_util) +target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) +target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) +target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) +target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) +target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance) +target_link_libraries(ckProfiler PRIVATE device_reduce_instance) +target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance) +target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance) diff --git a/profiler/README.md b/profiler/README.md new file mode 100644 index 00000000000..bfd6a3a53be --- /dev/null +++ b/profiler/README.md @@ -0,0 +1,48 @@ +## Profile GEMM kernels +```bash +#arg1: tensor operation (gemm=GEMM) +#arg2: data type (0=fp32, 1=fp16) +#arg3: matrix layout (0=NN, 1=NT, 2=TN, 3=TT) +#arg4: verification (0=no, 1=yes) +#arg5: initialization (0=no init, 1=integer value, 2=decimal value) +#arg6: print matrix value (0=no, 1=yes) +#arg7: run kernel # of times (>1) +#arg8 to 13: M, N, K, StrideA, StrideB, StrideC + +################ op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC +./bin/ckProfiler gemm 1 1 1 1 0 5 3840 4096 4096 4096 4096 4096 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +```bash +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +.... +Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s +``` + +## Profile 2d forward convolution kernels +```bash +#arg1: tensor operation (conv=Convolution) +#arg2: data type (0=fp32, 1=fp16) +#arg3: input tensor layout (0=NCHW, 1=NHWC) +#arg4: weight tensor layout (0=KCYX, 1=KYXC) +#arg5: output tensor layout (0=NKHW, 1=NHWK) +#arg6: verification (0=no, 1=yes) +#arg7: initialization (0=no init, 1=integer value, 2=decimal value) +#arg8: print matrix value (0=no, 1=yes) +#arg9: run kernel # of times (>1) +#arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx + ################ op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads + ./bin/ckProfiler conv2d_fwd 1 1 1 1 1 1 0 5 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} +out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +.... +Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s +``` diff --git a/profiler/include/profile_batched_gemm_impl.hpp b/profiler/include/profile_batched_gemm_impl.hpp new file mode 100644 index 00000000000..3393110c33e --- /dev/null +++ b/profiler/include/profile_batched_gemm_impl.hpp @@ -0,0 +1,428 @@ +#pragma once + +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "element_wise_operation.hpp" +#include "tensor_layout.hpp" +#include "device.hpp" +#include "host_tensor_generator.hpp" +#include "device_gemm.hpp" +#include "reference_batched_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_batched_gemm_instance { + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(std::vector&); +void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(std::vector&); +void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(std::vector&); +void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(std::vector&); +void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(std::vector&); +void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(std::vector&); +void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(std::vector&); +void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(std::vector&); +void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances( + std::vector&); + +} // namespace device_batched_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +bool profile_batched_gemm_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int BatchCount) +{ + bool pass = true; + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({row * stride, stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({col * stride, 1, stride})); + } + }; + + Tensor a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); + Tensor b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); + Tensor c_g_m_n_host_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + Tensor c_g_m_n_device_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + std::unique_ptr> c_f32_g_m_n_host_result = nullptr; + std::unique_ptr> c_f32_g_m_n_device_result = nullptr; + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; + + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + // set zero to c_device_buf + c_g_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + if(do_verification) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + Tensor a_f32_g_m_k( + f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); + Tensor b_f32_g_k_n( + f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); + c_f32_g_m_n_host_result = std::make_unique>( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + c_f32_g_m_n_device_result = std::make_unique>( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + + bf16_to_f32_(a_g_m_k, a_f32_g_m_k); + bf16_to_f32_(b_g_k_n, b_f32_g_k_n); + + using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: + ReferenceBatchedGemm; + + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + auto ref_argument = ref_batched_gemm.MakeArgument(a_f32_g_m_k, + b_f32_g_k_n, + *c_f32_g_m_n_host_result, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + } + else + { + + using ReferenceBatchedGemmInstance = + ck::tensor_operation::host::ReferenceBatchedGemm; + + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_g_m_k.mData.data()); + b_device_buf.ToDevice(b_g_k_n.mData.data()); + c_device_buf.ToDevice(c_g_m_n_device_result.mData.data()); + + // add device GEMM instances + std::vector + gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(gemm_ptrs); + } + } + else if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(gemm_ptrs); + } + } + else if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(gemm_ptrs); + } + } + else if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + BatchCount); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string gemm_name = gemm_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * BatchCount * M * N * K; + + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); + + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + + bf16_to_f32_(c_g_m_n_device_result, *c_f32_g_m_n_device_result); + float err = check_error(*c_f32_g_m_n_host_result, *c_f32_g_m_n_device_result); + pass = pass && (err < 1E-6); + } + else + { + float err = check_error(c_g_m_n_host_result, c_g_m_n_device_result); + pass = pass && (err < 1E-6); + } + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_g_m_n_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "this device GEMM instance does not support this GEMM problem" + << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_batched_gemm_reduce_impl.hpp b/profiler/include/profile_batched_gemm_reduce_impl.hpp new file mode 100644 index 00000000000..56ca2cbebe4 --- /dev/null +++ b/profiler/include/profile_batched_gemm_reduce_impl.hpp @@ -0,0 +1,359 @@ +#pragma once + +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "reduction_operator.hpp" +#include "device_gemm_reduce.hpp" +#include "reference_batched_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; +using F16 = ck::half_t; +using DPtrsGlobal = ck::Tuple; +using Identity = ck::tensor_operation::element_wise::UnaryIdentic; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr< + DPtrsGlobal, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + DInElementOps, + DOutElementOps>; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( + std::vector&); + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( + std::vector&); + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( + std::vector&); + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( + std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +bool profile_batched_gemm_reduce_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int BatchCount) +{ + bool pass = true; + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({row * stride, stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({col * stride, 1, stride})); + } + }; + + Tensor a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); + Tensor b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); + + Tensor c_g_m_n_host_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + Tensor d0_g_m_host_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + Tensor d1_g_m_host_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + + Tensor c_g_m_n_device_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + Tensor d0_g_m_device_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + Tensor d1_g_m_device_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; + std::cout << "d0_g_m: " << d0_g_m_host_result.mDesc << std::endl; + std::cout << "d1_g_m: " << d1_g_m_host_result.mDesc << std::endl; + + std::size_t num_thread = std::thread::hardware_concurrency(); + switch(init_method) + { + case 0: break; + case 1: + std::srand(0); + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + std::srand(0); + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + using D0ReduceOp = ck::reduce::Add; + using D1ReduceOp = ck::reduce::Add; + using UnaryIdenticElementOp = + ck::tensor_operation::element_wise::UnaryIdentic; + using UnarySquareElementOp = + ck::tensor_operation::element_wise::UnarySquare; + using DxsInElementOps = ck::Tuple; + using DxsOutElementOps = ck::Tuple; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + const auto dxs_in_element_op = DxsInElementOps{}; + const auto dxs_out_element_op = DxsOutElementOps{}; + const auto d0_reduce_op = D0ReduceOp{}; + const auto d1_reduce_op = D1ReduceOp{}; + + if(do_verification) + { + using ReferenceBatchedGemmInstance = + ck::tensor_operation::host::ReferenceBatchedGemm; + + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + for(int batch = 0; batch < BatchCount; ++batch) + { + for(int m = 0; m < M; ++m) + { + float d0_acc = d0_reduce_op.GetReductionZeroVal(); + float d1_acc = d1_reduce_op.GetReductionZeroVal(); + + for(int n = 0; n < N; ++n) + { + float d0_val = ck::type_convert(c_g_m_n_host_result(batch, m, n)); + float d1_val; + + UnarySquareElementOp{}(d1_val, d0_val); + d0_reduce_op(d0_acc, d0_val); + d1_reduce_op(d1_acc, d1_val); + } + + d0_g_m_host_result(batch, m) = ck::type_convert(d0_acc); + d1_g_m_host_result(batch, m) = ck::type_convert(d1_acc); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem d0_device_buf(sizeof(DDataType) * d0_g_m_device_result.mDesc.GetElementSpace()); + DeviceMem d1_device_buf(sizeof(DDataType) * d1_g_m_device_result.mDesc.GetElementSpace()); + + auto dxs_global = ck::make_tuple(static_cast(d0_device_buf.GetDeviceBuffer()), + static_cast(d1_device_buf.GetDeviceBuffer())); + + a_device_buf.ToDevice(a_g_m_k.mData.data()); + b_device_buf.ToDevice(b_g_k_n.mData.data()); + + // add device GEMM instances + std::vector + gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( + gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + dxs_global, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + dxs_in_element_op, + dxs_out_element_op, + BatchCount); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + // init DO, D1 to 0 + d0_device_buf.SetZero(); + d1_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + std::size_t flop = std::size_t(2) * BatchCount * M * N * K; + std::size_t num_btype = sizeof(ADataType) * BatchCount * M * K + + sizeof(BDataType) * BatchCount * K * N + + sizeof(CDataType) * BatchCount * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); + d0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); + d1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); + + float c_error = check_error(c_g_m_n_host_result, c_g_m_n_device_result); + float d0_error = check_error(d0_g_m_host_result, d0_g_m_device_result); + float d1_error = check_error(d1_g_m_host_result, d1_g_m_device_result); + + pass = pass && (c_error < 1E-6); + pass = pass && (d0_error < 1E-6); + pass = pass && (d1_error < 1E-6); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_g_m_n_device_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d0_host: ", d0_g_m_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d0_device: ", d0_g_m_device_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d1_host: ", d1_g_m_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d1_device: ", d1_g_m_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_conv_bwd_weight_impl.hpp b/profiler/include/profile_conv_bwd_weight_impl.hpp new file mode 100644 index 00000000000..8e3a4074b08 --- /dev/null +++ b/profiler/include/profile_conv_bwd_weight_impl.hpp @@ -0,0 +1,281 @@ +#pragma once + +#include "stream_config.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "device_conv_backward_weight.hpp" +#include "element_wise_operation.hpp" +#include "reference_conv_backward_weight.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_weight_instance { + +using DeviceConvBwdWeightNoOpPtr = + DeviceConvBwdWeightPtr; + +void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector&); + +void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector&); + +} // namespace device_conv2d_bwd_weight_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +bool profile_conv_bwd_weight_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t split_k) +{ + const ck::index_t Y = filter_spatial_lengths[0]; + const ck::index_t X = filter_spatial_lengths[1]; + + const ck::index_t Hi = input_spatial_lengths[0]; + const ck::index_t Wi = input_spatial_lengths[1]; + + const ck::index_t Ho = output_spatial_lengths[0]; + const ck::index_t Wo = output_spatial_lengths[1]; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { + if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x_host_result(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor wei_k_c_y_x_device_result( + f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); + } + + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + if(do_verification) + { + using ReferenceConvBwdWeightInstance = + ck::tensor_operation::host::ReferenceConvBwdWeight; + + auto ref_conv = ReferenceConvBwdWeightInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x_host_result, + out_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * + wei_k_c_y_x_device_result.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using DeviceConvBwdWeightNoOpPtr = + ck::tensor_operation::device::DeviceConvBwdWeightPtr; + + // add device Conv instances + std::vector conv_ptrs; + + if constexpr(ck::is_same_v, float> && + ck::is_same_v, float> && + ck::is_same_v, float>) + { + ck::tensor_operation::device::device_conv2d_bwd_weight_instance:: + add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + } + else if constexpr(ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t>) + { + ck::tensor_operation::device::device_conv2d_bwd_weight_instance:: + add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + } + + if(conv_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + bool pass = true; + + for(auto& conv_ptr : conv_ptrs) + { + // using atomic, so need to reset input + if(split_k > 1) + { + wei_device_buf.SetZero(); + } + + auto argument_ptr = conv_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k); + + auto invoker_ptr = conv_ptr->MakeInvokerPointer(); + + if(conv_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string conv_name = conv_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << conv_name << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data()); + + float max_error = check_error(wei_k_c_y_x_host_result, wei_k_c_y_x_device_result); + + if(max_error > 8) + { + pass = false; + std::cout << "Fail info:" << conv_ptr->GetTypeString() << std::endl; + } + + if(do_log) + { + LogRangeAsType(std::cout << "out: ", out_n_k_ho_wo.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "wei_device: ", wei_k_c_y_x_device_result.mData, ",") + << std::endl; + } + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp new file mode 100644 index 00000000000..5ea35cd72f1 --- /dev/null +++ b/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp @@ -0,0 +1,276 @@ +#pragma once + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "device_conv_fwd_bias_activation_add.hpp" +#include "reference_conv_fwd_bias_activation_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_bias_activation_add_instance { + +using DeviceConvFwdBiasReluAddPtr = + DeviceConvFwdBiasActivationAddPtr; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances( + std::vector&); + +} // namespace device_conv2d_fwd_bias_activation_add_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_conv_fwd_bias_relu_add_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) +{ + const ck::index_t Y = filter_spatial_lengths[0]; + const ck::index_t X = filter_spatial_lengths[1]; + + const ck::index_t Hi = input_spatial_lengths[0]; + const ck::index_t Wi = input_spatial_lengths[1]; + + const ck::index_t Ho = output_spatial_lengths[0]; + const ck::index_t Wo = output_spatial_lengths[1]; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { + if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo_host_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + Tensor out_n_k_ho_wo_device_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + // bias: assume contiguous 1d vector + Tensor bias_k( + HostTensorDescriptor(std::vector({static_cast(K)}))); + + // residual: assume same layout as output tensor + Tensor resi_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; + std::cout << "bias_k: " << bias_k.mDesc << std::endl; + std::cout << "resi_n_k_ho_wo: " << resi_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + resi_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + resi_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + if(do_verification) + { + using ReferenceConvFwdInstance = + ck::tensor_operation::host::ReferenceConvFwd_Bias_Activation_Add; + + auto ref_conv = ReferenceConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + bias_k, + resi_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace()); + DeviceMem resi_device_buf(sizeof(OutDataType) * resi_n_k_ho_wo.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + bias_device_buf.ToDevice(bias_k.mData.data()); + resi_device_buf.ToDevice(resi_n_k_ho_wo.mData.data()); + + using DeviceConvFwdBiasReluAddPtr = ck::tensor_operation::device:: + DeviceConvFwdBiasActivationAddPtr; + + // add device operator instances + std::vector op_ptrs; + + if constexpr(ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t>) + { + ck::tensor_operation::device::device_conv2d_fwd_bias_activation_add_instance:: + add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(op_ptrs); + } + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(bias_device_buf.GetDeviceBuffer()), + static_cast(resi_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string conv_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = + sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K) + + sizeof(OutDataType) * (N * K * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << conv_name << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); + + ck::utils::check_err(out_n_k_ho_wo_device_result.mData, + out_n_k_ho_wo_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "wei: ", wei_k_c_y_x.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",") + << std::endl; + } + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp new file mode 100644 index 00000000000..f1c2fd300ac --- /dev/null +++ b/profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp @@ -0,0 +1,331 @@ +#pragma once +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "device_conv_fwd_bias_activation.hpp" +#include "element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_bias_activation_atomic_add_instance { + +using DeviceConvFwdBiasReluPtr = + DeviceConvFwdBiasActivationPtr; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances( + std::vector&); + +} // namespace device_conv2d_fwd_bias_activation_atomic_add_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +void cpu_conv_bias_relu_atomic_add(ck::half_t* in_ptr, + ck::half_t* weight_ptr, + ck::half_t* output_ptr, + ck::half_t* bias_ptr, + const ck::index_t N, + const ck::index_t K, + const ck::index_t C, + const ck::index_t Y, + const ck::index_t X, + const ck::index_t Hi, + const ck::index_t Wi, + const ck::index_t Ho, + const ck::index_t Wo, + const ck::index_t Stride, + const ck::index_t Dilation, + const ck::index_t Pad) +{ + + const auto in_desc = + HostTensorDescriptor(std::vector{static_cast(N), + static_cast(Hi), + static_cast(Wi), + static_cast(C)}); + const auto wei_desc = + HostTensorDescriptor(std::vector{static_cast(K), + static_cast(Y), + static_cast(X), + static_cast(C)}); + const auto out_desc = + HostTensorDescriptor(std::vector{static_cast(N), + static_cast(Ho), + static_cast(Wo), + static_cast(K)}); + const auto bias_desc = + HostTensorDescriptor(std::vector{static_cast(K)}); + + auto f_k = [&](auto k) { + for(int n = 0; n < N; ++n) + { + for(int ho = 0; ho < Ho; ++ho) + { + for(int wo = 0; wo < Wo; ++wo) + { + double v = 0; + for(int c = 0; c < C; ++c) + { + for(int y = 0; y < Y; ++y) + { + int hi = ho * Stride + y * Dilation - Pad; + for(int x = 0; x < X; ++x) + { + int wi = wo * Stride + x * Dilation - Pad; + if(hi >= 0 && hi < Hi && wi >= 0 && wi < Wi) + { + double in = + in_ptr[in_desc.GetOffsetFromMultiIndex(n, hi, wi, c)]; + double wei = + weight_ptr[wei_desc.GetOffsetFromMultiIndex(k, y, x, c)]; + + v += in * wei; + } + } + } + } + + v += bias_ptr[bias_desc.GetOffsetFromMultiIndex(k)]; + + v = v > 0 ? v : 0; + + output_ptr[out_desc.GetOffsetFromMultiIndex(n, ho, wo, k)] = v; + } + } + } + }; + + make_ParallelTensorFunctor(f_k, K)(std::thread::hardware_concurrency()); +} + +template +void profile_conv_fwd_bias_relu_atomic_add_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) +{ + const ck::index_t Y = filter_spatial_lengths[0]; + const ck::index_t X = filter_spatial_lengths[1]; + + const ck::index_t Hi = input_spatial_lengths[0]; + const ck::index_t Wi = input_spatial_lengths[1]; + + const ck::index_t Ho = output_spatial_lengths[0]; + const ck::index_t Wo = output_spatial_lengths[1]; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { + if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo_host_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + Tensor out_n_k_ho_wo_device_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + // bias: assume contiguous 1d vector + Tensor bias_k( + HostTensorDescriptor(std::vector({static_cast(K)}))); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; + std::cout << "bias_k: " << bias_k.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::AddRelu; + + if(do_verification) + { + cpu_conv_bias_relu_atomic_add(in_n_c_hi_wi.mData.data(), + wei_k_c_y_x.mData.data(), + out_n_k_ho_wo_host_result.mData.data(), + bias_k.mData.data(), + N, + K, + C, + Y, + X, + Hi, + Wi, + Ho, + Wo, + conv_filter_strides[0], + conv_filter_dilations[0], + input_left_pads[0]); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + bias_device_buf.ToDevice(bias_k.mData.data()); + + using DeviceConvFwdBiasReluPtr = ck::tensor_operation::device:: + DeviceConvFwdBiasActivationPtr; + + // add device operator instances + std::vector op_ptrs; + + if constexpr(ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t>) + { + ck::tensor_operation::device::device_conv2d_fwd_bias_activation_atomic_add_instance:: + add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances( + op_ptrs); + } + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(bias_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string conv_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = + sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << conv_name << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); + + ck::utils::check_err(out_n_k_ho_wo_device_result.mData, + out_n_k_ho_wo_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "wei: ", wei_k_c_y_x.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",") + << std::endl; + } + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_conv_fwd_bias_relu_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_impl.hpp new file mode 100644 index 00000000000..eeb2b93e4ee --- /dev/null +++ b/profiler/include/profile_conv_fwd_bias_relu_impl.hpp @@ -0,0 +1,263 @@ +#pragma once +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "device_conv_fwd_bias_activation.hpp" +#include "reference_conv_fwd_bias_activation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_bias_activation_instance { + +using DeviceConvFwdBiasReluPtr = + DeviceConvFwdBiasActivationPtr; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances( + std::vector&); + +} // namespace device_conv2d_fwd_bias_activation_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_conv_fwd_bias_relu_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) +{ + const ck::index_t Y = filter_spatial_lengths[0]; + const ck::index_t X = filter_spatial_lengths[1]; + + const ck::index_t Hi = input_spatial_lengths[0]; + const ck::index_t Wi = input_spatial_lengths[1]; + + const ck::index_t Ho = output_spatial_lengths[0]; + const ck::index_t Wo = output_spatial_lengths[1]; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { + if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo_host_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + Tensor out_n_k_ho_wo_device_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + // bias: assume contiguous 1d vector + Tensor bias_k( + HostTensorDescriptor(std::vector({static_cast(K)}))); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; + std::cout << "bias_k: " << bias_k.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::AddRelu; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + if(do_verification) + { + using ReferenceConvFwdInstance = + ck::tensor_operation::host::ReferenceConvFwd_Bias_Activation; + + auto ref_conv = ReferenceConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + bias_k, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + ref_invoker.Run(ref_argument); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + bias_device_buf.ToDevice(bias_k.mData.data()); + + using DeviceConvFwdBiasReluPtr = ck::tensor_operation::device:: + DeviceConvFwdBiasActivationPtr; + + // add device operator instances + std::vector op_ptrs; + + if constexpr(ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t>) + { + ck::tensor_operation::device::device_conv2d_fwd_bias_activation_instance:: + add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(op_ptrs); + } + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(bias_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string conv_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = + sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << conv_name << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); + + ck::utils::check_err(out_n_k_ho_wo_device_result.mData, + out_n_k_ho_wo_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "wei: ", wei_k_c_y_x.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",") + << std::endl; + } + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_convnd_bwd_data_impl.hpp b/profiler/include/profile_convnd_bwd_data_impl.hpp new file mode 100644 index 00000000000..291bf2abc08 --- /dev/null +++ b/profiler/include/profile_convnd_bwd_data_impl.hpp @@ -0,0 +1,481 @@ +#pragma once +#include "config.hpp" +#include "device.hpp" +#include "conv_util.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "device_conv_bwd_data.hpp" +#include "element_wise_operation.hpp" +#include "reference_conv_bwd_data.hpp" + +using F16 = ck::half_t; +using F32 = float; +using BF16 = ck::bhalf_t; +using INT8 = int8_t; +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using DeviceConvBwdDataNoOpPtr = + DeviceConvBwdDataPtr; +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances( + std::vector&); +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances( + std::vector&); +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances( + std::vector&); +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( + std::vector&); + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector&); + +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances( + std::vector&); +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances( + std::vector&); +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( + std::vector&); +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( + std::vector&); +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { +using DeviceConvBwdDataNoOpPtr = + ck::tensor_operation::device::device_conv2d_bwd_data_instance::DeviceConvBwdDataNoOpPtr; + +template +HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} +template +HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} +template +HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} +template +void get_device_conv_bwd_data_op_ptr( + InDataType, WeiDataType, OutDataType, std::vector&, int) +{ + std::cout << "can not find device conv bwd data" << std::endl; + exit(1); +} +template <> +void get_device_conv_bwd_data_op_ptr( + F32, F32, F32, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); + break; + default: break; + } +} +template <> +void get_device_conv_bwd_data_op_ptr( + F16, F16, F16, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); + break; + default: break; + } +} +template <> +void get_device_conv_bwd_data_op_ptr( + BF16, BF16, BF16, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); + break; + default: break; + } +} +template <> +void get_device_conv_bwd_data_op_ptr( + INT8, INT8, INT8, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs); + break; + default: break; + } +} + +template +static bool check_out(const Tensor& ref, const Tensor& result) +{ + float max_diff = 1e-6; + + for(std::size_t i = 0; i < ref.mData.size(); ++i) + { + float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); + if(max_diff < diff) + { + return false; + } + } + return true; +} +template +void show_data_nhwc_layout(Tensor& nhwc) +{ + std::cout << "["; + for(int n = 0; n < ck::type_convert(nhwc.mDesc.GetLengths()[0]); n++) + { + std::cout << "["; + for(int hi = 0; hi < ck::type_convert(nhwc.mDesc.GetLengths()[2]); hi++) + { + std::cout << "["; + for(int wi = 0; wi < ck::type_convert(nhwc.mDesc.GetLengths()[3]); wi++) + { + std::cout << "["; + for(int c = 0; c < ck::type_convert(nhwc.mDesc.GetLengths()[1]); c++) + { + std::cout << static_cast(nhwc(n, c, hi, wi)) << " "; + } + std::cout << "]"; + } + std::cout << "]"; + } + std::cout << "]"; + } + std::cout << "]"; +} + +template +bool profile_convnd_bwd_data_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + ck::index_t N, + ck::index_t K, + ck::index_t C, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths, + const std::vector& conv_filter_strides, + const std::vector& conv_filter_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads) +{ + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + std::vector input_dims{static_cast(N), static_cast(C)}; + input_dims.insert( + std::end(input_dims), std::begin(input_spatial_lengths), std::end(input_spatial_lengths)); + + std::vector filter_dims{static_cast(K), static_cast(C)}; + filter_dims.insert(std::end(filter_dims), + std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths)); + + std::vector output_dims{static_cast(N), static_cast(K)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input_host_result( + get_input_host_tensor_descriptor(input_dims, NDimSpatial)); + Tensor input_device_result( + get_input_host_tensor_descriptor(input_dims, NDimSpatial)); + Tensor weights( + get_filters_host_tensor_descriptor(filter_dims, NDimSpatial)); + Tensor output( + get_output_host_ensor_descriptor(output_dims, NDimSpatial)); + + std::cout << "input: " << input_host_result.mDesc << std::endl; + std::cout << "weights: " << weights.mDesc << std::endl; + std::cout << "output: " << output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + output.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + output.GenerateTensorValue(GeneratorTensor_1{1}); + weights.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); + + out_device_buf.ToDevice(output.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + + // reset input to zero + in_device_buf.SetZero(); + + if(do_verification) + { + auto RunReference = [&](auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(input_host_result, + weights, + output, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + ref_invoker.Run(ref_argument); + }; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData(); + RunReference(ref_conv); + } + + // add device Conv instances + std::vector conv_ptrs; + get_device_conv_bwd_data_op_ptr( + InDataType{}, WeiDataType{}, OutDataType{}, conv_ptrs, NDimSpatial); + + if(conv_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + bool success = true; + for(auto& conv_ptr : conv_ptrs) + { + auto argument_ptr = conv_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + auto invoker_ptr = conv_ptr->MakeInvokerPointer(); + + if(conv_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string conv_name = conv_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + ck::utils::conv::get_flops(N, C, K, filter_spatial_lengths, output_spatial_lengths); + std::size_t num_btype = + ck::utils::conv::get_btype( + N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + in_device_buf.FromDevice(input_device_result.mData.data()); + + if(!check_out(input_host_result, input_device_result)) + { + std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl; + + success = false; + } + else + { + std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl; + } + + check_error(input_host_result, input_device_result); + + if(do_log) + { + std::cout << "in : "; + show_data_nhwc_layout(output); + std::cout << std::endl; + + std::cout << "wei: "; + show_data_nhwc_layout(weights); + std::cout << std::endl; + + std::cout << "out_host : "; + show_data_nhwc_layout(input_host_result); + std::cout << std::endl; + + std::cout << "out_device: "; + show_data_nhwc_layout(input_device_result); + std::cout << std::endl; + } + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; + return success; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_convnd_fwd.hpp b/profiler/include/profile_convnd_fwd.hpp new file mode 100644 index 00000000000..a3b55a79d1f --- /dev/null +++ b/profiler/include/profile_convnd_fwd.hpp @@ -0,0 +1,9 @@ +#pragma once + +namespace ck { +namespace profiler { + +int profile_convnd_fwd(int argc, char* argv[]); + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm_bias_2d_impl.hpp b/profiler/include/profile_gemm_bias_2d_impl.hpp new file mode 100644 index 00000000000..8565f9637c3 --- /dev/null +++ b/profiler/include/profile_gemm_bias_2d_impl.hpp @@ -0,0 +1,314 @@ +#pragma once + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "device_gemm_bias.hpp" +#include "reference_gemm_bias_2d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using DeviceGemmAlphaBetaPtr = ck::tensor_operation::device::DeviceGemmBiasPtr< + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AlphaBetaAdd>; + +void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances( + std::vector&); + +void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances( + std::vector&); + +void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances( + std::vector&); + +void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances( + std::vector&); + +void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances( + std::vector&); + +void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances( + std::vector&); + +void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances( + std::vector&); + +void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances( + std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_gemm_bias_2d_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + float alpha, + float beta) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c0_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c0_m_n: " << c0_m_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + c0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + c0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + // set zero to c_device_buf + c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::AlphaBetaAdd; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{alpha, beta}; + + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBias2D; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c0_m_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c0_device_buf(sizeof(C0DataType) * c0_m_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c0_device_buf.ToDevice(c0_m_n.mData.data()); + c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + // add device GEMM instances + std::vector + gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + } + } + else if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c0_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string gemm_name = gemm_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c0 : ", c0_m_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm_bias_relu_add_impl.hpp b/profiler/include/profile_gemm_bias_relu_add_impl.hpp new file mode 100644 index 00000000000..6fec17c1993 --- /dev/null +++ b/profiler/include/profile_gemm_bias_relu_add_impl.hpp @@ -0,0 +1,289 @@ +#pragma once + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "device_gemm_bias_activation_add.hpp" +#include "reference_gemm_bias_activation_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using DeviceGemmBiasReluAddPtr = ck::tensor_operation::device::DeviceGemmBiasActivationAddPtr< + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddReluAdd>; + +void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances( + std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_gemm_bias_relu_add_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int StrideC1, + int KBatch = 1) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + // c0_n[n] + Tensor c0_n(HostTensorDescriptor( + std::vector({static_cast(N)}), std::vector({1}))); + + // c1_m_n[m ,n] + Tensor c1_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "c0_n: " << c0_n.mDesc << std::endl; + std::cout << "c1_m_n: " << c1_m_n.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + c0_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + c1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + c0_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + c1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + // set zero to c_device_buf + c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}); + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::AddReluAdd; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + if(do_verification) + { + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmBiasActivationAdd; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + b_k_n, + c_m_n_host_result, + c0_n, + c1_m_n, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem c0_n_device_buf(sizeof(CDataType) * c0_n.mDesc.GetElementSpace()); + DeviceMem c1_m_n_device_buf(sizeof(CDataType) * c1_m_n.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + c0_n_device_buf.ToDevice(c0_n.mData.data()); + c1_m_n_device_buf.ToDevice(c1_m_n.mData.data()); + + // add device GEMM instances + std::vector + gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances( + gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = gemm_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + static_cast(c0_n_device_buf.GetDeviceBuffer()), + static_cast(c1_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideC1, + a_element_op, + b_element_op, + c_element_op, + KBatch); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string gemm_name = gemm_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + + sizeof(CDataType) * M * N + sizeof(CDataType) * N + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "a: ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c0: ", c0_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c1: ", c1_m_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host: ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm_bias_relu_impl.hpp b/profiler/include/profile_gemm_bias_relu_impl.hpp new file mode 100644 index 00000000000..69010becc5b --- /dev/null +++ b/profiler/include/profile_gemm_bias_relu_impl.hpp @@ -0,0 +1,267 @@ +#pragma once + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "device_gemm_bias_activation.hpp" +#include "reference_gemm_bias_activation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using DeviceGemmBiasReluPtr = ck::tensor_operation::device::DeviceGemmBiasActivationPtr< + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddRelu>; + +void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances( + std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_gemm_bias_relu_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int KBatch = 1) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + // c0_n[n] + Tensor c0_n(HostTensorDescriptor( + std::vector({static_cast(N)}), std::vector({1}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "c0_n: " << c0_n.mDesc << std::endl; + + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + c0_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + c0_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + // set zero to c_device_buf + c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::AddRelu; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + if(do_verification) + { + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmBiasActivation; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, c0_n, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem c0_n_device_buf(sizeof(CDataType) * c0_n.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + c0_n_device_buf.ToDevice(c0_n.mData.data()); + + // add device GEMM instances + std::vector + gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = gemm_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + static_cast(c0_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + KBatch); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string gemm_name = gemm_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + + sizeof(CDataType) * M * N + sizeof(CDataType) * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c0 : ", c0_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp new file mode 100644 index 00000000000..ff6f8ad6f7d --- /dev/null +++ b/profiler/include/profile_gemm_impl.hpp @@ -0,0 +1,626 @@ +#pragma once +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "device_gemm.hpp" +#include "reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector&); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( + std::vector&); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector&); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector&); + +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( + std::vector&); + +void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector&); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector&); + +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector&); + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector&); + +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector&); + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector&); + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_gemm_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int KBatch) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + + std::size_t num_thread = 1; + switch(init_method) + { + // case 0: break; + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + // set zero to c_device_buf + c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + // add device GEMM instances + std::vector gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + } + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + } + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + } + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + } + } + } + else if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + } + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + } + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + } + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + } + } + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(gemm_ptrs); + } + } + else if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + KBatch); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init C to zero before profiling next kernel + c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + Tensor a_f32_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_f32_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_f32_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + bf16_to_f32_(a_m_k, a_f32_m_k); + bf16_to_f32_(b_k_n, b_f32_k_n); + bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result); + + using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_f32_m_k, + b_f32_k_n, + c_m_n_host_result, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData); + + if(do_log) + { + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + } + } + else + { + Tensor c_m_n_host_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + + if(do_log) + { + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + } + } + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << gemm_ptr->GetTypeString() << " does not support this GEMM problem" + << std::endl; + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = int8"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time + << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " + << best_gemm_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm_reduce_impl.hpp b/profiler/include/profile_gemm_reduce_impl.hpp new file mode 100644 index 00000000000..97d0f2523b3 --- /dev/null +++ b/profiler/include/profile_gemm_reduce_impl.hpp @@ -0,0 +1,340 @@ +#pragma once +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "reduction_operator.hpp" +#include "device_gemm_reduce.hpp" +#include "reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; +using F16 = ck::half_t; +using DPtrsGlobal = ck::Tuple; +using Identity = ck::tensor_operation::element_wise::UnaryIdentic; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr< + DPtrsGlobal, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + DInElementOps, + DOutElementOps>; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector&); + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector&); + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector&); + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +bool profile_gemm_reduce_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor d0_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor d1_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor d0_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor d1_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl; + std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl; + + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + std::srand(0); + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + std::srand(0); + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + using D0ReduceOp = ck::reduce::Add; + using D1ReduceOp = ck::reduce::Add; + using UnaryIdenticElementOp = + ck::tensor_operation::element_wise::UnaryIdentic; + using UnarySquareElementOp = + ck::tensor_operation::element_wise::UnarySquare; + using DxsInElementOps = ck::Tuple; + using DxsOutElementOps = ck::Tuple; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + const auto dxs_in_element_op = DxsInElementOps{}; + const auto dxs_out_element_op = DxsOutElementOps{}; + const auto d0_reduce_op = D0ReduceOp{}; + const auto d1_reduce_op = D1ReduceOp{}; + + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + float d0_acc = d0_reduce_op.GetReductionZeroVal(); + float d1_acc = d1_reduce_op.GetReductionZeroVal(); + + for(int n = 0; n < N; ++n) + { + float d0_val = ck::type_convert(c_m_n_host_result(m, n)); + float d1_val; + + UnarySquareElementOp{}(d1_val, d0_val); + d0_reduce_op(d0_acc, d0_val); + d1_reduce_op(d1_acc, d1_val); + } + + d0_m_host_result(m) = ck::type_convert(d0_acc); + d1_m_host_result(m) = ck::type_convert(d1_acc); + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace()); + DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace()); + + auto dxs_global = ck::make_tuple(static_cast(d0_device_buf.GetDeviceBuffer()), + static_cast(d1_device_buf.GetDeviceBuffer())); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + // add device GEMM instances + std::vector + gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( + gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + dxs_global, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + dxs_in_element_op, + dxs_out_element_op); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + // init DO, D1 to 0 + d0_device_buf.SetZero(); + d1_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(CDataType) * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + d0_device_buf.FromDevice(d0_m_device_result.mData.data()); + d1_device_buf.FromDevice(d1_m_device_result.mData.data()); + + float c_error = check_error(c_m_n_host_result, c_m_n_device_result); + float d0_error = check_error(d0_m_host_result, d0_m_device_result); + float d1_error = check_error(d1_m_host_result, d1_m_device_result); + + pass = pass && (c_error < 1E-6); + pass = pass && (d0_error < 1E-6); + pass = pass && (d1_error < 1E-6); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host: ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d0_host: ", d0_m_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d0_device: ", d0_m_device_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d1_host: ", d1_m_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d1_device: ", d1_m_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_grouped_gemm_impl.hpp b/profiler/include/profile_grouped_gemm_impl.hpp new file mode 100644 index 00000000000..96d34c7e429 --- /dev/null +++ b/profiler/include/profile_grouped_gemm_impl.hpp @@ -0,0 +1,317 @@ +#pragma once +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "device_gemm.hpp" +#include "reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_grouped_gemm_instance { + +using DeviceGroupedGemmNoOpPtr = ck::tensor_operation::device::DeviceGroupedGemmPtr< + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>; + +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector&); +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector&); +void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector&); +void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector&); + +} // namespace device_grouped_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_grouped_gemm_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + std::size_t group_count = Ms.size(); + + if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() && + group_count == StrideBs.size() && group_count == StrideCs.size())) + { + throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n"); + } + + std::vector> a_m_k; + std::vector> b_k_n; + std::vector> c_m_n_device_results; + + for(std::size_t i = 0; i < group_count; i++) + { + a_m_k.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); + b_k_n.push_back( + Tensor(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{}))); + + c_m_n_device_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); + + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i + << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + << "]:" << c_m_n_device_results[i].mDesc << std::endl; + + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + a_m_k[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + a_m_k[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + c_m_n_device_results[i].GenerateTensorValue(GeneratorTensor_0{}, num_thread); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + // if(do_verification) + // { + + // } + + using DeviceMemPtr = std::unique_ptr; + std::vector a_device_buf, b_device_buf, c_device_buf; + + a_device_buf.reserve(group_count); + b_device_buf.reserve(group_count); + c_device_buf.reserve(group_count); + + std::vector p_a, p_b; + std::vector p_c; + + p_a.reserve(group_count); + p_b.reserve(group_count); + p_c.reserve(group_count); + + std::vector gemm_shapes; + + gemm_shapes.reserve(group_count); + + for(std::size_t i = 0; i < group_count; i++) + { + a_device_buf.emplace_back( + std::make_unique(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace())); + b_device_buf.emplace_back( + std::make_unique(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace())); + + c_device_buf.emplace_back(std::make_unique( + sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpace())); + + a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); + b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); + c_device_buf[i]->ToDevice(c_m_n_device_results[i].mData.data()); + + gemm_shapes.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i]}); + + p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); + p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); + p_c.push_back(c_device_buf[i]->GetDeviceBuffer()); + } + + // add device GEMM instances + std::vector< + ck::tensor_operation::device::device_grouped_gemm_instance::DeviceGroupedGemmNoOpPtr> + gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_grouped_gemm_instance:: + add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_grouped_gemm_instance:: + add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_grouped_gemm_instance:: + add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_grouped_gemm_instance:: + add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(p_a, + p_b, + p_c, + gemm_shapes, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string gemm_name = gemm_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t i = 0; i < gemm_shapes.size(); i++) + { + flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; + + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] + + sizeof(CDataType) * Ms[i] * Ns[i]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + for(std::size_t i = 0; i < gemm_shapes.size(); i++) + { + + c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); + + Tensor c_m_n_host_result( + f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k[i], + b_k_n[i], + c_m_n_host_result, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + ck::utils::check_err(c_m_n_device_results[i].mData, c_m_n_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k[i].mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_results[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + } + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; +} // namespace profiler + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_reduce_impl.hpp b/profiler/include/profile_reduce_impl.hpp new file mode 100644 index 00000000000..a87694754e4 --- /dev/null +++ b/profiler/include/profile_reduce_impl.hpp @@ -0,0 +1,492 @@ +#pragma once + +#include "check_err.hpp" +#include "device_reduce.hpp" +#include "device_reduce_instance.hpp" +#include "reduction_enums.hpp" +#include "host_reduction.hpp" +#include "host_common_util.hpp" +#include "host_tensor_generator.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +template +struct ReduceDescription +{ + static constexpr int Rank_ = Rank; + static constexpr int NumReduceDim_ = NumReduceDim; + static constexpr int ReduceOpId_ = ReduceOpId; + static constexpr int PropagateNan_ = PropagateNan; + static constexpr int UseIndex_ = UseIndex; +}; + +using reduce_description_instances = + std::tuple, // for ADD + ReduceDescription<4, 4, 0, false, false>, + ReduceDescription<4, 1, 0, false, false>, + ReduceDescription<2, 1, 0, false, false>, + + ReduceDescription<4, 3, 5, false, false>, // for AVG + ReduceDescription<4, 4, 5, false, false>, + ReduceDescription<4, 1, 5, false, false>, + ReduceDescription<2, 1, 5, false, false>, + + ReduceDescription<4, 3, 7, false, false>, // for NORM2 + ReduceDescription<4, 4, 7, false, false>, + ReduceDescription<4, 1, 7, false, false>, + ReduceDescription<2, 1, 7, false, false>, + + ReduceDescription<4, 3, 2, false, false>, // for MIN + ReduceDescription<4, 4, 2, false, false>, + ReduceDescription<4, 1, 2, false, false>, + ReduceDescription<2, 1, 2, false, false>, + ReduceDescription<4, 3, 3, false, false>, // for MAX + ReduceDescription<4, 4, 3, false, false>, + ReduceDescription<4, 1, 3, false, false>, + ReduceDescription<2, 1, 3, false, false>, + ReduceDescription<4, 3, 4, false, false>, // for AMAX + ReduceDescription<4, 4, 4, false, false>, + ReduceDescription<4, 1, 4, false, false>, + ReduceDescription<2, 1, 4, false, false>, + + ReduceDescription<4, 3, 2, false, true>, // for MIN + ReduceDescription<4, 4, 2, false, true>, + ReduceDescription<4, 1, 2, false, true>, + ReduceDescription<2, 1, 2, false, true>, + ReduceDescription<4, 3, 3, false, true>, // for MAX + ReduceDescription<4, 4, 3, false, true>, + ReduceDescription<4, 1, 3, false, true>, + ReduceDescription<2, 1, 3, false, true>, + ReduceDescription<4, 3, 4, false, true>, // for AMAX + ReduceDescription<4, 4, 4, false, true>, + ReduceDescription<4, 1, 4, false, true>, + ReduceDescription<2, 1, 4, false, true>>; + +template +bool description_match(const DescriptionType& description, + int Rank, + const std::vector& reduceDims, + ReduceTensorOp ReduceOpId, + bool PropagateNan, + bool UseIndex) +{ + if(description.Rank_ != Rank || description.ReduceOpId_ != static_cast(ReduceOpId) || + description.PropagateNan_ != static_cast(PropagateNan) || + description.UseIndex_ != static_cast(UseIndex)) + return (false); + + if(DescriptionType::NumReduceDim_ != reduceDims.size()) + return (false); + + bool result = true; + + return (result); +}; + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +static inline std::vector get_invariant_dims(const std::vector& reduceDims) +{ + assert(NumReduceDim == reduceDims.size()); + + int reduceFlag = 0; + + // flag the bits for the reduceDims + for(int i = 0; i < NumReduceDim; i++) + { + reduceFlag |= 1 << reduceDims[i]; + }; + + std::vector invariantDims; + + // collect invariant dimensions + for(int i = 0; i < Rank; i++) + if((reduceFlag & (1 << i)) == 0) + { + invariantDims.push_back(i); + }; + + return invariantDims; +}; + +template +bool profile_reduce_impl_impl(bool do_verification, + int init_method, + bool do_dumpout, + bool time_kernel, + const std::vector& inLengths, + const std::vector& reduceDims, + float alpha, + float beta) +{ + using namespace ck::tensor_operation::device; + using namespace ck::tensor_operation::device::device_reduce_instance; + using namespace ck::host_reduce; + using ck::host_common::dumpBufferToFile; + + constexpr bool op_support_indices = + (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || + ReduceOpId == ReduceTensorOp::AMAX); + + constexpr bool OutputIndex = (op_support_indices && UseIndex); + + constexpr bool out_support_atomic_add = std::is_same::value; + constexpr bool op_support_atomic_add = + !op_support_indices && ReduceOpId != ReduceTensorOp::NORM2; + constexpr bool use_atomic_add = (out_support_atomic_add && op_support_atomic_add); + + // 1) If InDataType is half_t, must use half_t as AccDataType for indexable reduction operations + // 2) If InDataType is half_t, must use float as AccDataType for non-indexable reduction + // operations + constexpr bool invalid_reduce_1 = + std::is_same::value && + ((!op_support_indices && !std::is_same::value) || + (op_support_indices && !std::is_same::value)); + + // 1) If InDataType is float, must use float as AccDataType for indexable reduction operations + constexpr bool invalid_reduce_2 = + std::is_same::value && + (op_support_indices && !std::is_same::value); + + // 1) The indices can only be used when the reduction operation is indexable + constexpr bool invalid_reduce_3 = (!op_support_indices && UseIndex); + + // 1) If InDataType is int8_t, must use int8_t as AccDataType for indexable reduction operations + // 2) If InDataType is int8_t, must use int32_t as AccDataType for non-indexable reduction + // operations + constexpr bool invalid_reduce_4 = + std::is_same::value && + ((!op_support_indices && !std::is_same::value) || + (op_support_indices && !std::is_same::value)); + + // 1) If InDataType is int8_t, the supported operation must be either indexable operations or + // ADD/AVG + constexpr bool invalid_reduce_5 = std::is_same::value && + (!op_support_indices && ReduceOpId != ReduceTensorOp::ADD && + ReduceOpId != ReduceTensorOp::AVG); + + // 1) If InDataType is bhalf_t, must use float as AccDataType for all reduction operations + constexpr bool invalid_reduce_6 = + std::is_same::value && !std::is_same::value; + + constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3 || + invalid_reduce_4 || invalid_reduce_5 || invalid_reduce_6); + + bool pass = true; + + if constexpr(!invalid_reduce) + { + Tensor in(inLengths); + + std::vector outLengths; + + const auto invariantDims = get_invariant_dims(reduceDims); + + if(reduceDims.size() == Rank) + outLengths.push_back(1); + else + for(auto dim : invariantDims) + outLengths.push_back(inLengths[dim]); + + Tensor out_ref(outLengths); + Tensor out(outLengths); + Tensor out_indices_ref(outLengths); + Tensor out_indices(outLengths); + + auto inStrides = in.mDesc.GetStrides(); + auto outStrides = out.mDesc.GetStrides(); + + size_t invariant_total_length = out.mDesc.GetElementSize(); + size_t reduce_total_length = in.mDesc.GetElementSize() / invariant_total_length; + + std::size_t num_thread = 1; + + if(do_verification) + { + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, + num_thread); + } + + if(beta != 0.0f) + for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++) + out.mData[i] = out_ref.mData[i]; + }; + + // these buffers are usually provided by the user application + DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace()); + DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace()); + + in_dev.ToDevice(in.mData.data()); + + if(beta != 0.0f) + out_dev.ToDevice(out.mData.data()); + + size_t indicesSizeInBytes = OutputIndex ? out.mDesc.GetElementSize() * sizeof(int) : 0; + + DeviceMem out_indices_dev(indicesSizeInBytes); + + float best_avg_time = 0; + float best_gb_per_sec = 0; + + using InElementwiseOperation_0 = + typename reduce_unary_operator:: + InElementwiseOperation; + using AccElementwiseOperation_0 = + typename reduce_unary_operator:: + AccElementwiseOperation; + + using DeviceReduceInstPtr0 = + DeviceReducePtr; + + std::vector reduce0_ptrs; + + add_device_reduce_instance_threadwise(reduce0_ptrs); + + add_device_reduce_instance_blockwise(reduce0_ptrs); + + if constexpr(use_atomic_add) + { + add_device_reduce_instance_multiblock_atomic_add(reduce0_ptrs); + } + + if(reduce0_ptrs.empty()) + { + throw std::runtime_error("Wrong! No device REDUCE instance found"); + }; + + if(do_verification) + { + ReductionHost + hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims); + + hostReduce.Run( + alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data()); + }; + + std::vector i_inLengths; + std::vector i_inStrides; + std::vector i_outLengths; + std::vector i_outStrides; + + i_inLengths.assign(inLengths.begin(), inLengths.end()); + i_inStrides.assign(inStrides.begin(), inStrides.end()); + i_outLengths.assign(outLengths.begin(), outLengths.end()); + i_outStrides.assign(outStrides.begin(), outStrides.end()); + + for(auto& reduce_ptr : reduce0_ptrs) + { + + InElementwiseOperation_0 in_elementwise_op_0(static_cast(reduce_total_length)); + AccElementwiseOperation_0 acc_elementwise_op_0( + static_cast(reduce_total_length)); + + auto argument_ptr = reduce_ptr->MakeArgumentPointer(i_inLengths, + i_inStrides, + i_outLengths, + i_outStrides, + reduceDims, + alpha, + beta, + in_dev.GetDeviceBuffer(), + nullptr, + out_dev.GetDeviceBuffer(), + out_indices_dev.GetDeviceBuffer(), + in_elementwise_op_0, + acc_elementwise_op_0); + + if(!reduce_ptr->IsSupportedArgument(argument_ptr.get())) + continue; + + std::string reduce_name = reduce_ptr->GetTypeString(); + + auto invoker_ptr = reduce_ptr->MakeInvokerPointer(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = + invariant_total_length * reduce_total_length * sizeof(InDataType) + + invariant_total_length * sizeof(OutDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " + << reduce_name << std::endl; + + if(gb_per_sec > best_gb_per_sec) + { + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + bool single_pass; + + out_dev.FromDevice(out.mData.data()); + single_pass = ck::utils::check_err(out.mData, out_ref.mData); + + if(OutputIndex) + { + out_indices_dev.FromDevice(out_indices.mData.data()); + single_pass = single_pass && + ck::utils::check_err(out_indices.mData, out_indices_ref.mData); + }; + + if(!single_pass) + { + std::cout << "Fail Info: " << reduce_ptr->GetTypeString() << std::endl; + } + + pass = pass && single_pass; + }; + + if(do_dumpout) + { + dumpBufferToFile("dump_in.bin", in.mData.data(), in.mDesc.GetElementSize()); + dumpBufferToFile("dump_out.bin", out.mData.data(), out.mDesc.GetElementSize()); + dumpBufferToFile( + "dump_out_host.bin", out_ref.mData.data(), out_ref.mDesc.GetElementSize()); + if(OutputIndex) + { + dumpBufferToFile("dump_indices.bin", + out_indices.mData.data(), + out_indices.mDesc.GetElementSize()); + dumpBufferToFile("dump_indices_host.bin", + out_indices_ref.mData.data(), + out_indices_ref.mDesc.GetElementSize()); + }; + }; + }; + + if(time_kernel) + std::cout << "Best Perf: " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s" + << std::endl; + } + else + { + std::cout << "The requested reduction operation is not supported, please check !!!" + << std::endl; + }; + + return pass; +}; + +template +bool profile_reduce_impl(bool do_verification, + int init_method, + bool do_dumpout, + bool time_kernel, + const std::vector& inLengths, + const std::vector& reduceDims, + ReduceTensorOp ReduceOpId, + bool PropagateNan, + bool UseIndex, + float alpha, + float beta) +{ + bool matched = false; + bool pass = true; + + using tuple_of_description_instances = + tensor_operation::device::device_reduce_instance::reduce_description_instances; + + const auto tuple_object = tuple_of_description_instances{}; + + static_for<0, std::tuple_size::value, 1>{}([&](auto i) { + if(matched) + return; + + using descType = remove_cvref_t(tuple_object))>; + + if(!description_match( + descType{}, inLengths.size(), reduceDims, ReduceOpId, PropagateNan, UseIndex)) + return; + + pass = pass && + profile_reduce_impl_impl(descType::ReduceOpId_), + static_cast(descType::PropagateNan_), + static_cast(descType::UseIndex_)>(do_verification, + init_method, + do_dumpout, + time_kernel, + inLengths, + reduceDims, + alpha, + beta); + + matched = true; + }); + + return pass; +}; + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/profile_batched_gemm.cpp b/profiler/src/profile_batched_gemm.cpp new file mode 100644 index 00000000000..fbdc07c3da1 --- /dev/null +++ b/profiler/src/profile_batched_gemm.cpp @@ -0,0 +1,400 @@ +#include +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_base.hpp" +#include "device_batched_gemm_xdl.hpp" +#include "profile_batched_gemm_impl.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + MK_KN_NM, // 4 + MK_NK_NM, // 5 + KM_KN_NM, // 6 + KM_NK_NM, // 7 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +int profile_batched_gemm(int argc, char* argv[]) +{ + if(!(argc == 15)) + { + printf("arg1: tensor operation (batched_gemm: Batched GEMM)\n"); + printf("arg2: data type (0: fp32; 1: fp16, 2: bf16, 3: int8)\n"); + printf("arg3: matrix layout (0: A[g, m, k] * B[g, k, n] = C[g, m, n];\n"); + printf(" 1: A[g, m, k] * B[g, n, k] = C[g, m, n];\n"); + printf(" 2: A[g, k, m] * B[g, k, n] = C[g, m, n];\n"); + printf(" 3: A[g, k, m] * B[g, n, k] = C[g, m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + const int BatchCount = std::stoi(argv[14]); + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else + { + throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_batched_gemm_reduce.cpp b/profiler/src/profile_batched_gemm_reduce.cpp new file mode 100644 index 00000000000..594fc6bedb6 --- /dev/null +++ b/profiler/src/profile_batched_gemm_reduce.cpp @@ -0,0 +1,153 @@ +#include +#include +#include +#include +#include +#include + +#include "profile_batched_gemm_reduce_impl.hpp" + +int profile_batched_gemm_reduce(int argc, char* argv[]) +{ + enum struct GemmMatrixLayout + { + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + }; + + enum struct GemmReduceDataType + { + F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F32_F32, // 1 + }; + + if(!(argc == 15 || argc == 16)) + { + printf("arg1: tensor operation (batched_gemm: BatchedGEMM+Reduce)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount\n"); + printf("arg15: split k into mulitiple batch\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + const int BatchCount = std::stoi(argv[14]); + + if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_batched_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_batched_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_batched_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_batched_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_conv_bwd_weight.cpp b/profiler/src/profile_conv_bwd_weight.cpp new file mode 100644 index 00000000000..80413322b30 --- /dev/null +++ b/profiler/src/profile_conv_bwd_weight.cpp @@ -0,0 +1,146 @@ +#include +#include +#include +#include +#include +#include +#include "profile_conv_bwd_weight_impl.hpp" + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +enum struct ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum struct ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum struct ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; + +int profile_conv_bwd_weight(int argc, char* argv[]) +{ + if(argc != 26) + { + printf("arg1: tensor operation (conv_fwd: ForwardConvolution)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: run kernel # of times (>1)\n"); + printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + printf("arg25: split k (>=1)\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto in_layout = static_cast(std::stoi(argv[3])); + const auto wei_layout = static_cast(std::stoi(argv[4])); + const auto out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const bool time_kernel = std::stoi(argv[9]); + + const ck::index_t N = std::stoi(argv[10]); + const ck::index_t K = std::stoi(argv[11]); + const ck::index_t C = std::stoi(argv[12]); + const ck::index_t Y = std::stoi(argv[13]); + const ck::index_t X = std::stoi(argv[14]); + const ck::index_t Hi = std::stoi(argv[15]); + const ck::index_t Wi = std::stoi(argv[16]); + + const ck::index_t conv_stride_h = std::stoi(argv[17]); + const ck::index_t conv_stride_w = std::stoi(argv[18]); + const ck::index_t conv_dilation_h = std::stoi(argv[19]); + const ck::index_t conv_dilation_w = std::stoi(argv[20]); + const ck::index_t in_left_pad_h = std::stoi(argv[21]); + const ck::index_t in_left_pad_w = std::stoi(argv[22]); + const ck::index_t in_right_pad_h = std::stoi(argv[23]); + const ck::index_t in_right_pad_w = std::stoi(argv[24]); + ck::index_t split_k = std::stoi(argv[25]); + split_k = std::max(1, split_k); + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv_bwd_weight_impl<2, + float, + float, + float, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + time_kernel, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}, + split_k); + } + else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv_bwd_weight_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + time_kernel, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}, + split_k); + } + else + { + throw std::runtime_error("wrong! this Conv data_type & layout is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_conv_fwd_bias_relu.cpp b/profiler/src/profile_conv_fwd_bias_relu.cpp new file mode 100644 index 00000000000..ca7dc1935ae --- /dev/null +++ b/profiler/src/profile_conv_fwd_bias_relu.cpp @@ -0,0 +1,114 @@ +#include +#include +#include +#include +#include +#include +#include "profile_conv_fwd_bias_relu_impl.hpp" + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + +enum struct ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum struct ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum struct ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; + +int profile_conv_fwd_bias_relu(int argc, char* argv[]) +{ + if(argc != 25) + { + printf("arg1: tensor operation (conv_fwd_bias_relu: ForwardConvolution+Bias+ReLu)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: time kernel (0=n0, 1=yes)\n"); + printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto in_layout = static_cast(std::stoi(argv[3])); + const auto wei_layout = static_cast(std::stoi(argv[4])); + const auto out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const bool time_kernel = std::stoi(argv[9]); + + const ck::index_t N = std::stoi(argv[10]); + const ck::index_t K = std::stoi(argv[11]); + const ck::index_t C = std::stoi(argv[12]); + const ck::index_t Y = std::stoi(argv[13]); + const ck::index_t X = std::stoi(argv[14]); + const ck::index_t Hi = std::stoi(argv[15]); + const ck::index_t Wi = std::stoi(argv[16]); + + const ck::index_t conv_stride_h = std::stoi(argv[17]); + const ck::index_t conv_stride_w = std::stoi(argv[18]); + const ck::index_t conv_dilation_h = std::stoi(argv[19]); + const ck::index_t conv_dilation_w = std::stoi(argv[20]); + const ck::index_t in_left_pad_h = std::stoi(argv[21]); + const ck::index_t in_left_pad_w = std::stoi(argv[22]); + const ck::index_t in_right_pad_h = std::stoi(argv[23]); + const ck::index_t in_right_pad_w = std::stoi(argv[24]); + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv_fwd_bias_relu_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + time_kernel, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}); + } + else + { + throw std::runtime_error("wrong! data_type & layout for this operator is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_conv_fwd_bias_relu_add.cpp b/profiler/src/profile_conv_fwd_bias_relu_add.cpp new file mode 100644 index 00000000000..5d75f5a2943 --- /dev/null +++ b/profiler/src/profile_conv_fwd_bias_relu_add.cpp @@ -0,0 +1,115 @@ +#include +#include +#include +#include +#include +#include +#include "profile_conv_fwd_bias_relu_add_impl.hpp" + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + +enum struct ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum struct ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum struct ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; + +int profile_conv_fwd_bias_relu_add(int argc, char* argv[]) +{ + if(argc != 25) + { + printf( + "arg1: tensor operation (conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLu+Add)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: time kernel (0=n0, 1=yes)\n"); + printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto in_layout = static_cast(std::stoi(argv[3])); + const auto wei_layout = static_cast(std::stoi(argv[4])); + const auto out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const bool time_kernel = std::stoi(argv[9]); + + const ck::index_t N = std::stoi(argv[10]); + const ck::index_t K = std::stoi(argv[11]); + const ck::index_t C = std::stoi(argv[12]); + const ck::index_t Y = std::stoi(argv[13]); + const ck::index_t X = std::stoi(argv[14]); + const ck::index_t Hi = std::stoi(argv[15]); + const ck::index_t Wi = std::stoi(argv[16]); + + const ck::index_t conv_stride_h = std::stoi(argv[17]); + const ck::index_t conv_stride_w = std::stoi(argv[18]); + const ck::index_t conv_dilation_h = std::stoi(argv[19]); + const ck::index_t conv_dilation_w = std::stoi(argv[20]); + const ck::index_t in_left_pad_h = std::stoi(argv[21]); + const ck::index_t in_left_pad_w = std::stoi(argv[22]); + const ck::index_t in_right_pad_h = std::stoi(argv[23]); + const ck::index_t in_right_pad_w = std::stoi(argv[24]); + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv_fwd_bias_relu_add_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + time_kernel, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}); + } + else + { + throw std::runtime_error("wrong! data_type & layout for this operator is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp b/profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp new file mode 100644 index 00000000000..96d3b10ddfa --- /dev/null +++ b/profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp @@ -0,0 +1,116 @@ +#include +#include +#include +#include +#include +#include +#include "profile_conv_fwd_bias_relu_atomic_add_impl.hpp" + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + +enum struct ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum struct ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum struct ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; + +int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[]) +{ + if(argc != 25) + { + printf("arg1: tensor operation (conv_fwd_bias_relu_atomic_add: " + "ForwardConvolution+Bias+ReLu+AtomicAdd)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: time kernel (0=n0, 1=yes)\n"); + printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto in_layout = static_cast(std::stoi(argv[3])); + const auto wei_layout = static_cast(std::stoi(argv[4])); + const auto out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const bool time_kernel = std::stoi(argv[9]); + + const ck::index_t N = std::stoi(argv[10]); + const ck::index_t K = std::stoi(argv[11]); + const ck::index_t C = std::stoi(argv[12]); + const ck::index_t Y = std::stoi(argv[13]); + const ck::index_t X = std::stoi(argv[14]); + const ck::index_t Hi = std::stoi(argv[15]); + const ck::index_t Wi = std::stoi(argv[16]); + + const ck::index_t conv_stride_h = std::stoi(argv[17]); + const ck::index_t conv_stride_w = std::stoi(argv[18]); + const ck::index_t conv_dilation_h = std::stoi(argv[19]); + const ck::index_t conv_dilation_w = std::stoi(argv[20]); + const ck::index_t in_left_pad_h = std::stoi(argv[21]); + const ck::index_t in_left_pad_w = std::stoi(argv[22]); + const ck::index_t in_right_pad_h = std::stoi(argv[23]); + const ck::index_t in_right_pad_w = std::stoi(argv[24]); + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv_fwd_bias_relu_atomic_add_impl< + 2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + time_kernel, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}); + } + else + { + throw std::runtime_error("wrong! data_type & layout for this operator is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_convnd_bwd_data.cpp b/profiler/src/profile_convnd_bwd_data.cpp new file mode 100644 index 00000000000..5d0e6a34c7b --- /dev/null +++ b/profiler/src/profile_convnd_bwd_data.cpp @@ -0,0 +1,228 @@ +#include +#include +#include +#include +#include +#include + +#include "profile_convnd_bwd_data_impl.hpp" + +namespace { + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +enum struct ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum struct ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum struct ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[], int arg_idx) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + ck::utils::conv::ConvParams params; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +} // namespace + +int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial) +{ + const int preParams = 10; + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + preParams; + if(cmdline_nargs != argc) + { + printf("arg1: tensor operation (conv[1|2|3]d_bwd_data: BackwardConvolution)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: time kernel (0=n0, 1=yes)\n"); + printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto in_layout = static_cast(std::stoi(argv[3])); + const auto wei_layout = static_cast(std::stoi(argv[4])); + const auto out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const bool time_kernel = std::stoi(argv[9]); + + ck::utils::conv::ConvParams params = parse_conv_params(num_dim_spatial, argv, preParams); + + auto Run = [&](auto input_type, auto wei_type, auto out_type, auto acc_type) { + using InDataType = decltype(input_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + using AccDataType = decltype(acc_type); + + switch(num_dim_spatial) + { + case 1: + ck::profiler::profile_convnd_bwd_data_impl<1, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + do_verification, + init_method, + do_log, + time_kernel, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_); + break; + + case 2: + ck::profiler::profile_convnd_bwd_data_impl<2, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + time_kernel, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_); + break; + + case 3: + ck::profiler::profile_convnd_bwd_data_impl<3, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>( + do_verification, + init_method, + do_log, + time_kernel, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_); + break; + + default: break; + } + }; + if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + Run(float{}, float{}, float{}, float{}); + } + else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + Run(ck::half_t{}, ck::half_t{}, ck::half_t{}, float{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + Run(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, float{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + Run(int8_t{}, int8_t{}, int8_t{}, int32_t{}); + } + else + { + std::cout << "wrong! this Conv data_type & layout is not implemented" << std::endl; + return 1; + } + + return 0; +} diff --git a/profiler/src/profile_convnd_fwd.cpp b/profiler/src/profile_convnd_fwd.cpp new file mode 100644 index 00000000000..87778a04a53 --- /dev/null +++ b/profiler/src/profile_convnd_fwd.cpp @@ -0,0 +1,351 @@ +#include +#include +#include +#include +#include +#include + +#include "conv_util.hpp" +#include "element_wise_operation.hpp" +#include "fill.hpp" +#include "profile_convnd_fwd.hpp" +#include "tensor_layout.hpp" + +namespace { + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +enum struct ConvDataLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +namespace ctl = ck::tensor_layout::convolution; + +template +struct ConvolutionLayouts; + +template <> +struct ConvolutionLayouts<1, ConvDataLayout::NHWC> +{ + typedef ctl::NWC Input; + typedef ctl::KXC Weight; + typedef ctl::NWK Output; +}; +template <> +struct ConvolutionLayouts<2, ConvDataLayout::NHWC> +{ + typedef ctl::NHWC Input; + typedef ctl::KYXC Weight; + typedef ctl::NHWK Output; +}; +template <> +struct ConvolutionLayouts<3, ConvDataLayout::NHWC> +{ + typedef ctl::NDHWC Input; + typedef ctl::KZYXC Weight; + typedef ctl::NDHWK Output; +}; +template <> +struct ConvolutionLayouts<1, ConvDataLayout::NCHW> +{ + typedef ctl::NCW Input; + typedef ctl::KCX Weight; + typedef ctl::NKW Output; +}; +template <> +struct ConvolutionLayouts<2, ConvDataLayout::NCHW> +{ + typedef ctl::NCHW Input; + typedef ctl::KCYX Weight; + typedef ctl::NKHW Output; +}; +template <> +struct ConvolutionLayouts<3, ConvDataLayout::NCHW> +{ + typedef ctl::NCDHW Input; + typedef ctl::KCZYX Weight; + typedef ctl::NKDHW Output; +}; + +void print_use_msg() +{ + std::cout << "arg1: tensor operation (conv_fwd: ForwardConvolution)\n" + << "arg2: data type (0: fp32; 1: fp16, 2: bf16, 3: int8)\n" + << "arg3: data layout (0: NCHW; 1: NHWC)\n" + << "arg4: verification (0=no, 1=yes)\n" + << "arg5: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: run kernel # of times (>1)\n" + << "arg8: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams parse_params(int num_dim_spatial, int argc, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 9; + if(cmdline_nargs != argc) + { + print_use_msg(); + exit(1); + } + int arg_idx = 9; + + return ck::utils::conv::parse_conv_params(num_dim_spatial, arg_idx, argv); +} + +template +void profile_convnd_instances_impl(const ck::utils::conv::ConvParams& params, + bool do_verification, + bool do_log, + bool time_kernel, + int init_method, + ConvLayouts) +{ + using namespace std::placeholders; + using namespace ck::utils; + + std::unique_ptr> conv_instance; + + switch(init_method) + { + case 0: + conv_instance = + std::make_unique>(params, false); + break; + case 1: + conv_instance = std::make_unique< + conv::ConvFwdOpInstance, + ck::utils::FillUniform>>( + params, true, ck::utils::FillUniform{}, ck::utils::FillUniform{}); + break; + case 2: + conv_instance = std::make_unique< + conv::ConvFwdOpInstance, + ck::utils::FillUniform>>( + params, + true, + ck::utils::FillUniform{}, + ck::utils::FillUniform{}); + break; + default: throw std::runtime_error("Unsupported init method!"); + } + + auto reference_conv_fwd_fun = std::bind( + conv::run_reference_convolution_forward, + params, + _1, + _2, + _3); + OpInstanceRunEngine run_engine(*conv_instance, + reference_conv_fwd_fun); + auto best_conf = run_engine.Profile( + conv::ConvolutionFwdInstances::template Get(), + time_kernel, + do_verification, + do_log); + + std::cout << "Best configuration parameters:" + << "\nname: " << best_conf.best_op_name << "\navg_time: " << best_conf.best_avg_time + << "\ntflops: " << best_conf.best_tflops << "\nGB/s: " << best_conf.best_gb_per_sec + << std::endl; +} + +template +void profile_convnd_instances(ConvDataType data_type, + ConvDataLayout data_layout, + const ck::utils::conv::ConvParams& params, + bool do_verification, + bool do_log, + bool time_kernel, + int init_method) +{ + switch(data_layout) + { + case ConvDataLayout::NHWC: { + switch(data_type) + { + case ConvDataType::F32_F32_F32: + profile_convnd_instances_impl( + params, + do_verification, + do_log, + time_kernel, + init_method, + ConvolutionLayouts{}); + break; + case ConvDataType::F16_F16_F16: + profile_convnd_instances_impl( + params, + do_verification, + do_log, + time_kernel, + init_method, + ConvolutionLayouts{}); + break; + case ConvDataType::BF16_BF16_BF16: + profile_convnd_instances_impl( + params, + do_verification, + do_log, + time_kernel, + init_method, + ConvolutionLayouts{}); + break; + case ConvDataType::INT8_INT8_INT8: + profile_convnd_instances_impl( + params, + do_verification, + do_log, + time_kernel, + init_method, + ConvolutionLayouts{}); + break; + } + break; + } + case ConvDataLayout::NCHW: { + switch(data_type) + { + case ConvDataType::F32_F32_F32: + profile_convnd_instances_impl( + params, + do_verification, + do_log, + time_kernel, + init_method, + ConvolutionLayouts{}); + break; + case ConvDataType::F16_F16_F16: + profile_convnd_instances_impl( + params, + do_verification, + do_log, + time_kernel, + init_method, + ConvolutionLayouts{}); + break; + case ConvDataType::BF16_BF16_BF16: + profile_convnd_instances_impl( + params, + do_verification, + do_log, + time_kernel, + init_method, + ConvolutionLayouts{}); + break; + case ConvDataType::INT8_INT8_INT8: + profile_convnd_instances_impl( + params, + do_verification, + do_log, + time_kernel, + init_method, + ConvolutionLayouts{}); + break; + } + break; + } + } +} + +} // namespace + +int ck::profiler::profile_convnd_fwd(int argc, char* argv[]) +{ + using namespace ck::utils::conv; + + ConvDataType data_type{ConvDataType::F32_F32_F32}; + ConvDataLayout data_layout{ConvDataLayout::NHWC}; + bool do_verification{true}; + int init_method{2}; + bool do_log{false}; + bool time_kernel{false}; + int num_dim_spatial{2}; + ConvParams params; + + if(argc >= 4) + { + data_type = static_cast(std::stoi(argv[2])); + data_layout = static_cast(std::stoi(argv[3])); + } + if(argc >= 9) + { + do_verification = std::stoi(argv[4]); + init_method = std::stoi(argv[5]); + do_log = std::stoi(argv[6]); + time_kernel = std::stoi(argv[7]); + num_dim_spatial = std::stoi(argv[8]); + } + if(argc >= 10) + { + params = parse_params(num_dim_spatial, argc, argv); + } + + // TODO Print nice message what is being profiled. + + switch(num_dim_spatial) + { + case 1: + profile_convnd_instances<1>( + data_type, data_layout, params, do_verification, do_log, time_kernel, init_method); + break; + case 2: + profile_convnd_instances<2>( + data_type, data_layout, params, do_verification, do_log, time_kernel, init_method); + break; + case 3: + profile_convnd_instances<3>( + data_type, data_layout, params, do_verification, do_log, time_kernel, init_method); + break; + default: + throw std::runtime_error("profile_conv_fwd: unsupported num_dim_spatial value: " + + std::to_string(num_dim_spatial)); + } + + return 0; +} diff --git a/profiler/src/profile_gemm.cpp b/profiler/src/profile_gemm.cpp new file mode 100644 index 00000000000..55bc98f4b10 --- /dev/null +++ b/profiler/src/profile_gemm.cpp @@ -0,0 +1,392 @@ +#include +#include +#include +#include +#include +#include +#include "profile_gemm_impl.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + MK_KN_NM, // 4 + MK_NK_NM, // 5 + KM_KN_NM, // 6 + KM_NK_NM, // 7 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +int profile_gemm(int argc, char* argv[]) +{ + if(!(argc == 14 || argc == 15)) + { + printf("arg1: tensor operation (gemm: GEMM)\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: split k into mulitiple batch\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + int KBatch = 1; + if(argc == 15) + KBatch = std::stoi(argv[14]); + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else + { + throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_gemm_bias_2d.cpp b/profiler/src/profile_gemm_bias_2d.cpp new file mode 100644 index 00000000000..51dba85f326 --- /dev/null +++ b/profiler/src/profile_gemm_bias_2d.cpp @@ -0,0 +1,256 @@ +#include +#include +#include +#include +#include +#include +#include "profile_gemm_bias_2d_impl.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + MK_KN_NM, // 4 + MK_NK_NM, // 5 + KM_KN_NM, // 6 + KM_NK_NM, // 7 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + +int profile_gemm_bias_2d(int argc, char* argv[]) +{ + if(!(argc == 16 || argc == 17)) + { + printf("arg1: tensor operation (gemm: GEMM+Bias_2d)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: alpha\n"); + printf("arg15: beta\n"); + printf("arg16: split k into mulitiple batch\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + const float alpha = std::stof(argv[14]); + const float beta = std::stof(argv[15]); + + if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_bias_2d_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + alpha, + beta); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_bias_2d_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + alpha, + beta); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_bias_2d_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + alpha, + beta); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_bias_2d_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + alpha, + beta); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_bias_2d_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + alpha, + beta); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_bias_2d_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + alpha, + beta); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_bias_2d_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + alpha, + beta); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_bias_2d_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + alpha, + beta); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_gemm_bias_relu.cpp b/profiler/src/profile_gemm_bias_relu.cpp new file mode 100644 index 00000000000..bf035d9ad9a --- /dev/null +++ b/profiler/src/profile_gemm_bias_relu.cpp @@ -0,0 +1,143 @@ +#include +#include +#include +#include +#include +#include +#include "profile_gemm_bias_relu_impl.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + MK_KN_NM, // 4 + MK_NK_NM, // 5 + KM_KN_NM, // 6 + KM_NK_NM, // 7 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + +int profile_gemm_bias_relu(int argc, char* argv[]) +{ + if(!(argc == 14 || argc == 15)) + { + printf("arg1: tensor operation (gemm: GEMM+Bias+ReLU)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: split k into mulitiple batch\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_bias_relu_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_bias_relu_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_bias_relu_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_bias_relu_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_gemm_bias_relu_add.cpp b/profiler/src/profile_gemm_bias_relu_add.cpp new file mode 100644 index 00000000000..9c324f6cf95 --- /dev/null +++ b/profiler/src/profile_gemm_bias_relu_add.cpp @@ -0,0 +1,148 @@ +#include +#include +#include +#include +#include +#include +#include "profile_gemm_bias_relu_add_impl.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + MK_KN_NM, // 4 + MK_NK_NM, // 5 + KM_KN_NM, // 6 + KM_NK_NM, // 7 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + +int profile_gemm_bias_relu_add(int argc, char* argv[]) +{ + if(!(argc == 15 || argc == 16)) + { + printf("arg1: tensor operation (gemm: GEMM+Bias+ReLU+Add)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, StrideC1\n"); + printf("arg15: split k into mulitiple batch\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + const int StrideC1 = std::stoi(argv[14]); + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_bias_relu_add_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_bias_relu_add_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_bias_relu_add_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_bias_relu_add_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_gemm_reduce.cpp b/profiler/src/profile_gemm_reduce.cpp new file mode 100644 index 00000000000..a23967acd7a --- /dev/null +++ b/profiler/src/profile_gemm_reduce.cpp @@ -0,0 +1,146 @@ +#include +#include +#include +#include +#include +#include +#include "profile_gemm_reduce_impl.hpp" + +int profile_gemm_reduce(int argc, char* argv[]) +{ + enum struct GemmMatrixLayout + { + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + }; + + enum struct GemmReduceDataType + { + F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F32_F32, // 1 + }; + + if(!(argc == 14 || argc == 15)) + { + printf("arg1: tensor operation (gemm: GEMM+Reduce)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: split k into mulitiple batch\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_grouped_gemm.cpp b/profiler/src/profile_grouped_gemm.cpp new file mode 100644 index 00000000000..c3774962cc9 --- /dev/null +++ b/profiler/src/profile_grouped_gemm.cpp @@ -0,0 +1,157 @@ +#include +#include +#include +#include +#include +#include +#include "profile_grouped_gemm_impl.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + MK_KN_NM, // 4 + MK_NK_NM, // 5 + KM_KN_NM, // 6 + KM_NK_NM, // 7 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +std::vector argToIntArray(char* input) +{ + std::vector out; + + std::istringstream in(input); + + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + + return out; +} + +int profile_grouped_gemm(int argc, char* argv[]) +{ + if(!(argc == 14)) + { + printf("arg1: tensor operation (grouped_gemm: Grouped GEMM)\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const auto Ms = argToIntArray(argv[8]); + const auto Ns = argToIntArray(argv[9]); + const auto Ks = argToIntArray(argv[10]); + + const auto StrideAs = argToIntArray(argv[11]); + const auto StrideBs = argToIntArray(argv[12]); + const auto StrideCs = argToIntArray(argv[13]); + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else + { + throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_reduce.cpp b/profiler/src/profile_reduce.cpp new file mode 100644 index 00000000000..bdbac4fab4f --- /dev/null +++ b/profiler/src/profile_reduce.cpp @@ -0,0 +1,427 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "data_type_enum.hpp" +#include "reduction_enums.hpp" + +#include "host_common_util.hpp" +#include "profile_reduce_impl.hpp" + +using namespace std; + +using ck::ReduceTensorOp; + +static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, + {"reduceDims", required_argument, nullptr, 'R'}, + {"reduceOp", required_argument, nullptr, 'O'}, + {"compType", required_argument, nullptr, 'C'}, + {"outType", required_argument, nullptr, 'W'}, + {"nanOpt", required_argument, nullptr, 'N'}, + {"indicesOpt", required_argument, nullptr, 'I'}, + {"scales", required_argument, nullptr, 'S'}, + {"half", no_argument, nullptr, '?'}, + {"double", no_argument, nullptr, '?'}, + {"int8", no_argument, nullptr, '?'}, + {"bf16", no_argument, nullptr, '?'}, + {"dumpout", required_argument, nullptr, 'o'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +static void check_reduce_dims(const int rank, const std::vector& reduceDims) +{ + for(auto dim : reduceDims) + { + if(dim < 0 || dim >= rank) + throw std::runtime_error("Invalid dimension index specified for Reducing"); + }; + + unsigned int flag = 0; + + for(auto dim : reduceDims) + { + if(flag & (0x1 << dim)) + throw std::runtime_error("All toReduce dimensions should be different!"); + flag = flag | (0x1 << dim); + }; +}; + +class ReduceProfilerArgs +{ + private: + int option_index = 0; + + public: + bool use_half = false; + bool use_double = false; + bool use_int8 = false; + bool use_bf16 = false; + + std::vector inLengths; + std::vector outLengths; + std::vector reduceDims; + + std::vector scales; + + ReduceTensorOp reduceOp = ReduceTensorOp::ADD; + ck::DataTypeEnum compTypeId = ck::DataTypeEnum::Float; + ck::DataTypeEnum outTypeId = ck::DataTypeEnum::Float; + + bool compType_assigned = false; + bool outType_assigned = false; + + int nanOpt = 0; + int indicesOpt = 0; + bool do_verification = false; + bool do_dumpout = false; + + int init_method; + bool time_kernel; + + ReduceProfilerArgs() = default; + ~ReduceProfilerArgs() = default; + + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths" + << std::endl; + std::cout << "--reduceDims or -R, comma separated list of to-reduce dimensions" + << std::endl; + std::cout << "--reduceOp or -O, enum value indicating the reduction operations" + << std::endl; + std::cout << "--compType or -C, enum value indicating the type of accumulated values used " + "during the reduction" + << std::endl; + std::cout << "--outType or -W, optional enum value indicating the type of the reduced " + "output, which could be float when the input data is half" + << std::endl; + std::cout + << "--nanOpt or -N, 1/0 value indicates the selection to use or not use Nan-Propagation" + << std::endl; + std::cout << "--indicesOpt or -I, 1/0 value indicates the selection to use or not use " + "index in reduction" + << std::endl; + std::cout << "--scales or -S, comma separated two float values for alpha and beta" + << std::endl; + std::cout << "--half, use fp16 for the input and output tensor data types" << std::endl; + std::cout << "--double, use fp64 for the input and output tensor data types" << std::endl; + std::cout << "--int8, use int8 for the input and output tensor data types" << std::endl; + std::cout << "--bf16, use bfloat16 for the input and output tensor data types" << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by " + "comparing with the host-based reduction" + << std::endl; + std::cout << "--dumpout or -o, 1/0 to indicate where to save the reduction result to files " + "for further analysis" + << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + optind++; // to skip the "reduce" module name + + while(1) + { + ch = getopt_long(argc, argv, "D:R:O:C:W:N:I:S:v:o:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'R': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + reduceDims = getTypeValuesFromString(optarg); + break; + case 'O': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + reduceOp = static_cast(std::atoi(optarg)); + break; + case 'C': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + compTypeId = static_cast(std::atoi(optarg)); + compType_assigned = true; + break; + case 'W': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + outTypeId = static_cast(std::atoi(optarg)); + outType_assigned = true; + break; + case 'N': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + nanOpt = std::atoi(optarg); + break; + case 'I': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + indicesOpt = std::atoi(optarg); + break; + case 'S': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + scales = getTypeValuesFromString(optarg); + + if(scales.size() != 2) + throw std::runtime_error("Invalid option format!"); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case 'o': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_dumpout = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "half") + use_half = true; + else if(std::string(long_options[option_index].name) == "double") + use_double = true; + else if(std::string(long_options[option_index].name) == "int8") + use_int8 = true; + else if(std::string(long_options[option_index].name) == "bf16") + use_bf16 = true; + else if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + + default: + show_usage(argv[0]); + std::cerr << "Invalid cmd-line options!" << std::endl; + return (-1); + }; + }; + + if(optind + 2 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind])); + + if(scales.empty()) + { + scales.push_back(1.0f); + scales.push_back(0.0f); + }; + + if(reduceOp == ReduceTensorOp::MIN || reduceOp == ReduceTensorOp::MAX || + reduceOp == ReduceTensorOp::AMAX) + { + // for indexable operations, no need to assign compType and outType, just let them be + // same as inType + compType_assigned = false; + outType_assigned = false; + }; + + return (0); + }; + +}; // end of class AppArgs + +int profile_reduce(int argc, char* argv[]) +{ + using ck::DataTypeEnum; + using ck::profiler::profile_reduce_impl; + + ReduceProfilerArgs args; + + if(args.processArgs(argc, argv) < 0) + return (-1); + + int rank = args.inLengths.size(); + + check_reduce_dims(rank, args.reduceDims); + + if(args.reduceOp == ReduceTensorOp::MUL || args.reduceOp == ReduceTensorOp::NORM1) + throw std::runtime_error("MUL and NORM1 are not supported by composable kernel!"); + + if(args.use_half) + { + if(!args.compType_assigned) + args.compTypeId = DataTypeEnum::Half; + + if(args.outType_assigned && + (args.outTypeId != DataTypeEnum::Half && args.outTypeId != DataTypeEnum::Float)) + args.outTypeId = DataTypeEnum::Float; + + if(!args.outType_assigned) + args.outTypeId = DataTypeEnum::Half; + + if(args.compTypeId == DataTypeEnum::Half) + { + profile_reduce_impl( + args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else if(args.compTypeId == DataTypeEnum::Float) + { + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else + throw std::runtime_error("Invalid compType assignment!"); + } + else if(args.use_double) + { + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else if(args.use_int8) + { + if(!args.compType_assigned) + args.compTypeId = DataTypeEnum::Int8; + + if(args.outType_assigned && + (args.outTypeId != DataTypeEnum::Int8 && args.outTypeId != DataTypeEnum::Int32)) + args.outTypeId = DataTypeEnum::Int32; + + if(!args.outType_assigned) + args.outTypeId = DataTypeEnum::Int8; + + if(args.compTypeId == DataTypeEnum::Int8) + { + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else if(args.compTypeId == DataTypeEnum::Int32) + { + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else + throw std::runtime_error("Invalid compType assignment!"); + } + else if(args.use_bf16) + { + if(args.outType_assigned && + (args.outTypeId != DataTypeEnum::BFloat16 && args.outTypeId != DataTypeEnum::Float)) + args.outTypeId = DataTypeEnum::Float; + + if(!args.outType_assigned) + args.outTypeId = DataTypeEnum::BFloat16; + + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else + { + if(args.compTypeId == DataTypeEnum::Float) + { + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else if(args.compTypeId == DataTypeEnum::Double) + { + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else + throw std::runtime_error("Invalid compType assignment!"); + }; + + return (0); +}; diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp new file mode 100644 index 00000000000..d16e28ee237 --- /dev/null +++ b/profiler/src/profiler.cpp @@ -0,0 +1,116 @@ +#include +#include +#include +#include +#include + +#include "profile_convnd_fwd.hpp" + +int profile_gemm(int, char*[]); +int profile_gemm_bias_2d(int, char*[]); +int profile_gemm_bias_relu(int, char*[]); +int profile_gemm_bias_relu_add(int, char*[]); +int profile_gemm_reduce(int, char*[]); +int profile_batched_gemm(int, char*[]); +int profile_grouped_gemm(int, char*[]); +int profile_conv_fwd(int, char*[]); +int profile_conv_fwd_bias_relu(int, char*[]); +int profile_conv_fwd_bias_relu_add(int, char*[]); +int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); +int profile_convnd_bwd_data(int, char*[], int); +int profile_reduce(int, char*[]); +int profile_conv_bwd_weight(int, char*[]); +int profile_batched_gemm_reduce(int, char*[]); + +int main(int argc, char* argv[]) +{ + if(strcmp(argv[1], "gemm") == 0) + { + return profile_gemm(argc, argv); + } + else if(strcmp(argv[1], "gemm_bias_2d") == 0) + { + return profile_gemm_bias_2d(argc, argv); + } + else if(strcmp(argv[1], "gemm_bias_relu") == 0) + { + return profile_gemm_bias_relu(argc, argv); + } + else if(strcmp(argv[1], "gemm_bias_relu_add") == 0) + { + return profile_gemm_bias_relu_add(argc, argv); + } + else if(strcmp(argv[1], "gemm_reduce") == 0) + { + return profile_gemm_reduce(argc, argv); + } + else if(strcmp(argv[1], "batched_gemm") == 0) + { + return profile_batched_gemm(argc, argv); + } + else if(strcmp(argv[1], "batched_gemm_reduce") == 0) + { + return profile_batched_gemm_reduce(argc, argv); + } + else if(strcmp(argv[1], "grouped_gemm") == 0) + { + return profile_grouped_gemm(argc, argv); + } + else if(strcmp(argv[1], "conv_fwd") == 0) + { + return ck::profiler::profile_convnd_fwd(argc, argv); + } + else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0) + { + return profile_conv_fwd_bias_relu(argc, argv); + } + else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0) + { + return profile_conv_fwd_bias_relu_add(argc, argv); + } + else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0) + { + return profile_conv_fwd_bias_relu_atomic_add(argc, argv); + } + else if(strcmp(argv[1], "conv1d_bwd_data") == 0) + { + return profile_convnd_bwd_data(argc, argv, 1); + } + else if(strcmp(argv[1], "conv2d_bwd_data") == 0) + { + return profile_convnd_bwd_data(argc, argv, 2); + } + else if(strcmp(argv[1], "conv3d_bwd_data") == 0) + { + return profile_convnd_bwd_data(argc, argv, 3); + } + else if(strcmp(argv[1], "reduce") == 0) + { + return profile_reduce(argc, argv); + } + else if(strcmp(argv[1], "conv2d_bwd_weight") == 0) + { + return profile_conv_bwd_weight(argc, argv); + } + else + { + // clang-format off + printf("arg1: tensor operation (gemm: GEMM\n" + " gemm_bias_2d: GEMM+Bias(2D)\n" + " gemm_bias_relu: GEMM+Bias+ReLU\n" + " gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n" + " gemm_reduce: GEMM+Reduce\n" + " grouped_gemm: Grouped GEMM\n" + " conv_fwd: ForwardConvolution\n" + " conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n" + " conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n" + " conv_fwd_bias_relu_atomic_add: ForwardConvolution+Bias+ReLU+AtomicAdd\n" + " conv1d_bwd_data: BackwardConvolution data 1 dim\n" + " conv2d_bwd_data: BackwardConvolution data 2 dim\n" + " conv3d_bwd_data: BackwardConvolution data 3 dim\n" + " reduce: Reduce\n" + " conv2d_bwd_weight: Backward Weight Convolution 2d\n"); + // clang-format on + } + return 0; +} diff --git a/rbuild.ini b/rbuild.ini new file mode 100644 index 00000000000..3649cedf0ae --- /dev/null +++ b/rbuild.ini @@ -0,0 +1,8 @@ +[develop] +cxx = ${rocm_path}/bin/hipcc +cc = ${rocm_path}/llvm/bin/clang +ignore = pcre +deps = + -f dev-requirements.txt +define = + BUILD_DEV=On diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000000..b91bf2e553a --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +danmar/cppcheck@dd05839a7e63ef04afd34711cb3e1e0ef742882f diff --git a/script/clang-format-overwrite.sh b/script/clang-format-overwrite.sh new file mode 100644 index 00000000000..71df7d10e5c --- /dev/null +++ b/script/clang-format-overwrite.sh @@ -0,0 +1,2 @@ +#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' +git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' diff --git a/script/cmake-rocm.sh b/script/cmake-rocm.sh index ebfa2b9f693..86b62368967 100755 --- a/script/cmake-rocm.sh +++ b/script/cmake-rocm.sh @@ -3,16 +3,18 @@ rm -f CMakeCache.txt rm -f *.cmake rm -rf CMakeFiles -MY_PROJECT_SOURCE=../../.. +MY_PROJECT_SOURCE=../ MY_PROJECT_INSTALL=../install.dir cmake \ -D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ --D HALF_INCLUDE_DIR="/root/workspace/external/half/include" \ --D BUILD_DEV=ON \ +-D BUILD_DEV=OFF \ -D CMAKE_BUILD_TYPE=Release \ --D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 -O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ +-D CMAKE_CXX_FLAGS=" -O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ ${MY_PROJECT_SOURCE} + +#-D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 --offload-arch=gfx90a -O3 -ftemplate-backtrace-limit=0 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ +#-D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 --offload-arch=gfx90a -O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \ diff --git a/script/conv2d_fwd.sh b/script/conv2d_fwd.sh new file mode 100755 index 00000000000..acc91e194fd --- /dev/null +++ b/script/conv2d_fwd.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=0 + + make -j $1 + +DRIVER=example/$1 +VERIFY=$2 +INIT=$3 +REPEAT=$4 + +# test +######## verify init repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ + $DRIVER $VERIFY $INIT $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT 128 256 64 1 1 1 1 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT 256 64 3 7 7 230 230 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT 256 64 3 7 7 224 224 2 2 1 1 3 3 3 3 + + N=$5 + +# Resnet50 +######## verify init repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ +#$DRIVER $VERIFY $INIT $REPEAT $N 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 256 256 3 3 30 30 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 512 256 1 1 56 56 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 512 512 3 3 16 16 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 256 512 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE diff --git a/script/conv_driver.sh b/script/conv_driver.sh new file mode 100755 index 00000000000..8805e0cc990 --- /dev/null +++ b/script/conv_driver.sh @@ -0,0 +1,71 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=0 + + make -j conv_fwd_driver_offline +#make -j conv_bwd_driver_offline +#make -j conv_wrw_driver_offline + + DRIVER="./host/driver_offline/conv_fwd_driver_offline" +#DRIVER="./host/driver_offline/conv_bwd_driver_offline" +#DRIVER="./host/driver_offline/conv_wrw_driver_offline" + +LAYOUT=$1 +ALGO=$2 +VERIFY=$3 +INIT=$4 +LOG=$5 +REPEAT=$6 + + DESIRED_GRID_SIZE=$7 + +######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE + $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE + $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE + $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE + $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 32 256 3 3 1 1 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 32 256 1 1 1 1 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE + $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 2 2 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 1 1 2 2 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE + +# Resnet50 +######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 128 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 256 1 1 56 56 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 512 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE diff --git a/script/count_vgpr.sh b/script/count_vgpr.sh index 4fbfec02783..07debc53a8c 100755 --- a/script/count_vgpr.sh +++ b/script/count_vgpr.sh @@ -1,259 +1,20 @@ #!/bin/bash FILE=$1 -echo v0 $( grep -w v0 $FILE | wc -l ) -echo v1 $( grep -w v1 $FILE | wc -l ) -echo v2 $( grep -w v2 $FILE | wc -l ) -echo v3 $( grep -w v3 $FILE | wc -l ) -echo v4 $( grep -w v4 $FILE | wc -l ) -echo v5 $( grep -w v5 $FILE | wc -l ) -echo v6 $( grep -w v6 $FILE | wc -l ) -echo v7 $( grep -w v7 $FILE | wc -l ) -echo v8 $( grep -w v8 $FILE | wc -l ) -echo v9 $( grep -w v9 $FILE | wc -l ) -echo v10 $( grep -w v10 $FILE | wc -l ) -echo v11 $( grep -w v11 $FILE | wc -l ) -echo v12 $( grep -w v12 $FILE | wc -l ) -echo v13 $( grep -w v13 $FILE | wc -l ) -echo v14 $( grep -w v14 $FILE | wc -l ) -echo v15 $( grep -w v15 $FILE | wc -l ) -echo v16 $( grep -w v16 $FILE | wc -l ) -echo v17 $( grep -w v17 $FILE | wc -l ) -echo v18 $( grep -w v18 $FILE | wc -l ) -echo v19 $( grep -w v19 $FILE | wc -l ) -echo v20 $( grep -w v20 $FILE | wc -l ) -echo v21 $( grep -w v21 $FILE | wc -l ) -echo v22 $( grep -w v22 $FILE | wc -l ) -echo v23 $( grep -w v23 $FILE | wc -l ) -echo v24 $( grep -w v24 $FILE | wc -l ) -echo v25 $( grep -w v25 $FILE | wc -l ) -echo v26 $( grep -w v26 $FILE | wc -l ) -echo v27 $( grep -w v27 $FILE | wc -l ) -echo v28 $( grep -w v28 $FILE | wc -l ) -echo v29 $( grep -w v29 $FILE | wc -l ) -echo v30 $( grep -w v30 $FILE | wc -l ) -echo v31 $( grep -w v31 $FILE | wc -l ) -echo v32 $( grep -w v32 $FILE | wc -l ) -echo v33 $( grep -w v33 $FILE | wc -l ) -echo v34 $( grep -w v34 $FILE | wc -l ) -echo v35 $( grep -w v35 $FILE | wc -l ) -echo v36 $( grep -w v36 $FILE | wc -l ) -echo v37 $( grep -w v37 $FILE | wc -l ) -echo v38 $( grep -w v38 $FILE | wc -l ) -echo v39 $( grep -w v39 $FILE | wc -l ) -echo v40 $( grep -w v40 $FILE | wc -l ) -echo v41 $( grep -w v41 $FILE | wc -l ) -echo v42 $( grep -w v42 $FILE | wc -l ) -echo v43 $( grep -w v43 $FILE | wc -l ) -echo v44 $( grep -w v44 $FILE | wc -l ) -echo v45 $( grep -w v45 $FILE | wc -l ) -echo v46 $( grep -w v46 $FILE | wc -l ) -echo v47 $( grep -w v47 $FILE | wc -l ) -echo v48 $( grep -w v48 $FILE | wc -l ) -echo v49 $( grep -w v49 $FILE | wc -l ) -echo v50 $( grep -w v50 $FILE | wc -l ) -echo v51 $( grep -w v51 $FILE | wc -l ) -echo v52 $( grep -w v52 $FILE | wc -l ) -echo v53 $( grep -w v53 $FILE | wc -l ) -echo v54 $( grep -w v54 $FILE | wc -l ) -echo v55 $( grep -w v55 $FILE | wc -l ) -echo v56 $( grep -w v56 $FILE | wc -l ) -echo v57 $( grep -w v57 $FILE | wc -l ) -echo v58 $( grep -w v58 $FILE | wc -l ) -echo v59 $( grep -w v59 $FILE | wc -l ) -echo v60 $( grep -w v60 $FILE | wc -l ) -echo v61 $( grep -w v61 $FILE | wc -l ) -echo v62 $( grep -w v62 $FILE | wc -l ) -echo v63 $( grep -w v63 $FILE | wc -l ) -echo v64 $( grep -w v64 $FILE | wc -l ) -echo v65 $( grep -w v65 $FILE | wc -l ) -echo v66 $( grep -w v66 $FILE | wc -l ) -echo v67 $( grep -w v67 $FILE | wc -l ) -echo v68 $( grep -w v68 $FILE | wc -l ) -echo v69 $( grep -w v69 $FILE | wc -l ) -echo v70 $( grep -w v70 $FILE | wc -l ) -echo v71 $( grep -w v71 $FILE | wc -l ) -echo v72 $( grep -w v72 $FILE | wc -l ) -echo v73 $( grep -w v73 $FILE | wc -l ) -echo v74 $( grep -w v74 $FILE | wc -l ) -echo v75 $( grep -w v75 $FILE | wc -l ) -echo v76 $( grep -w v76 $FILE | wc -l ) -echo v77 $( grep -w v77 $FILE | wc -l ) -echo v78 $( grep -w v78 $FILE | wc -l ) -echo v79 $( grep -w v79 $FILE | wc -l ) -echo v80 $( grep -w v80 $FILE | wc -l ) -echo v81 $( grep -w v81 $FILE | wc -l ) -echo v82 $( grep -w v82 $FILE | wc -l ) -echo v83 $( grep -w v83 $FILE | wc -l ) -echo v84 $( grep -w v84 $FILE | wc -l ) -echo v85 $( grep -w v85 $FILE | wc -l ) -echo v86 $( grep -w v86 $FILE | wc -l ) -echo v87 $( grep -w v87 $FILE | wc -l ) -echo v88 $( grep -w v88 $FILE | wc -l ) -echo v89 $( grep -w v89 $FILE | wc -l ) -echo v90 $( grep -w v90 $FILE | wc -l ) -echo v91 $( grep -w v91 $FILE | wc -l ) -echo v92 $( grep -w v92 $FILE | wc -l ) -echo v93 $( grep -w v93 $FILE | wc -l ) -echo v94 $( grep -w v94 $FILE | wc -l ) -echo v95 $( grep -w v95 $FILE | wc -l ) -echo v96 $( grep -w v96 $FILE | wc -l ) -echo v97 $( grep -w v97 $FILE | wc -l ) -echo v98 $( grep -w v98 $FILE | wc -l ) -echo v99 $( grep -w v99 $FILE | wc -l ) -echo v100 $( grep -w v100 $FILE | wc -l ) -echo v101 $( grep -w v101 $FILE | wc -l ) -echo v102 $( grep -w v102 $FILE | wc -l ) -echo v103 $( grep -w v103 $FILE | wc -l ) -echo v104 $( grep -w v104 $FILE | wc -l ) -echo v105 $( grep -w v105 $FILE | wc -l ) -echo v106 $( grep -w v106 $FILE | wc -l ) -echo v107 $( grep -w v107 $FILE | wc -l ) -echo v108 $( grep -w v108 $FILE | wc -l ) -echo v109 $( grep -w v109 $FILE | wc -l ) -echo v110 $( grep -w v110 $FILE | wc -l ) -echo v111 $( grep -w v111 $FILE | wc -l ) -echo v112 $( grep -w v112 $FILE | wc -l ) -echo v113 $( grep -w v113 $FILE | wc -l ) -echo v114 $( grep -w v114 $FILE | wc -l ) -echo v115 $( grep -w v115 $FILE | wc -l ) -echo v116 $( grep -w v116 $FILE | wc -l ) -echo v117 $( grep -w v117 $FILE | wc -l ) -echo v118 $( grep -w v118 $FILE | wc -l ) -echo v119 $( grep -w v119 $FILE | wc -l ) -echo v120 $( grep -w v120 $FILE | wc -l ) -echo v121 $( grep -w v121 $FILE | wc -l ) -echo v122 $( grep -w v122 $FILE | wc -l ) -echo v123 $( grep -w v123 $FILE | wc -l ) -echo v124 $( grep -w v124 $FILE | wc -l ) -echo v125 $( grep -w v125 $FILE | wc -l ) -echo v126 $( grep -w v126 $FILE | wc -l ) -echo v127 $( grep -w v127 $FILE | wc -l ) -echo v128 $( grep -w v128 $FILE | wc -l ) -echo v129 $( grep -w v129 $FILE | wc -l ) -echo v130 $( grep -w v130 $FILE | wc -l ) -echo v131 $( grep -w v131 $FILE | wc -l ) -echo v132 $( grep -w v132 $FILE | wc -l ) -echo v133 $( grep -w v133 $FILE | wc -l ) -echo v134 $( grep -w v134 $FILE | wc -l ) -echo v135 $( grep -w v135 $FILE | wc -l ) -echo v136 $( grep -w v136 $FILE | wc -l ) -echo v137 $( grep -w v137 $FILE | wc -l ) -echo v138 $( grep -w v138 $FILE | wc -l ) -echo v139 $( grep -w v139 $FILE | wc -l ) -echo v140 $( grep -w v140 $FILE | wc -l ) -echo v141 $( grep -w v141 $FILE | wc -l ) -echo v142 $( grep -w v142 $FILE | wc -l ) -echo v143 $( grep -w v143 $FILE | wc -l ) -echo v144 $( grep -w v144 $FILE | wc -l ) -echo v145 $( grep -w v145 $FILE | wc -l ) -echo v146 $( grep -w v146 $FILE | wc -l ) -echo v147 $( grep -w v147 $FILE | wc -l ) -echo v148 $( grep -w v148 $FILE | wc -l ) -echo v149 $( grep -w v149 $FILE | wc -l ) -echo v150 $( grep -w v150 $FILE | wc -l ) -echo v151 $( grep -w v151 $FILE | wc -l ) -echo v152 $( grep -w v152 $FILE | wc -l ) -echo v153 $( grep -w v153 $FILE | wc -l ) -echo v154 $( grep -w v154 $FILE | wc -l ) -echo v155 $( grep -w v155 $FILE | wc -l ) -echo v156 $( grep -w v156 $FILE | wc -l ) -echo v157 $( grep -w v157 $FILE | wc -l ) -echo v158 $( grep -w v158 $FILE | wc -l ) -echo v159 $( grep -w v159 $FILE | wc -l ) -echo v160 $( grep -w v160 $FILE | wc -l ) -echo v161 $( grep -w v161 $FILE | wc -l ) -echo v162 $( grep -w v162 $FILE | wc -l ) -echo v163 $( grep -w v163 $FILE | wc -l ) -echo v164 $( grep -w v164 $FILE | wc -l ) -echo v165 $( grep -w v165 $FILE | wc -l ) -echo v166 $( grep -w v166 $FILE | wc -l ) -echo v167 $( grep -w v167 $FILE | wc -l ) -echo v168 $( grep -w v168 $FILE | wc -l ) -echo v169 $( grep -w v169 $FILE | wc -l ) -echo v170 $( grep -w v170 $FILE | wc -l ) -echo v171 $( grep -w v171 $FILE | wc -l ) -echo v172 $( grep -w v172 $FILE | wc -l ) -echo v173 $( grep -w v173 $FILE | wc -l ) -echo v174 $( grep -w v174 $FILE | wc -l ) -echo v175 $( grep -w v175 $FILE | wc -l ) -echo v176 $( grep -w v176 $FILE | wc -l ) -echo v177 $( grep -w v177 $FILE | wc -l ) -echo v178 $( grep -w v178 $FILE | wc -l ) -echo v179 $( grep -w v179 $FILE | wc -l ) -echo v180 $( grep -w v180 $FILE | wc -l ) -echo v181 $( grep -w v181 $FILE | wc -l ) -echo v182 $( grep -w v182 $FILE | wc -l ) -echo v183 $( grep -w v183 $FILE | wc -l ) -echo v184 $( grep -w v184 $FILE | wc -l ) -echo v185 $( grep -w v185 $FILE | wc -l ) -echo v186 $( grep -w v186 $FILE | wc -l ) -echo v187 $( grep -w v187 $FILE | wc -l ) -echo v188 $( grep -w v188 $FILE | wc -l ) -echo v189 $( grep -w v189 $FILE | wc -l ) -echo v190 $( grep -w v190 $FILE | wc -l ) -echo v191 $( grep -w v191 $FILE | wc -l ) -echo v192 $( grep -w v192 $FILE | wc -l ) -echo v193 $( grep -w v193 $FILE | wc -l ) -echo v194 $( grep -w v194 $FILE | wc -l ) -echo v195 $( grep -w v195 $FILE | wc -l ) -echo v196 $( grep -w v196 $FILE | wc -l ) -echo v197 $( grep -w v197 $FILE | wc -l ) -echo v198 $( grep -w v198 $FILE | wc -l ) -echo v199 $( grep -w v199 $FILE | wc -l ) -echo v200 $( grep -w v200 $FILE | wc -l ) -echo v201 $( grep -w v201 $FILE | wc -l ) -echo v202 $( grep -w v202 $FILE | wc -l ) -echo v203 $( grep -w v203 $FILE | wc -l ) -echo v204 $( grep -w v204 $FILE | wc -l ) -echo v205 $( grep -w v205 $FILE | wc -l ) -echo v206 $( grep -w v206 $FILE | wc -l ) -echo v207 $( grep -w v207 $FILE | wc -l ) -echo v208 $( grep -w v208 $FILE | wc -l ) -echo v209 $( grep -w v209 $FILE | wc -l ) -echo v210 $( grep -w v210 $FILE | wc -l ) -echo v211 $( grep -w v211 $FILE | wc -l ) -echo v212 $( grep -w v212 $FILE | wc -l ) -echo v213 $( grep -w v213 $FILE | wc -l ) -echo v214 $( grep -w v214 $FILE | wc -l ) -echo v215 $( grep -w v215 $FILE | wc -l ) -echo v216 $( grep -w v216 $FILE | wc -l ) -echo v217 $( grep -w v217 $FILE | wc -l ) -echo v218 $( grep -w v218 $FILE | wc -l ) -echo v219 $( grep -w v219 $FILE | wc -l ) -echo v220 $( grep -w v220 $FILE | wc -l ) -echo v221 $( grep -w v221 $FILE | wc -l ) -echo v222 $( grep -w v222 $FILE | wc -l ) -echo v223 $( grep -w v223 $FILE | wc -l ) -echo v224 $( grep -w v224 $FILE | wc -l ) -echo v225 $( grep -w v225 $FILE | wc -l ) -echo v226 $( grep -w v226 $FILE | wc -l ) -echo v227 $( grep -w v227 $FILE | wc -l ) -echo v228 $( grep -w v228 $FILE | wc -l ) -echo v229 $( grep -w v229 $FILE | wc -l ) -echo v230 $( grep -w v230 $FILE | wc -l ) -echo v231 $( grep -w v231 $FILE | wc -l ) -echo v232 $( grep -w v232 $FILE | wc -l ) -echo v233 $( grep -w v233 $FILE | wc -l ) -echo v234 $( grep -w v234 $FILE | wc -l ) -echo v235 $( grep -w v235 $FILE | wc -l ) -echo v236 $( grep -w v236 $FILE | wc -l ) -echo v237 $( grep -w v237 $FILE | wc -l ) -echo v238 $( grep -w v238 $FILE | wc -l ) -echo v239 $( grep -w v239 $FILE | wc -l ) -echo v240 $( grep -w v240 $FILE | wc -l ) -echo v241 $( grep -w v241 $FILE | wc -l ) -echo v242 $( grep -w v242 $FILE | wc -l ) -echo v243 $( grep -w v243 $FILE | wc -l ) -echo v244 $( grep -w v244 $FILE | wc -l ) -echo v245 $( grep -w v245 $FILE | wc -l ) -echo v246 $( grep -w v246 $FILE | wc -l ) -echo v247 $( grep -w v247 $FILE | wc -l ) -echo v248 $( grep -w v248 $FILE | wc -l ) -echo v249 $( grep -w v249 $FILE | wc -l ) -echo v250 $( grep -w v250 $FILE | wc -l ) -echo v251 $( grep -w v251 $FILE | wc -l ) -echo v252 $( grep -w v252 $FILE | wc -l ) -echo v253 $( grep -w v253 $FILE | wc -l ) -echo v254 $( grep -w v254 $FILE | wc -l ) -echo v255 $( grep -w v255 $FILE | wc -l ) +for num in {0..255} +do + base_pattern="(\[?${num}\b|\[\d*:${num}\])" + spattern="s${base_pattern}" + vpattern="v${base_pattern}" + apattern="a${base_pattern}" + scount=$(grep -P $spattern $FILE | wc -l) + vcount=$(grep -P $vpattern $FILE | wc -l) + acount=$(grep -P $apattern $FILE | wc -l) + bash -c "echo -n v${num} $vcount && \ + echo -n , s${num} $scount && \ + echo -n , a${num} $acount" + if [[ $scount -ne 0 || $vcount -ne 0 || $acount -ne 0 ]]; then + echo -n " *" + fi + echo "" +done diff --git a/script/example_gemm_xdl.sh b/script/example_gemm_xdl.sh new file mode 100755 index 00000000000..9e2d77d39b0 --- /dev/null +++ b/script/example_gemm_xdl.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=1 + + make -j gemm_xdl + + DRIVER="./example/gemm_xdl" + +VERIFY=$1 +INIT=$2 +LOG=$3 +REPEAT=$4 + +######### verify init log repeat M___ N___ K___ StrideA StrideB StrideC +#$DRIVER $VERIFY $INIT $LOG $REPEAT 960 1024 1024 1024 1024 1024 +#$DRIVER $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 +#$DRIVER $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 2048 2048 2048 + $DRIVER $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 4096 4096 4096 +#$DRIVER $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 8192 8192 8192 diff --git a/script/gemm.sh b/script/gemm.sh new file mode 100755 index 00000000000..395db86d091 --- /dev/null +++ b/script/gemm.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=0 + + make -j $1 + +DRIVER=example/$1 +VERIFY=$2 +INIT=$3 +REPEAT=$4 + +######## verify init repeat M___ N___ K___ StrideA StrideB StrideC StrideC1 +#$DRIVER $VERIFY $INIT $REPEAT 256 256 256 256 256 256 256 +#$DRIVER $VERIFY $INIT $REPEAT 960 1024 1024 1024 1024 1024 1024 +#$DRIVER $VERIFY $INIT $REPEAT 1920 2048 2048 2048 2048 2048 2048 + $DRIVER $VERIFY $INIT $REPEAT 3840 4096 4096 4096 4096 4096 4096 +#$DRIVER $VERIFY $INIT $REPEAT 7680 8192 8192 8192 8192 8192 8192 +#$DRIVER $VERIFY $INIT $REPEAT 1024 1024 1024 1024 1024 1024 1024 +#$DRIVER $VERIFY $INIT $REPEAT 2048 2048 2048 2048 2048 2048 2048 diff --git a/script/gemm_driver.sh b/script/gemm_driver.sh new file mode 100755 index 00000000000..491c14cc87e --- /dev/null +++ b/script/gemm_driver.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=0 + + make -j gemm_driver_offline + + DRIVER="./host/driver_offline/gemm_driver_offline" + +LAYOUT=$1 +ALGO=$2 +VERIFY=$3 +INIT=$4 +LOG=$5 +REPEAT=$6 + + M01=$7 + N01=$8 + +######### layout algo verify init log repeat M___ N___ K___ M01_ N01_ +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01 diff --git a/script/parse_perf_data.py b/script/parse_perf_data.py new file mode 100644 index 00000000000..a023a195266 --- /dev/null +++ b/script/parse_perf_data.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +import os, io, argparse, datetime +import numpy as np +import sqlalchemy +from sqlalchemy.types import NVARCHAR, Float, Integer +import pymysql +import pandas as pd +from sshtunnel import SSHTunnelForwarder + +def print_to_string(*args, **kwargs): + output = io.StringIO() + print(*args, file=output, **kwargs) + contents = output.getvalue() + output.close() + return contents + +def parse_args(): + parser = argparse.ArgumentParser(description='Parse results from tf benchmark runs') + parser.add_argument('filename', type=str, help='Log file to prase or directory containing log files') + args = parser.parse_args() + files = [] + if os.path.isdir(args.filename): + all_files = os.listdir(args.filename) + for name in all_files: + if not 'log' in name: + continue + files.append(os.path.join(args.filename, name)) + else: + files = [args.filename] + args.files = files + return args + +def main(): + args = parse_args() + tests = [] + kernels=[] + tflops=[] + dtype=[] + alayout=[] + blayout=[] + M=[] + N=[] + K=[] + StrideA=[] + StrideB=[] + StrideC=[] + #parse results, get the Tflops value for "Best Perf" kernels + glue="" + for filename in args.files: + for line in open(filename): + if 'Branch name' in line: + lst=line.split() + branch_name=lst[2] + for filename in args.files: + for line in open(filename): + if 'Best Perf' in line: + lst=line.split() + if len(lst)>=37: #the line is complete + tests.append(glue.join(lst[5:30])) + kernels.append(glue.join(lst[37:])) + tflops.append(lst[33]) + dtype.append(lst[5]) + alayout.append(lst[8]) + blayout.append(lst[11]) + M.append(lst[14]) + N.append(lst[17]) + K.append(lst[20]) + StrideA.append(lst[23]) + StrideB.append(lst[26]) + StrideC.append(lst[29]) + elif len(lst)<37 and len(lst)>=33: #the tflops are available + tests.append(glue.join(lst[5:30])) + kernels.append("N/A") + tflops.append(lst[33]) + dtype.append(lst[5]) + alayout.append(lst[8]) + blayout.append(lst[11]) + M.append(lst[14]) + N.append(lst[17]) + K.append(lst[20]) + StrideA.append(lst[23]) + StrideB.append(lst[26]) + StrideC.append(lst[29]) + print("warning: incomplete line:",lst) + elif len(lst)<33: #even the tflops are not available + print("Error in ckProfiler output!") + print("warning: incomplete line=",lst) + + #sort results + print("Number of tests:",len(tests)) + print("Branch name:",branch_name) + #sorted_tests = sorted(tests) + #print("sorted tests:",sorted_tests) + sorted_tflops = [x for _,x in sorted(zip(tests,tflops))] + #sorted_kernels = [x for _,x in sorted(zip(tests,kernels))] + test_list=list(range(1,len(tests)+1)) + + sql_hostname = '127.0.0.1' + sql_username = os.environ["dbuser"] + print("sql_username=",sql_username) + sql_password = os.environ["dbpassword"] + sql_main_database = 'miopen_perf' + sql_port = 3306 + ssh_host = os.environ["dbsship"] + print("ssh_host=",ssh_host) + ssh_user = os.environ["dbsshuser"] + print("ssh_user=",ssh_user) + ssh_port = int(os.environ["dbsshport"]) + ssh_pass = os.environ["dbsshpassword"] + + with SSHTunnelForwarder( + (ssh_host, ssh_port), + ssh_username=ssh_user, + ssh_password=ssh_pass, + remote_bind_address=(sql_hostname, sql_port)) as tunnel: + + sqlEngine = sqlalchemy.create_engine('mysql+pymysql://{0}:{1}@{2}:{3}/{4}'. + format(sql_username, sql_password, sql_hostname, tunnel.local_bind_port, sql_main_database)) + conn = sqlEngine.connect() + + #write the ck_gemm_test_params table + #only needed once the test set changes + ''' + sorted_dtypes = [x for _,x in sorted(zip(tests,dtype))] + sorted_alayout = [x for _,x in sorted(zip(tests,alayout))] + sorted_blayout = [x for _,x in sorted(zip(tests,blayout))] + sorted_M = [x for _,x in sorted(zip(tests,M))] + sorted_N = [x for _,x in sorted(zip(tests,N))] + sorted_K = [x for _,x in sorted(zip(tests,K))] + sorted_StrideA = [x for _,x in sorted(zip(tests,StrideA))] + sorted_StrideB = [x for _,x in sorted(zip(tests,StrideB))] + sorted_StrideC = [x for _,x in sorted(zip(tests,StrideC))] + ck_gemm_params=[test_list,sorted_dtypes,sorted_alayout,sorted_blayout, + sorted_M,sorted_N,sorted_K,sorted_StrideA,sorted_StrideB, + sorted_StrideC] + df=pd.DataFrame(np.transpose(ck_gemm_params),columns=['Test_number','Data_type', + 'Alayout','BLayout','M','N','K', 'StrideA','StrideB','StrideC']) + print(df) + + dtypes = { + 'Test_number': Integer(), + 'Data_type': NVARCHAR(length=5), + 'Alayout': NVARCHAR(length=12), + 'Blayout': NVARCHAR(length=12), + 'M': Integer(), + 'N': Integer(), + 'K': Integer(), + 'StrideA': Integer(), + 'StrideB': Integer(), + 'StrideC': Integer() + } + df.to_sql("ck_gemm_test_params",conn,if_exists='replace',index=False, dtype=dtypes) + ''' + + #read baseline results for the latest develop branch + query = '''SELECT * from ck_gemm_tflops WHERE Datetime = (SELECT MAX(Datetime) FROM ck_gemm_tflops where Branch_ID='develop' );''' + tflops_base = pd.read_sql_query(query, conn) + + #write new results to the db + testlist=[] + for i in range(1,len(tests)+1): + testlist.append("Test%i"%i) + ck_gemm_tflops=[str(branch_name),str(datetime.datetime.now())] + flops=pd.DataFrame(data=[ck_gemm_tflops],columns=['Branch_ID','Datetime']) + df_add=pd.DataFrame(data=[sorted_tflops],columns=testlist) + flops=pd.concat([flops,df_add],axis=1) + print("new tflops results:",flops) + flops.to_sql("ck_gemm_tflops",conn,if_exists='append',index=False) + conn.close() + + #compare the results to the baseline + regression=0 + base=tflops_base[testlist].to_numpy(dtype='float') + base_list=base[0] + ave_perf=0 + for i in range(len(base_list)): + # success criterion: + if base_list[i]>1.01*float(sorted_tflops[i]): + print("test # ",i,"shows regression by {:.3f}%".format( + (float(sorted_tflops[i])-base_list[i])/base_list[i]*100)) + regression=1 + ave_perf=ave_perf+float(sorted_tflops[i])/base_list[i] + if regression==0: + print("no regressions found") + ave_perf=ave_perf/len(base_list) + print("average performance relative to baseline:",ave_perf) + + #return 0 if performance criteria met, otherwise return 1 + + return regression + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/script/pool2d_fwd.sh b/script/pool2d_fwd.sh new file mode 100755 index 00000000000..10acf5394e6 --- /dev/null +++ b/script/pool2d_fwd.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=0 + + make -j $1 + +DRIVER=example/$1 +VERIFY=$2 +INIT=$3 +REPEAT=$4 + +# test +######## verify init repeat N__ C___ Y X Hi__ Wi__ Strides LeftPads RightPads +#$DRIVER $VERIFY $INIT $REPEAT 128 192 3 3 71 71 2 2 1 1 1 1 +#$DRIVER $VERIFY $INIT $REPEAT 128 64 1 1 1 1 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT 256 3 7 7 230 230 2 2 0 0 0 0 + $DRIVER $VERIFY $INIT $REPEAT 256 1024 14 14 14 14 1 1 0 0 0 0 + + N=$5 + +# Resnet50 +######## verify init repeat N__ C___ Y X Hi__ Wi__ Strides LeftPads RightPads +#$DRIVER $VERIFY $INIT $REPEAT $N 1024 1 1 14 14 2 2 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 1024 1 1 14 14 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 1024 1 1 14 14 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 128 3 3 28 28 1 1 1 1 1 1 +#$DRIVER $VERIFY $INIT $REPEAT $N 128 1 1 28 28 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 128 3 3 58 58 2 2 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 2048 1 1 7 7 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 256 1 1 14 14 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 256 3 3 14 14 1 1 1 1 1 1 +#$DRIVER $VERIFY $INIT $REPEAT $N 256 3 3 30 30 2 2 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 256 1 1 56 56 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 256 1 1 56 56 2 2 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 256 1 1 56 56 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 512 3 3 16 16 2 2 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 512 1 1 28 28 2 2 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 512 1 1 28 28 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 512 1 1 28 28 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 512 1 1 7 7 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 512 3 3 7 7 1 1 1 1 1 1 +#$DRIVER $VERIFY $INIT $REPEAT $N 64 1 1 56 56 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 64 1 1 56 56 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 64 3 3 56 56 1 1 1 1 1 1 +#$DRIVER $VERIFY $INIT $REPEAT $N 3 7 7 230 230 2 2 0 0 0 0 diff --git a/script/profile_conv.sh b/script/profile_conv.sh new file mode 100755 index 00000000000..f3a6d2c70cb --- /dev/null +++ b/script/profile_conv.sh @@ -0,0 +1,177 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=0 + + make -j ckProfiler + + DRIVER="./profiler/ckProfiler" + +OP=$1 +DATATYPE=$2 +IN_LAYOUT=$3 +WEI_LAYOUT=$4 +OUT_LAYOUT=$5 +VERIFY=$6 +INIT=$7 +LOG=$8 +REPEAT=$9 + +# test +######## op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 28 28 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE + + N=${10} + +# Resnet50 from Bing +######## op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 56 56 2 2 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 28 28 2 2 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 256 1 1 56 56 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 14 14 2 2 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 512 1 1 28 28 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 8 7 7 224 224 2 2 1 1 3 3 3 3 + + +# Resnet50 from Bing +#################### op____________________ datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 3 7 7 224 224 2 2 1 1 3 3 3 3 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 256 1 1 56 56 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 56 56 2 2 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 512 1 1 28 28 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 28 28 2 2 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 14 14 2 2 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 + + +# Resnet50 +######## op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 30 30 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 256 1 1 56 56 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 16 16 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 512 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 3 7 7 230 230 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE + +# SSD +######## op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 64 3 7 7 300 300 2 2 1 1 3 3 3 3 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 64 64 3 3 75 75 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 64 64 3 3 75 75 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 64 64 3 3 75 75 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 64 64 3 3 75 75 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 64 64 3 3 75 75 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 64 64 3 3 75 75 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 64 1 1 75 75 2 2 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 64 3 3 75 75 2 2 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 128 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 128 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 128 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 128 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 128 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 128 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 128 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 128 1 1 38 38 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 128 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 1 1 38 38 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 512 256 3 3 38 38 2 2 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 512 1 1 19 19 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 512 256 3 3 19 19 2 2 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 512 1 1 10 10 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 128 3 3 10 10 2 2 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 256 1 1 5 5 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 128 3 3 5 5 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 256 1 1 3 3 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 128 3 3 3 3 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 340 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 510 512 3 3 19 19 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 510 512 3 3 10 10 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 510 256 3 3 5 5 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 340 256 3 3 3 3 1 1 1 1 1 1 1 1 diff --git a/script/profile_gemm.sh b/script/profile_gemm.sh new file mode 100755 index 00000000000..b816c5101f5 --- /dev/null +++ b/script/profile_gemm.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +## GPU visibility +export HIP_VISIBLE_DEVICES=0 +#make -j ckProfiler +DRIVER="../build/bin/ckProfiler" +echo $DRIVER +OP=$1 +DATATYPE=$2 +LAYOUT=$3 +VERIFY=$4 +INIT=$5 +LOG=$6 +REPEAT=$7 + +######## op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 256 256 256 256 256 256 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 1024 1024 1024 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 2048 2048 2048 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 4096 4096 4096 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 8192 8192 8192 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 6656 8192 8192 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3328 4096 4096 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1664 2048 2048 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 832 1024 1024 -1 -1 -1 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7040 8192 8192 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 5120 5632 4096 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2560 2816 2048 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1280 1408 1024 -1 -1 -1 diff --git a/script/profile_reduce_no_index.sh b/script/profile_reduce_no_index.sh new file mode 100755 index 00000000000..580a7ca1ee2 --- /dev/null +++ b/script/profile_reduce_no_index.sh @@ -0,0 +1,83 @@ +#!/bin/bash + +PRECISION= +##PRECISION=--half +##PRECISION=--double +##PRECISION=--int8 +##PRECISION=--bf16 + +if [ -n $PRECISION ] && [ "$PRECISION" = "--half" -o "$PRECISION" = "--bf16" ]; then + ACCTYPE="-C 1" +elif [ -n $PRECISION ] && [ "$PRECISION" = "--int8" ]; then + ACCTYPE="-C 2" +fi + + +driver="./bin/ckProfiler" + +VERIFY="-v $1" +INIT=$2 +NREPEAT=$3 + + +#### 0 - ADD, 5 - AVG, 7 - NORM2 +Operations="0 5 7" + +#### 0 - ADD, 5 - AVG, for int8, no NORM2 supported +if [ -n $PRECISION ] && [ "$PRECISION" = "--int8" ]; then + Operations=5 +fi + +## for generic validation +for op in $Operations; do + set -x + ####### datatype layout reduce dims op acctype verify init repeats + $driver reduce $PRECISION -D 64,4,280,82 -R 0,1,2,3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 0 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 1 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 1,2,3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 0,2,3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 0,1,3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,22960 -R 0 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,22960 -R 1 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 4,1469440 -R 0 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 4,1469440 -R 1 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + set +x +done + +#### 0 - ADD, 5 - AVG, 7 - NORM2 +Operations=5 + +## for performance evaluation (resnet50 NHWC => C) +for op in $Operations; do + set -x + ####### datatype layout reduce dims op acctype verify init repeats + $driver reduce $PRECISION -D 256,14,14,1024 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,28,28,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,58,58,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,7,7,2048 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,14,14,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,30,30,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,56,56,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,16,16,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,28,28,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,7,7,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,56,56,64 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,230,230,3 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,14,14,1024 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,28,28,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,58,58,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,7,7,2048 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,14,14,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,30,30,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,56,56,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,16,16,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,28,28,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,7,7,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,56,56,64 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + set +x +done + diff --git a/script/profile_reduce_with_index.sh b/script/profile_reduce_with_index.sh new file mode 100755 index 00000000000..d4671e39817 --- /dev/null +++ b/script/profile_reduce_with_index.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +PRECISION= +##PRECISION=--half +##PRECISION=--double +##PRECISION=--int8 +##PRECISION=--bf16 + +driver="./bin/ckProfiler" + +VERIFY="-v $1" +INIT=$2 +NREPEAT=$3 + +#### 2 - MIN, 3 - MAX, 4 - AMAX +Operations="2 4" + +## for generic validation +for op in $Operations; do + for use_idx in 0 1; do + set -x + ####### datatype layout reduce dims op use index verify init repeats + $driver reduce $PRECISION -D 64,4,280,82 -R 0,1,2,3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 0 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 1 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 1,2,3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 0,2,3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 0,1,3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,22960 -R 0 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,22960 -R 1 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 4,1469440 -R 0 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 4,1469440 -R 1 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + set +x + done +done + +Operations=2 + +## for performance evaluation (resnet50 NHWC => C) +for op in $Operations; do + for use_idx in 0 1; do + set -x + ####### datatype layout reduce dims op use index verify init repeats + $driver reduce $PRECISION -D 256,14,14,1024 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,28,28,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,58,58,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,7,7,2048 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,14,14,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,30,30,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,56,56,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,16,16,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,28,28,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,7,7,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,56,56,64 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,230,230,3 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,14,14,1024 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,28,28,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,58,58,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,7,7,2048 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,14,14,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,30,30,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,56,56,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,16,16,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,28,28,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,7,7,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,56,56,64 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + set +x + done +done + diff --git a/script/run.sh b/script/run.sh deleted file mode 100755 index 1ff56b22953..00000000000 --- a/script/run.sh +++ /dev/null @@ -1,137 +0,0 @@ -#!/bin/bash - -## GPU visibility - export ROCR_VISIBLE_DEVICE=0 - export GPU_DEVICE_ORDINAL=0 - - make -j conv_fwd_driver_offline -#make -j conv_bwd_driver_offline -#make -j conv_wrw_driver_offline -#make -j gemm_driver_offline - -DRIVER="./host/driver_offline/conv_fwd_driver_offline" -LAYOUT=$1 -ALGO=$2 -VERIFY=$3 -INIT=$4 -LOG=$5 -REPEAT=$6 - -#M01=$7 -#N01=$8 - - KBATCH=$7 - -######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 - -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 - -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 - -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 - -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1 - -######### layout algo verify init log repeat M___ N___ K___ -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01 - -# Resnet50 -######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 28 28 1 1 1 1 1 1 1 1 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 128 1 1 28 28 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 58 58 2 2 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 256 1 1 56 56 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 256 1 1 56 56 2 2 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 256 1 1 56 56 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 512 1 1 28 28 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 512 1 1 28 28 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 64 1 1 56 56 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 1 1 56 56 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 3 3 56 56 1 1 1 1 1 1 1 1 - -# 256x128x32 c64 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 56 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 56 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 224 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH - - - -# 128x128x32 c64 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 28 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 112 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 112 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 448 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 28 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 224 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 112 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH - - -# 128x64x32 c64 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112 - -# 64x128x32 c64 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH - -# 64x64x32 c32 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 112 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 448 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 448 diff --git a/script/test_convnd_fwd.sh b/script/test_convnd_fwd.sh new file mode 100644 index 00000000000..1bd7a6b5d71 --- /dev/null +++ b/script/test_convnd_fwd.sh @@ -0,0 +1,110 @@ +#!/usr/bin/env bash + +# set -e + +DIM1=False +DIM2=True +DIM3=False +DATE=220317 +GIT_HASH=4e6dfda +LOG_DIR=${DATE}_${GIT_HASH} +SUFFIX=${GIT_HASH} + + +#-------------------------------------------------------------------------- +# Commandline arguments parsing +# like: cmd -key[--key] value +#-------------------------------------------------------------------------- + +POSITIONAL=() +while [[ $# -gt 0 ]] +do +key="$1" + +case $key in + -d1|--d1) + DIM1=True + echo DIM1: "${DIM1}" + shift # past argument + ;; + -d2|--d2) + DIM2=True + echo DIM2: "${DIM2}" + shift # past argument + ;; + -d3|--d3) + DIM3=True + echo DIM3: "${DIM3}" + shift # past argument + ;; + -all|--all) + DIM1=True + DIM2=True + DIM3=True + echo DIM1: "${DIM1}" + echo DIM2: "${DIM2}" + echo DIM3: "${DIM3}" + shift # past argument + ;; + -s|--suffix) + SUFFIX=${SUFFIX}_"$2" + echo SUFFIX: "${SUFFIX}" + shift # past argument + shift # past value + ;; + *) # unknown option + POSITIONAL+=("$1") # save it in an array for later + shift # past argument + ;; +esac +done +set -- "${POSITIONAL[@]}" # restore positional parameters + +#-------------------------------------------------------------------------- + +# NUMACTL="numactl --cpunodebind=1 --membind=1" +NUMACTL= +# ENV_CONF= +GPU=mi100 +PROF_ITER_COUNT=10000 +LOG_DIR_PATH=../log/${LOG_DIR} +set -x + +#------------------------------------------------------------------------------- +# 1D +#------------------------------------------------------------------------------- + +if [[ "${DIM1}" == "True" ]]; then + mkdir -p ${LOG_DIR_PATH} + echo ">>>>>>>> RUN test conv1d nwc <<<<<<<<<<" + CMD="./../build/bin/test_conv1d_fwd" + ${NUMACTL} ${CMD} 2>&1 \ + | tee ${LOG_DIR_PATH}/test_conv1d_fwd_nwc_${SUFFIX}_${GPU}.log + +fi + +#------------------------------------------------------------------------------- +# 2D +#------------------------------------------------------------------------------- + +if [[ "${DIM2}" == "True" ]]; then + mkdir -p ${LOG_DIR_PATH} + echo ">>>>>>>> RUN test conv2d nhwc <<<<<<<<<<" + CMD="./../build/bin/test_conv2d_fwd" + ${NUMACTL} ${CMD} 2>&1 \ + | tee ${LOG_DIR_PATH}/test_conv2d_fwd_nhwc_${SUFFIX}_${GPU}.log + +fi + +#------------------------------------------------------------------------------- +# 3D +#------------------------------------------------------------------------------- + +if [[ "${DIM3}" == "True" ]]; then + mkdir -p ${LOG_DIR_PATH} + echo ">>>>>>>> RUN test conv3d ndhwc <<<<<<<<<<" + CMD="./../build/bin/test_conv3d_fwd" + ${NUMACTL} ${CMD} 2>&1 \ + | tee ${LOG_DIR_PATH}/test_conv3d_fwd_ndhwc_${SUFFIX}_${GPU}.log + +fi diff --git a/script/test_reduce_no_index.sh b/script/test_reduce_no_index.sh new file mode 100755 index 00000000000..b9563038370 --- /dev/null +++ b/script/test_reduce_no_index.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +## The following will be used for CI + +set -x + +## for float +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2,3 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,3 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,2,3 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1,2,3 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 2 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 3 0 2 + +## for float64 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2,3 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,3 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,2,3 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1,2,3 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 2 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 3 6 2 + +## for float16 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2,3 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,3 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,2,3 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1,2,3 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 2 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 3 1 2 + +## for int8_t +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2,3 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,3 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,2,3 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1,2,3 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 2 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 3 3 2 + +## for bfloat16 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2,3 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,3 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,2,3 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1,2,3 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 2 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 3 5 2 + +set +x + diff --git a/script/test_reduce_with_index.sh b/script/test_reduce_with_index.sh new file mode 100755 index 00000000000..b0843ba6c1b --- /dev/null +++ b/script/test_reduce_with_index.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +## The following will be used for CI + +set -x + +## for float +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 2 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 3 0 2 + +## for float64 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 2 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 3 6 2 + +## for float16 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 2 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 3 1 2 + +## for int8_t +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 2 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 3 3 2 + +## for bfloat16 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 2 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 3 5 2 + +set +x + diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt new file mode 100644 index 00000000000..b05ec8d3287 --- /dev/null +++ b/test/CMakeLists.txt @@ -0,0 +1,68 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/ + ${PROJECT_SOURCE_DIR}/include/ck + ${PROJECT_SOURCE_DIR}/include/ck/utility + ${PROJECT_SOURCE_DIR}/include/ck/host_utility + ${PROJECT_SOURCE_DIR}/include/ck/tensor_description + ${PROJECT_SOURCE_DIR}/include/ck/tensor + ${PROJECT_SOURCE_DIR}/include/ck/problem_transform + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/device + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/grid + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/block + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/warp + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/thread + ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/element + ${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor + ${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance + ${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance/gpu/reduce + ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu + ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu + ${PROJECT_SOURCE_DIR}/library/include/ck/library/utility + ${PROJECT_SOURCE_DIR}/test/include + ${PROJECT_SOURCE_DIR}/profiler/include + ${PROJECT_SOURCE_DIR}/external/include/half +) + +include(googletest) + +add_custom_target(tests) + + +function(add_test_executable TEST_NAME) + message("adding test ${TEST_NAME}") + add_executable(${TEST_NAME} ${ARGN}) + add_test(NAME ${TEST_NAME} COMMAND $ ) + add_dependencies(tests ${TEST_NAME}) + add_dependencies(check ${TEST_NAME}) +endfunction(add_test_executable TEST_NAME) + +include(GoogleTest) + +function(add_gtest_executable TEST_NAME) + message("adding gtest ${TEST_NAME}") + add_executable(${TEST_NAME} ${ARGN}) + add_dependencies(tests ${TEST_NAME}) + add_dependencies(check ${TEST_NAME}) + # suppress gtest warnings + target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef) + target_link_libraries(${TEST_NAME} PRIVATE gtest_main) + gtest_discover_tests(${TEST_NAME}) +endfunction(add_gtest_executable TEST_NAME) + + +add_subdirectory(magic_number_division) +add_subdirectory(space_filling_curve) +add_subdirectory(conv_util) +add_subdirectory(reference_conv_fwd) +add_subdirectory(gemm) +add_subdirectory(gemm_split_k) +add_subdirectory(gemm_reduce) +add_subdirectory(batched_gemm) +add_subdirectory(batched_gemm_reduce) +add_subdirectory(grouped_gemm) +add_subdirectory(convnd_fwd) +add_subdirectory(reduce) +add_subdirectory(conv2d_bwd_weight) +add_subdirectory(convnd_bwd_data) +add_subdirectory(block_to_ctile_map) +# DONOT add client_app, that is tested via CI independently diff --git a/test/batched_gemm/CMakeLists.txt b/test/batched_gemm/CMakeLists.txt new file mode 100644 index 00000000000..b70e3aae9b2 --- /dev/null +++ b/test/batched_gemm/CMakeLists.txt @@ -0,0 +1,4 @@ +add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp) +target_link_libraries(test_batched_gemm_fp16 PRIVATE host_tensor) +target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance) + diff --git a/test/batched_gemm/batched_gemm_fp16.cpp b/test/batched_gemm/batched_gemm_fp16.cpp new file mode 100644 index 00000000000..c039e344d29 --- /dev/null +++ b/test/batched_gemm/batched_gemm_fp16.cpp @@ -0,0 +1,41 @@ +#include + +#include "profile_batched_gemm_impl.hpp" + +namespace { +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +} // namespace + +int main() +{ + int M = 512; + int N = 256; + int K = 128; + int BatchCount = 3; + + bool pass = true; + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, K, N, N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, K, K, N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, M, N, N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, M, K, N, BatchCount); + + std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl; + return pass ? 0 : 1; +} diff --git a/test/batched_gemm/batched_gemm_util.hpp b/test/batched_gemm/batched_gemm_util.hpp new file mode 100644 index 00000000000..0a5c471d401 --- /dev/null +++ b/test/batched_gemm/batched_gemm_util.hpp @@ -0,0 +1,106 @@ +#ifndef BATCHED_GEMM_UTILS_HPP +#define BATCHED_GEMM_UTILS_HPP + +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace batched_gemm_util { + +struct GemmParams +{ + GemmParams() + : M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0) + { + } + + ck::index_t M; + ck::index_t N; + ck::index_t K; + + ck::index_t StrideA; + ck::index_t StrideB; + ck::index_t StrideC; + + float alpha; + float beta; +}; + +template +void RunHostBatchedGemm(const Tensor& A, + const Tensor& B, + Tensor& C, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) +{ + auto ref_batched_gemm = BatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + auto ref_argument = + ref_batched_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); +} + +template +void RunDeviceBatchedGemm(DeviceGemmPtr& batched_gemm_ptr, + const ck::batched_gemm_util::GemmParams& params, + const Tensor& A, + const Tensor& B, + Tensor& C, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) +{ + DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace()); + DeviceMem b_g_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); + DeviceMem c_g_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); + + a_g_m_k_device_buf.ToDevice(A.mData.data()); + b_g_k_n_device_buf.ToDevice(B.mData.data()); + + const auto batch_count = A.mDesc.GetLengths()[0]; + auto invoker_ptr = batched_gemm_ptr->MakeInvokerPointer(); + auto argument_ptr = batched_gemm_ptr->MakeArgumentPointer( + static_cast(a_g_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_g_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_g_m_n_device_buf.GetDeviceBuffer()), + params.M, + params.N, + params.K, + params.StrideA, + params.StrideB, + params.StrideC, + a_element_op, + b_element_op, + c_element_op, + batch_count); + + if(!batched_gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + invoker_ptr->Run(argument_ptr.get()); + c_g_m_n_device_buf.FromDevice(C.mData.data()); +} + +} // namespace batched_gemm_util +} // namespace ck +#endif diff --git a/test/batched_gemm_reduce/CMakeLists.txt b/test/batched_gemm_reduce/CMakeLists.txt new file mode 100644 index 00000000000..3ecf19491be --- /dev/null +++ b/test/batched_gemm_reduce/CMakeLists.txt @@ -0,0 +1,9 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/profiler/include + ${PROJECT_SOURCE_DIR}/test/include + ${PROJECT_SOURCE_DIR}/external/include/half +) + +add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) +target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE host_tensor) +target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance) diff --git a/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp b/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp new file mode 100644 index 00000000000..7b311cff170 --- /dev/null +++ b/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp @@ -0,0 +1,64 @@ +#include + +#include "profile_batched_gemm_reduce_impl.hpp" + +int main() +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + int M = 512; + int N = 256; + int K = 128; + + int BatchCount = 3; + + bool pass = true; + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, K, N, N, BatchCount); + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, K, K, N, BatchCount); + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, M, N, N, BatchCount); + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, M, K, N, BatchCount); + + if(pass) + { + std::cout << "test BatchedGEMM+Reduce fp16: Pass" << std::endl; + return 0; + } + else + { + std::cout << "test BatchedGEMM+Reduce fp16: Fail" << std::endl; + return -1; + } +} diff --git a/test/block_to_ctile_map/CMakeLists.txt b/test/block_to_ctile_map/CMakeLists.txt new file mode 100644 index 00000000000..97dfbb2b552 --- /dev/null +++ b/test/block_to_ctile_map/CMakeLists.txt @@ -0,0 +1 @@ +add_gtest_executable(test_block_to_ctile_map test_block_to_ctile_map.cpp) \ No newline at end of file diff --git a/test/block_to_ctile_map/test_block_to_ctile_map.cpp b/test/block_to_ctile_map/test_block_to_ctile_map.cpp new file mode 100644 index 00000000000..662d2a0fa57 --- /dev/null +++ b/test/block_to_ctile_map/test_block_to_ctile_map.cpp @@ -0,0 +1,318 @@ +#include +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "gtest/gtest.h" +#include +#include + +using namespace ck; + +static auto I0 = Number<0>{}; +static auto I1 = Number<1>{}; +static auto I2 = Number<2>{}; + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck1) +{ + const index_t M = 384; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + const index_t M01 = 4; + const index_t N01 = 4; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01, N01) = (%d, %d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01, + N01); + + BlockToCTileMap_M00_N00_M01_N01 tile_map( + c_grid_desc_m_n, M01, N01); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 16); + + // clang-format off + std::vector> expected_m0idx_n0idx_valid = { + {0, 0, 1}, + {0, 1, 1}, + {0, 2, 1}, + {0, 3, 0}, + {1, 0, 1}, + {1, 1, 1}, + {1, 2, 1}, + {1, 3, 0}, + {2, 0, 1}, + {2, 1, 1}, + {2, 2, 1}, + {2, 3, 0}, + {3, 0, 0}, + {3, 1, 0}, + {3, 2, 0}, + {3, 3, 0} + }; + // clang-format on + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto m0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", m0, n0 = " << m0n0_idx[I0] << ", " << m0n0_idx[I1]; + std::cout << ", valid = " << tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_m0idx_n0idx_valid[i] == + std::vector{m0n0_idx[I0], + m0n0_idx[I1], + tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck0) +{ + const index_t M = 384; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + + const index_t M01 = 4; + const index_t N01 = 4; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01, N01) = (%d, %d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01, + N01); + + BlockToCTileMap_M00_N00_M01_N01 + tile_map(c_grid_desc_m_n, M01, N01); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == false); +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N0_M01_DeviceCTileIndexCheck1) +{ + const index_t M = 384; + const index_t N = 512; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + const index_t M01 = 4; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_M00_N0_M01 tile_map( + c_grid_desc_m_n, M01); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 16); + + // clang-format off + std::vector> expected_m0idx_n0idx_valid = { + {0, 0, 1}, + {1, 0, 1}, + {2, 0, 1}, + {3, 0, 0}, + {0, 1, 1}, + {1, 1, 1}, + {2, 1, 1}, + {3, 1, 0}, + {0, 2, 1}, + {1, 2, 1}, + {2, 2, 1}, + {3, 2, 0}, + {0, 3, 1}, + {1, 3, 1}, + {2, 3, 1}, + {3, 3, 0} + }; + // clang-format on + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto m0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", m0, n0 = " << m0n0_idx[I0] << ", " << m0n0_idx[I1]; + std::cout << ", valid = " << tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_m0idx_n0idx_valid[i] == + std::vector{m0n0_idx[I0], + m0n0_idx[I1], + tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N0_M01_DeviceCTileIndexCheck0) +{ + const index_t M = 512; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + // clang-format off + std::vector> expected_m0_gridsize_validity = { + {5, 15, false}, + {4, 12, true}, + {3, 18, false}, + {2, 12, true}, + {1, 12, true} + }; + // clang-format on + + for(auto e : expected_m0_gridsize_validity) + { + const index_t M01 = std::get<0>(e); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_M00_N0_M01 tile_map( + c_grid_desc_m_n, M01); + + EXPECT_EQ(tile_map.CalculateGridSize(c_grid_desc_m_n), std::get<1>(e)); + EXPECT_EQ(tile_map.CheckValidity(c_grid_desc_m_n), std::get<2>(e)); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N0_M01Adapt) +{ + const index_t M = 768; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + constexpr index_t M01 = 4; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_M00_N0_M01Adapt tile_map( + c_grid_desc_m_n, M01); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 18); + + // clang-format off + std::vector> expected_m0idx_n0idx_valid = { + {0, 0, 1}, + {1, 0, 1}, + {2, 0, 1}, + {3, 0, 1}, + {0, 1, 1}, + {1, 1, 1}, + {2, 1, 1}, + {3, 1, 1}, + {0, 2, 1}, + {1, 2, 1}, + {2, 2, 1}, + {3, 2, 1}, + {4, 0, 1}, + {5, 0, 1}, + {4, 1, 1}, + {5, 1, 1}, + {4, 2, 1}, + {5, 2, 1}, + }; + // clang-format on + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto m0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", m0, n0 = " << m0n0_idx[I0] << ", " << m0n0_idx[I1]; + std::cout << ", valid = " << tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_m0idx_n0idx_valid[i] == + std::vector{m0n0_idx[I0], + m0n0_idx[I1], + tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_KSplit_M00_N0_M01Adapt) +{ + const index_t M = 768; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + constexpr index_t M01 = 4; + const index_t KSplit = 3; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_KSplit_M00_N0_M01Adapt + tile_map(c_grid_desc_m_n, M01, KSplit); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 18 * KSplit); + + std::vector> expected_ksplitidx_m0idx_n0idx_valid = { + {0, 0, 0, 1}, {0, 1, 0, 1}, {0, 2, 0, 1}, {0, 3, 0, 1}, {0, 0, 1, 1}, {0, 1, 1, 1}, + {0, 2, 1, 1}, {0, 3, 1, 1}, {0, 0, 2, 1}, {0, 1, 2, 1}, {0, 2, 2, 1}, {0, 3, 2, 1}, + {0, 4, 0, 1}, {0, 5, 0, 1}, {0, 4, 1, 1}, {0, 5, 1, 1}, {0, 4, 2, 1}, {0, 5, 2, 1}, + {1, 0, 0, 1}, {1, 1, 0, 1}, {1, 2, 0, 1}, {1, 3, 0, 1}, {1, 0, 1, 1}, {1, 1, 1, 1}, + {1, 2, 1, 1}, {1, 3, 1, 1}, {1, 0, 2, 1}, {1, 1, 2, 1}, {1, 2, 2, 1}, {1, 3, 2, 1}, + {1, 4, 0, 1}, {1, 5, 0, 1}, {1, 4, 1, 1}, {1, 5, 1, 1}, {1, 4, 2, 1}, {1, 5, 2, 1}, + {2, 0, 0, 1}, {2, 1, 0, 1}, {2, 2, 0, 1}, {2, 3, 0, 1}, {2, 0, 1, 1}, {2, 1, 1, 1}, + {2, 2, 1, 1}, {2, 3, 1, 1}, {2, 0, 2, 1}, {2, 1, 2, 1}, {2, 2, 2, 1}, {2, 3, 2, 1}, + {2, 4, 0, 1}, {2, 5, 0, 1}, {2, 4, 1, 1}, {2, 5, 1, 1}, {2, 4, 2, 1}, {2, 5, 2, 1}, + }; + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto ksplitm0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", ksplit, m0, n0 = " << ksplitm0n0_idx[I0] << ", " + << ksplitm0n0_idx[I1] << ", " << ksplitm0n0_idx[I2]; + std::cout << ", valid = " + << tile_map.ValidCTileIndex(ksplitm0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_ksplitidx_m0idx_n0idx_valid[i] == + std::vector{ksplitm0n0_idx[I0], + ksplitm0n0_idx[I1], + ksplitm0n0_idx[I2], + tile_map.ValidCTileIndex(ksplitm0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +} diff --git a/test/client_app/CMakeLists.txt b/test/client_app/CMakeLists.txt new file mode 100644 index 00000000000..f8dd8c4e0ad --- /dev/null +++ b/test/client_app/CMakeLists.txt @@ -0,0 +1,11 @@ +cmake_minimum_required(VERSION 3.15) +project(ck_app) +add_compile_options(-std=c++14) + +find_package(composable_kernel 1.0.0 COMPONENTS device_operations host_tensor) +find_package(hip REQUIRED PATHS /opt/rocm) +message(STATUS "Build with HIP ${hip_VERSION}") + +add_executable(test_client_app client_app.cpp) + +target_link_libraries(test_client_app PRIVATE composable_kernel::device_operations composable_kernel::host_tensor hip::host) diff --git a/test/client_app/client_app.cpp b/test/client_app/client_app.cpp new file mode 100644 index 00000000000..665a103f706 --- /dev/null +++ b/test/client_app/client_app.cpp @@ -0,0 +1,77 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "client_app_impl.hpp" + +int main(int argc, char* argv[]) +{ + if(argc != 25) + { + printf("arg1: tensor operation (conv_fwd: ForwardConvolution)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: time kernel (0=n0, 1=yes)\n"); + printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + const ConvDataType data_type = static_cast(std::stoi(argv[2])); + const int in_layout = static_cast(std::stoi(argv[3])); + const int wei_layout = static_cast(std::stoi(argv[4])); + const int out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const bool time_kernel = std::stoi(argv[9]); + + const ck::index_t N = std::stoi(argv[10]); + const ck::index_t K = std::stoi(argv[11]); + const ck::index_t C = std::stoi(argv[12]); + const ck::index_t Y = std::stoi(argv[13]); + const ck::index_t X = std::stoi(argv[14]); + const ck::index_t Hi = std::stoi(argv[15]); + const ck::index_t Wi = std::stoi(argv[16]); + + const ck::index_t conv_stride_h = std::stoi(argv[17]); + const ck::index_t conv_stride_w = std::stoi(argv[18]); + const ck::index_t conv_dilation_h = std::stoi(argv[19]); + const ck::index_t conv_dilation_w = std::stoi(argv[20]); + const ck::index_t in_left_pad_h = std::stoi(argv[21]); + const ck::index_t in_left_pad_w = std::stoi(argv[22]); + const ck::index_t in_right_pad_h = std::stoi(argv[23]); + const ck::index_t in_right_pad_w = std::stoi(argv[24]); + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + ck::app::profile_conv_fwd_impl(do_verification, + init_method, + do_log, + time_kernel, + data_type, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}); + return 1; +} diff --git a/test/client_app/client_app_impl.hpp b/test/client_app/client_app_impl.hpp new file mode 100644 index 00000000000..f9e4145ba01 --- /dev/null +++ b/test/client_app/client_app_impl.hpp @@ -0,0 +1,214 @@ +#pragma once + +#include "host_interface.hpp" + +enum ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +enum ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; + +void check_hip_error(void) +{ + hipError_t err = hipGetLastError(); + if(err != hipSuccess) + { + std::cerr << "Error: " << hipGetErrorString(err) << std::endl; + exit(err); + } +} +std::string getDeviceName(int device) +{ + struct hipDeviceProp_t prop; + hipGetDeviceProperties(&prop, device); + check_hip_error(); + return std::string(prop.name); +} + +int getDriver(void) +{ + int driver; + hipDriverGetVersion(&driver); + check_hip_error(); + return driver; +} + +namespace ck { +namespace app { +struct DeviceMem +{ + DeviceMem() = delete; + DeviceMem(std::size_t mem_size); + void* GetDeviceBuffer(); + void ToDevice(const void* p); + void FromDevice(void* p); + ~DeviceMem(); + + void* mpDeviceBuf; + std::size_t mMemSize; +}; + +DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) +{ + hipGetErrorString(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); +} + +void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; } + +void DeviceMem::ToDevice(const void* p) +{ + hipGetErrorString( + hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); +} + +void DeviceMem::FromDevice(void* p) +{ + hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); +} + +DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); } + +void profile_conv_fwd_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + ConvDataType data_type, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) +{ + const ck::index_t Y = filter_spatial_lengths[0]; + const ck::index_t X = filter_spatial_lengths[1]; + + const ck::index_t Hi = input_spatial_lengths[0]; + const ck::index_t Wi = input_spatial_lengths[1]; + + const ck::index_t Ho = output_spatial_lengths[0]; + const ck::index_t Wo = output_spatial_lengths[1]; + + const auto in_sz = N * C * Hi * Wi; + const auto wei_sz = K * C * Y * X; + const auto out_sz = N * K * Ho * Wo; + + using WeiDataType = float; + using InDataType = float; + using OutDataType = float; + + app::DeviceMem in_device_buf(sizeof(InDataType) * in_sz); + app::DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_sz); + app::DeviceMem out_device_buf(sizeof(OutDataType) * out_sz); + // data is already on device! + + // add device Conv instances + std::vector conv_ptrs; + if(data_type == F16_F16_F16) + { + add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(conv_ptrs); + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(conv_ptrs); + } + else if(data_type == BF16_BF16_BF16) + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(conv_ptrs); + else if(data_type == F32_F32_F32) + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(conv_ptrs); + else if(data_type == INT8_INT8_INT8) + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(conv_ptrs); + else + throw std::runtime_error("wrong! Invalid data type"); + if(conv_ptrs.empty()) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + int deviceIndex = 0; + hipSetDevice(deviceIndex); + check_hip_error(); + + StreamConfig stream_config{nullptr, time_kernel}; + hipStreamCreate(&stream_config.stream_id_); + check_hip_error(); + + // profile device Conv instances + for(auto& conv_ptr : conv_ptrs) + { + auto argument_ptr = + conv_ptr.MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + auto invoker_ptr = conv_ptr.MakeInvokerPointer(); + + if(conv_ptr.IsSupportedArgument(argument_ptr.get())) + { + std::string conv_name = conv_ptr.GetTypeString(); + float ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << conv_name << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; +} + +} // namespace app +} // namespace ck diff --git a/test/conv2d_bwd_data/CMakeLists.txt b/test/conv2d_bwd_data/CMakeLists.txt new file mode 100644 index 00000000000..1b5c03afa30 --- /dev/null +++ b/test/conv2d_bwd_data/CMakeLists.txt @@ -0,0 +1,3 @@ +add_test_executable(test_conv2d_bwd_data conv2d_bwd_data.cpp) +target_link_libraries(test_conv2d_bwd_data PRIVATE host_tensor) +target_link_libraries(test_conv2d_bwd_data PRIVATE device_conv2d_bwd_data_instance) diff --git a/test/conv2d_bwd_data/conv2d_bwd_data.cpp b/test/conv2d_bwd_data/conv2d_bwd_data.cpp new file mode 100644 index 00000000000..c8eb5413dcc --- /dev/null +++ b/test/conv2d_bwd_data/conv2d_bwd_data.cpp @@ -0,0 +1,327 @@ +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "device_conv_bwd_data.hpp" +#include "element_wise_operation.hpp" +#include "reference_conv_bwd_data.hpp" + +using F16 = ck::half_t; +using F32 = float; +using BF16 = ck::bhalf_t; +using INT8 = int8_t; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using DeviceConvBwdDataNoOpPtr = + DeviceConvBwdDataPtr; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector&); + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +template +static bool check_out(const Tensor& ref, const Tensor& result) +{ + float max_diff = 1e-6; + + for(int i = 0; i < ref.mData.size(); ++i) + { + float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); + if(max_diff < diff) + { + return false; + } + } + + return true; +} + +int main(int argc, char* argv[]) +{ + int data_type = 0; + int init_method = 0; + + // Conv shape + ck::index_t N = 128; + ck::index_t K = 256; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t conv_stride_h = 2; + ck::index_t conv_stride_w = 2; + ck::index_t conv_dilation_h = 1; + ck::index_t conv_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + if(argc == 1) + { + data_type = 1; + init_method = 1; + } + else if(argc == 3) + { + data_type = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + } + else if(argc == 18) + { + data_type = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + + N = std::stoi(argv[3]); + K = std::stoi(argv[4]); + C = std::stoi(argv[5]); + Y = std::stoi(argv[6]); + X = std::stoi(argv[7]); + Hi = std::stoi(argv[8]); + Wi = std::stoi(argv[9]); + conv_stride_h = std::stoi(argv[10]); + conv_stride_w = std::stoi(argv[11]); + conv_dilation_h = std::stoi(argv[12]); + conv_dilation_w = std::stoi(argv[13]); + in_left_pad_h = std::stoi(argv[14]); + in_left_pad_w = std::stoi(argv[15]); + in_right_pad_h = std::stoi(argv[16]); + in_right_pad_w = std::stoi(argv[17]); + } + else + { + printf("arg1: data type (0=fp32, 1=fp16, 2= bfp16, 3= int8_t )\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + auto Run = [&](auto input_type, auto wei_type, auto out_type, auto acc_type) { + using InDataType = decltype(input_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + using AccDataType = decltype(acc_type); + + using ReferenceConvBwdInstance = + ck::tensor_operation::host::ReferenceConvBwdData; + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + const std::vector input_spatial_lengths{{Hi, Wi}}; + const std::vector filter_spatial_lengths{{Y, X}}; + const std::vector output_spatial_lengths{{Ho, Wo}}; + const std::vector conv_filter_strides{{conv_stride_h, conv_stride_w}}; + const std::vector conv_filter_dilations{{conv_dilation_h, conv_dilation_w}}; + const std::vector input_left_pads{{in_left_pad_h, in_left_pad_w}}; + const std::vector input_right_pads{{in_right_pad_h, in_right_pad_w}}; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W) { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + }; + + Tensor out_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo)); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X)); + Tensor in_n_c_hi_wi_host_result(f_host_tensor_descriptor(N, C, Hi, Wi)); + Tensor in_n_c_hi_wi_device_result(f_host_tensor_descriptor(N, C, Hi, Wi)); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * + in_n_c_hi_wi_device_result.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + // reset input to zero + in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1{0}); + in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data()); + + // get host result + { + auto ref_conv = ReferenceConvBwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result, + wei_k_c_y_x, + out_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + ref_invoker.Run(ref_argument); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using DeviceConvBwdDataNoOpPtr = ck::tensor_operation::device:: + DeviceConvBwdDataPtr; + + // add device Conv instances + std::vector conv_ptrs; + + if constexpr(ck::is_same_v, float> && + ck::is_same_v, float> && + ck::is_same_v, float>) + { + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + } + else if constexpr(ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t>) + { + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + } + else if constexpr(ck::is_same_v, ck::bhalf_t> && + ck::is_same_v, ck::bhalf_t> && + ck::is_same_v, ck::bhalf_t>) + { + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); + } + else if constexpr(ck::is_same_v, int8_t> && + ck::is_same_v, int8_t> && + ck::is_same_v, int8_t>) + { + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); + } + + if(conv_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + // profile device Conv instances + bool success = true; + for(auto& conv_ptr : conv_ptrs) + { + auto argument_ptr = conv_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(conv_ptr->IsSupportedArgument(argument_ptr.get())) + { + auto invoker_ptr = conv_ptr->MakeInvokerPointer(); + invoker_ptr->Run(argument_ptr.get(), 1); + + in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); + + if(!check_out(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result)) + { + std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl; + success = false; + } + else + { + std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl; + } + } + else + { + std::cout << "Not support Info: " << conv_ptr->GetTypeString() << std::endl; + } + } + + if(success) + { + std::cout << "test conv2d bwd : Pass" << std::endl; + return 0; + } + else + { + std::cout << "test conv2d bwd: Fail " << std::endl; + return -1; + } + }; + + if(data_type == 0) + { + return Run(F32(), F32(), F32(), F32()); + } + else if(data_type == 1) + { + return Run(F16(), F16(), F16(), F32()); + } + else if(data_type == 2) + { + return Run(BF16(), BF16(), BF16(), F32()); + } + else if(data_type == 3) + { + return Run(INT8(), INT8(), INT8(), int()); + } + else + { + return 1; + } +} diff --git a/test/conv2d_bwd_weight/CMakeLists.txt b/test/conv2d_bwd_weight/CMakeLists.txt new file mode 100644 index 00000000000..ecd5336c1f3 --- /dev/null +++ b/test/conv2d_bwd_weight/CMakeLists.txt @@ -0,0 +1,7 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/profiler/include + ${PROJECT_SOURCE_DIR}/external/include/half +) + +add_test_executable(test_conv2d_bwd_weight conv2d_bwd_weight.cpp) +target_link_libraries(test_conv2d_bwd_weight PRIVATE host_tensor device_conv2d_bwd_weight_instance conv_util) diff --git a/test/conv2d_bwd_weight/conv2d_bwd_weight.cpp b/test/conv2d_bwd_weight/conv2d_bwd_weight.cpp new file mode 100644 index 00000000000..671980f49e4 --- /dev/null +++ b/test/conv2d_bwd_weight/conv2d_bwd_weight.cpp @@ -0,0 +1,216 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "conv_util.hpp" +#include "profile_conv_bwd_weight_impl.hpp" + +int test_self() +{ + bool pass = true; + std::vector params; + + params.push_back({2, 128, 256, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + params.push_back({2, 128, 256, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + params.push_back({2, 128, 256, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + + for(auto& param : params) + { + // f32 + pass &= ck::profiler::profile_conv_bwd_weight_impl<2, + float, + float, + float, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + + // fp16 + pass &= ck::profiler::profile_conv_bwd_weight_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + } + return pass; +} +int main(int argc, char* argv[]) +{ + int data_type = 1; + int init_method = 1; + + // Conv shape + ck::index_t N = 128; + ck::index_t K = 256; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t conv_stride_h = 2; + ck::index_t conv_stride_w = 2; + ck::index_t conv_dilation_h = 1; + ck::index_t conv_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + ck::index_t split_k = 1; + + bool pass = true; + if(argc == 1) + { + pass = test_self(); + } + else + { + if(argc == 3) + { + data_type = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + } + else if(argc == 19) + { + data_type = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + + N = std::stoi(argv[3]); + K = std::stoi(argv[4]); + C = std::stoi(argv[5]); + Y = std::stoi(argv[6]); + X = std::stoi(argv[7]); + Hi = std::stoi(argv[8]); + Wi = std::stoi(argv[9]); + conv_stride_h = std::stoi(argv[10]); + conv_stride_w = std::stoi(argv[11]); + conv_dilation_h = std::stoi(argv[12]); + conv_dilation_w = std::stoi(argv[13]); + in_left_pad_h = std::stoi(argv[14]); + in_left_pad_w = std::stoi(argv[15]); + in_right_pad_h = std::stoi(argv[16]); + in_right_pad_w = std::stoi(argv[17]); + split_k = std::stoi(argv[18]); + } + else + { + printf("arg1: data type (0=fp32, 1=fp16, 2= bfp16, 3= int8_t )\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + ck::utils::conv::ConvParams param{2, + N, + K, + C, + {Y, X}, + {Hi, Wi}, + {conv_stride_h, conv_stride_w}, + {conv_dilation_h, conv_dilation_w}, + {in_left_pad_h, in_left_pad_w}, + {in_right_pad_h, in_right_pad_w}}; + if(data_type == 0) + { + pass = ck::profiler::profile_conv_bwd_weight_impl<2, + float, + float, + float, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + init_method, + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + split_k); + } + else if(data_type == 1) + { + pass = ck::profiler::profile_conv_bwd_weight_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + init_method, + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + split_k); + } + else + { + std::cout << "Not support data type" << std::endl; + return 1; + } + } + + if(pass) + { + std::cout << "test conv2d bwd weight : Pass" << std::endl; + return 0; + } + else + { + std::cout << "test conv2d bwd weight: Fail " << std::endl; + return -1; + } +} diff --git a/test/conv_util/CMakeLists.txt b/test/conv_util/CMakeLists.txt new file mode 100644 index 00000000000..795c9ec0ac9 --- /dev/null +++ b/test/conv_util/CMakeLists.txt @@ -0,0 +1,2 @@ +add_gtest_executable(test_conv_util conv_util.cpp) +target_link_libraries(test_conv_util PRIVATE host_tensor conv_util) diff --git a/test/conv_util/conv_util.cpp b/test/conv_util/conv_util.cpp new file mode 100644 index 00000000000..98f55b872e2 --- /dev/null +++ b/test/conv_util/conv_util.cpp @@ -0,0 +1,203 @@ +#include +#include +#include +#include + +#include "config.hpp" +#include "conv_util.hpp" +#include "tensor_layout.hpp" +#include "check_err.hpp" + +namespace { + +class TestConvUtil : public ::testing::Test +{ + public: + void SetNDParams(std::size_t ndims) + { + conv_params.num_dim_spatial_ = ndims; + conv_params.filter_spatial_lengths_ = std::vector(ndims, 3); + conv_params.input_spatial_lengths_ = std::vector(ndims, 71); + conv_params.conv_filter_strides_ = std::vector(ndims, 2); + conv_params.conv_filter_dilations_ = std::vector(ndims, 1); + conv_params.input_left_pads_ = std::vector(ndims, 1); + conv_params.input_right_pads_ = std::vector(ndims, 1); + } + + protected: + // ------- default 2D ------- + // input NCHW {128,192,71,71}, + // weights KCYX {256,192,3,3}, + // stride {2,2}, + // dilations {1,1}, + // padding {{1,1}, {1,1}} + ck::utils::conv::ConvParams conv_params; +}; + +} // namespace + +TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D) +{ + ck::utils::conv::ConvParams conv_params; + std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{36, 36}, + "Error: ConvParams 2D default constructor.")); + + conv_params.conv_filter_strides_ = std::vector{1, 1}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{71, 71}, "Error: ConvParams 2D stride {1,1}.")); + + conv_params.conv_filter_strides_ = std::vector{2, 2}; + conv_params.input_left_pads_ = std::vector{2, 2}; + conv_params.input_right_pads_ = std::vector{2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{37, 37}, + "Error: ConvParams 2D padding left/right {2,2}.")); + + conv_params.conv_filter_dilations_ = std::vector{2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{36, 36}, "Error: ConvParams 2D dilation {2,2}.")); + + conv_params.conv_filter_strides_ = std::vector{3, 3}; + conv_params.input_left_pads_ = std::vector{1, 1}; + conv_params.input_right_pads_ = std::vector{1, 1}; + conv_params.conv_filter_dilations_ = std::vector{2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE( + ck::utils::check_err(out_spatial_len, + std::vector{23, 23}, + "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}.")); +} + +TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D) +{ + SetNDParams(1); + + std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{36}, "Error: ConvParams 1D.")); + + conv_params.conv_filter_strides_ = std::vector{1}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{71}, "Error: ConvParams 1D stride {1}.")); + + conv_params.conv_filter_strides_ = std::vector{2}; + conv_params.input_left_pads_ = std::vector{2}; + conv_params.input_right_pads_ = std::vector{2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{37}, + "Error: ConvParams 1D padding left/right {2}.")); + + conv_params.conv_filter_dilations_ = std::vector{2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{36}, "Error: ConvParams 1D dilation {2}.")); + + conv_params.conv_filter_strides_ = std::vector{3}; + conv_params.input_left_pads_ = std::vector{1}; + conv_params.input_right_pads_ = std::vector{1}; + conv_params.conv_filter_dilations_ = std::vector{2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE( + ck::utils::check_err(out_spatial_len, + std::vector{23}, + "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}.")); +} + +TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D) +{ + SetNDParams(3); + + std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{36, 36, 36}, "Error: ConvParams 3D.")); + + conv_params.conv_filter_strides_ = std::vector{1, 1, 1}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{71, 71, 71}, + "Error: ConvParams 3D stride {1, 1, 1}.")); + + conv_params.conv_filter_strides_ = std::vector{2, 2, 2}; + conv_params.input_left_pads_ = std::vector{2, 2, 2}; + conv_params.input_right_pads_ = std::vector{2, 2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{37, 37, 37}, + "Error: ConvParams 3D padding left/right {2, 2, 2}.")); + + conv_params.conv_filter_dilations_ = std::vector{2, 2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{36, 36, 36}, + "Error: ConvParams 3D dilation {2, 2, 2}.")); + + conv_params.conv_filter_strides_ = std::vector{3, 3, 3}; + conv_params.input_left_pads_ = std::vector{1, 1, 1}; + conv_params.input_right_pads_ = std::vector{1, 1, 1}; + conv_params.conv_filter_dilations_ = std::vector{2, 2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, + std::vector{23, 23, 23}, + "Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}.")); +} + +TEST(ConvUtil, GetHostTensorDescriptor) +{ + namespace tl = ck::tensor_layout::convolution; + std::vector dims{2, 3, 4, 5}; + HostTensorDescriptor h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NHWC{}); + EXPECT_TRUE(ck::utils::check_err( + h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NHWC dimensions lengths!")); + EXPECT_TRUE(ck::utils::check_err( + h.GetStrides(), {3 * 4 * 5, 1, 3 * 5, 3}, "Error: wrong NHWC dimensions strides!")); + + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCHW{}); + EXPECT_TRUE(ck::utils::check_err( + h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NCHW dimensions lengths!")); + EXPECT_TRUE(ck::utils::check_err( + h.GetStrides(), {3 * 4 * 5, 4 * 5, 5, 1}, "Error: wrong NCHW dimensions strides!")); + + dims = std::vector{2, 3, 4}; + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NWC{}); + EXPECT_TRUE( + ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NWC dimensions lengths!")); + EXPECT_TRUE(ck::utils::check_err( + h.GetStrides(), {3 * 4, 1, 3}, "Error: wrong NWC dimensions strides!")); + + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCW{}); + EXPECT_TRUE( + ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NCW dimensions lengths!")); + EXPECT_TRUE(ck::utils::check_err( + h.GetStrides(), {3 * 4, 4, 1}, "Error: wrong NCW dimensions strides!")); + + dims = std::vector{2, 3, 4, 5, 6}; + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NDHWC{}); + EXPECT_TRUE( + ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NDHWC dimensions lengths!")); + EXPECT_TRUE(ck::utils::check_err(h.GetStrides(), + {3 * 4 * 5 * 6, // N + 1, // C + 3 * 5 * 6, // D + 3 * 6, // H + 3}, // W + "Error: wrong NDHWC dimensions strides!")); + + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCDHW{}); + EXPECT_TRUE( + ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NCDHW dimensions lengths!")); + EXPECT_TRUE(ck::utils::check_err(h.GetStrides(), + {3 * 4 * 5 * 6, // N + 4 * 5 * 6, // C + 5 * 6, // D + 6, // H + 1}, // W + "Error: wrong NCDHW dimensions strides!")); +} diff --git a/test/convnd_bwd_data/CMakeLists.txt b/test/convnd_bwd_data/CMakeLists.txt new file mode 100644 index 00000000000..55d71a41d32 --- /dev/null +++ b/test/convnd_bwd_data/CMakeLists.txt @@ -0,0 +1,7 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/profiler/include + ${PROJECT_SOURCE_DIR}/external/include/half +) + +add_test_executable(test_convnd_bwd_data convnd_bwd_data.cpp) +target_link_libraries(test_convnd_bwd_data PRIVATE host_tensor device_convnd_bwd_data_instance conv_util) diff --git a/test/convnd_bwd_data/convnd_bwd_data.cpp b/test/convnd_bwd_data/convnd_bwd_data.cpp new file mode 100644 index 00000000000..7284680e0e5 --- /dev/null +++ b/test/convnd_bwd_data/convnd_bwd_data.cpp @@ -0,0 +1,330 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "profile_convnd_bwd_data_impl.hpp" + +int main() +{ + bool pass = true; + // check 1d + std::vector params; + params.push_back({1, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); + params.push_back({1, 128, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); + params.push_back({1, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); + + for(auto& param : params) + { + pass &= ck::profiler::profile_convnd_bwd_data_impl<1, + float, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + + pass &= ck::profiler::profile_convnd_bwd_data_impl<1, + ck::half_t, + ck::half_t, + ck::half_t, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + + pass &= ck::profiler::profile_convnd_bwd_data_impl<1, + ck::bhalf_t, + ck::bhalf_t, + ck::bhalf_t, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + + pass &= ck::profiler::profile_convnd_bwd_data_impl<1, + int8_t, + int8_t, + int8_t, + int, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + } + + // check 2d + params.clear(); + params.push_back({2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + params.push_back({2, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + params.push_back({2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + + for(auto& param : params) + { + pass &= ck::profiler::profile_convnd_bwd_data_impl<2, + float, + float, + float, + float, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + + pass &= ck::profiler::profile_convnd_bwd_data_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + float, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + + pass &= ck::profiler::profile_convnd_bwd_data_impl<2, + ck::bhalf_t, + ck::bhalf_t, + ck::bhalf_t, + float, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + + pass &= ck::profiler::profile_convnd_bwd_data_impl<2, + int8_t, + int8_t, + int8_t, + int, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + } + + // check 3d + params.clear(); + params.push_back( + {3, 128, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + params.push_back( + {3, 128, 128, 256, {3, 3, 3}, {14, 14, 14}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + params.push_back( + {3, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + for(auto& param : params) + { + pass &= ck::profiler::profile_convnd_bwd_data_impl<3, + float, + float, + float, + float, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + + pass &= ck::profiler::profile_convnd_bwd_data_impl<3, + ck::half_t, + ck::half_t, + ck::half_t, + float, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + + pass &= ck::profiler::profile_convnd_bwd_data_impl<3, + ck::bhalf_t, + ck::bhalf_t, + ck::bhalf_t, + float, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + + pass &= ck::profiler::profile_convnd_bwd_data_impl<3, + int8_t, + int8_t, + int8_t, + int, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + } + + if(pass) + { + std::cout << "test convnd bwd : Pass" << std::endl; + return 0; + } + else + { + std::cout << "test convnd bwd: Fail " << std::endl; + return -1; + } +} diff --git a/test/convnd_fwd/CMakeLists.txt b/test/convnd_fwd/CMakeLists.txt new file mode 100644 index 00000000000..34e698681b2 --- /dev/null +++ b/test/convnd_fwd/CMakeLists.txt @@ -0,0 +1,13 @@ +add_custom_target(test_convnd_fwd) + +add_gtest_executable(test_conv1d_fwd conv1d_fwd.cpp) +target_link_libraries(test_conv1d_fwd PRIVATE host_tensor device_conv1d_fwd_instance conv_util) +add_dependencies(test_convnd_fwd test_conv1d_fwd) + +add_gtest_executable(test_conv2d_fwd conv2d_fwd.cpp) +target_link_libraries(test_conv2d_fwd PRIVATE host_tensor device_conv2d_fwd_instance conv_util) +add_dependencies(test_convnd_fwd test_conv2d_fwd) + +add_gtest_executable(test_conv3d_fwd conv3d_fwd.cpp) +target_link_libraries(test_conv3d_fwd PRIVATE host_tensor device_conv3d_fwd_instance conv_util) +add_dependencies(test_convnd_fwd test_conv3d_fwd) diff --git a/test/convnd_fwd/conv1d_fwd.cpp b/test/convnd_fwd/conv1d_fwd.cpp new file mode 100644 index 00000000000..b6b6a89b2ce --- /dev/null +++ b/test/convnd_fwd/conv1d_fwd.cpp @@ -0,0 +1,93 @@ +#include +#include +#include +#include +#include "gtest/gtest.h" + +#include "data_type.hpp" +#include "element_wise_operation.hpp" +#include "library/include/ck/library/utility/conv_util.hpp" +#include "conv_util.hpp" + +namespace { + +template +bool test_conv1d_nwc_instances(const std::vector& conv_ptrs) +{ + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + + ck::utils::conv::ConvParams params; + params.num_dim_spatial_ = 1; + params.filter_spatial_lengths_ = std::vector{3}; + params.input_spatial_lengths_ = std::vector{71}; + params.conv_filter_strides_ = std::vector{2}; + params.conv_filter_dilations_ = std::vector{1}; + params.input_left_pads_ = std::vector{1}; + params.input_right_pads_ = std::vector{1}; + + conv::ConvFwdOpInstance conv_instance(params); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + return run_engine.Test(conv_ptrs); +} + +} // anonymous namespace + +TEST(Conv1DFwdNWC, TestConv1D) +{ + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + + ck::utils::conv::ConvParams params; + params.num_dim_spatial_ = 1; + params.N_ = 2; + params.K_ = 16; + params.C_ = 4; + params.filter_spatial_lengths_ = std::vector{3}; + params.input_spatial_lengths_ = std::vector{16}; + params.conv_filter_strides_ = std::vector{1}; + params.conv_filter_dilations_ = std::vector{1}; + params.input_left_pads_ = std::vector{1}; + params.input_right_pads_ = std::vector{1}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<1>(conv_ptrs); + conv::ConvFwdOpInstance conv_instance( + params); + + auto reference_conv_fwd_fun = std::bind( + conv::run_reference_convolution_forward<1, float, float, float>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(1e-5); + run_engine.SetRtol(1e-4); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); +} + +TEST(Conv1DFwdNWC, Bf16Iinstances) +{ + EXPECT_TRUE(test_conv1d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::Get<1>())); +} + +TEST(Conv1DFwdNWC, F16Instances) +{ + EXPECT_TRUE(test_conv1d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::Get<1>())); +} + +TEST(Conv1DFwdNWC, F32Instances) +{ + EXPECT_TRUE(test_conv1d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::Get<1>())); +} + +TEST(Conv1DFwdNWC, Int8Instances) +{ + EXPECT_TRUE(test_conv1d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::Get<1>())); +} diff --git a/test/convnd_fwd/conv2d_fwd.cpp b/test/convnd_fwd/conv2d_fwd.cpp new file mode 100644 index 00000000000..05e46147be1 --- /dev/null +++ b/test/convnd_fwd/conv2d_fwd.cpp @@ -0,0 +1,91 @@ +#include +#include +#include +#include +#include "gtest/gtest.h" + +#include "data_type.hpp" +#include "element_wise_operation.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "conv_util.hpp" + +namespace { + +template +bool test_conv2d_nhwc_instances(const std::vector& conv_ptrs) +{ + using namespace std::placeholders; + using namespace ck::utils; + + conv::ConvParams params; + params.num_dim_spatial_ = 2; + params.filter_spatial_lengths_ = std::vector{3, 3}; + params.input_spatial_lengths_ = std::vector{71, 71}; + params.conv_filter_strides_ = std::vector{2, 2}; + params.conv_filter_dilations_ = std::vector{1, 1}; + params.input_left_pads_ = std::vector{1, 1}; + params.input_right_pads_ = std::vector{1, 1}; + + conv::ConvFwdOpInstance conv_instance(params); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + return run_engine.Test(conv_ptrs); +} + +} // anonymous namespace + +TEST(Conv2DFwdNHWC, TestConv2D) +{ + using namespace std::placeholders; + using namespace ck::utils; + + ck::utils::conv::ConvParams params; + params.N_ = 2; + params.K_ = 16; + params.C_ = 4; + params.input_spatial_lengths_ = std::vector{16, 16}; + params.conv_filter_strides_ = std::vector{1, 1}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<2>(conv_ptrs); + conv::ConvFwdOpInstance conv_instance(params); + + auto reference_conv_fwd_fun = std::bind( + conv::run_reference_convolution_forward<2, float, float, float>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(1e-5); + run_engine.SetRtol(1e-4); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); +} + +TEST(Conv2DFwdNHWC, Bf16Instances) +{ + EXPECT_TRUE(test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::Get<2>())); +} + +TEST(Conv2DFwdNHWC, F16Instances) +{ + EXPECT_TRUE(test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::Get<2>())); +} + +TEST(Conv2DFwdNHWC, BF32Instances) +{ + EXPECT_TRUE(test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::Get<2>())); +} + +TEST(Conv2DFwdNHWC, F32Instances) +{ + EXPECT_TRUE(test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::Get<2>())); +} + +TEST(Conv2DFwdNHWC, Int8Instances) +{ + EXPECT_TRUE(test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::Get<2>())); +} diff --git a/test/convnd_fwd/conv3d_fwd.cpp b/test/convnd_fwd/conv3d_fwd.cpp new file mode 100644 index 00000000000..c6f0e7ec07f --- /dev/null +++ b/test/convnd_fwd/conv3d_fwd.cpp @@ -0,0 +1,214 @@ +#include +#include +#include +#include +#include +#include "gtest/gtest.h" + +#include "data_type.hpp" +#include "element_wise_operation.hpp" +#include "library/include/ck/library/utility/conv_util.hpp" +#include "conv_util.hpp" + +namespace { + +template +bool test_conv3d_ndhwc_instances(const std::vector& conv_ptrs) +{ + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + + conv::ConvParams params; + params.N_ = 64; + params.num_dim_spatial_ = 3; + params.filter_spatial_lengths_ = std::vector{3, 3, 2}; + params.input_spatial_lengths_ = std::vector{32, 32, 2}; + params.conv_filter_strides_ = std::vector{2, 2, 2}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{1, 1, 1}; + params.input_right_pads_ = std::vector{1, 1, 1}; + + conv::ConvFwdOpInstance conv_instance(params); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + return run_engine.Test(conv_ptrs); +} + +} // anonymous namespace + +TEST(Conv3DFwdNDHWC, TestConv3D) +{ + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + + conv::ConvParams params; + params.num_dim_spatial_ = 3; + params.N_ = 2; + params.K_ = 16; + params.C_ = 4; + params.filter_spatial_lengths_ = std::vector{3, 3, 3}; + params.input_spatial_lengths_ = std::vector{16, 16, 16}; + params.conv_filter_strides_ = std::vector{1, 1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{1, 1, 1}; + params.input_right_pads_ = std::vector{1, 1, 1}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); + conv::ConvFwdOpInstance conv_instance( + params); + + auto reference_conv_fwd_fun = std::bind( + conv::run_reference_convolution_forward<3, float, float, float>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(1e-5); + run_engine.SetRtol(1e-4); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); +} + +TEST(Conv3DFwdNDHWC, InputOver2GB) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using namespace ck::utils; + + // >2GB Input + conv::ConvParams params; + params.num_dim_spatial_ = 3; + params.N_ = 2; + params.K_ = 16; + params.C_ = 32; + params.filter_spatial_lengths_ = std::vector{3, 3, 3}; + params.input_spatial_lengths_ = std::vector{32, 1000, 1000}; + params.conv_filter_strides_ = std::vector{1, 1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{1, 1, 1}; + params.input_right_pads_ = std::vector{1, 1, 1}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); + + auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, + nullptr, + nullptr, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + PassThrough{}, + PassThrough{}, + PassThrough{}); + EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get())); +} + +TEST(Conv3DFwdNDHWC, FiltersOver2GB) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using namespace ck::utils; + + // >2GB Filters + conv::ConvParams params; + params.num_dim_spatial_ = 3; + params.N_ = 2; + params.K_ = 16; + params.C_ = 32; + params.filter_spatial_lengths_ = std::vector{4, 1000, 1000}; + params.input_spatial_lengths_ = std::vector{16, 16, 16}; + params.conv_filter_strides_ = std::vector{1, 1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{1, 1, 1}; + params.input_right_pads_ = std::vector{1, 1, 1}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); + + auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, + nullptr, + nullptr, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + PassThrough{}, + PassThrough{}, + PassThrough{}); + EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get())); +} + +TEST(Conv3DFwdNDHWC, OutputOver2GB) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using namespace ck::utils; + + // >2GB Output + conv::ConvParams params; + params.num_dim_spatial_ = 3; + params.N_ = 2; + params.K_ = 16; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{1, 1, 1}; + params.input_spatial_lengths_ = std::vector{1000, 1000, 30}; + params.conv_filter_strides_ = std::vector{1, 1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{2, 2, 2}; + params.input_right_pads_ = std::vector{2, 2, 2}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); + auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, + nullptr, + nullptr, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + PassThrough{}, + PassThrough{}, + PassThrough{}); + EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get())); +} + +TEST(Conv3DFwdNDHWC, Bf16Instances) +{ + EXPECT_TRUE(test_conv3d_ndhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::Get<3>())); +} + +TEST(Conv3DFwdNDHWC, F16Instances) +{ + EXPECT_TRUE(test_conv3d_ndhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::Get<3>())); +} + +TEST(Conv3DFwdNDHWC, F32Instances) +{ + EXPECT_TRUE(test_conv3d_ndhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::Get<3>())); +} + +TEST(Conv3DFwdNDHWC, Int8Instances) +{ + EXPECT_TRUE(test_conv3d_ndhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::Get<3>())); +} diff --git a/test/convnd_fwd/conv_util.hpp b/test/convnd_fwd/conv_util.hpp new file mode 100644 index 00000000000..09f641b4151 --- /dev/null +++ b/test/convnd_fwd/conv_util.hpp @@ -0,0 +1,81 @@ +#ifndef TEST_CONV_UTIL_HPP +#define TEST_CONV_UTIL_HPP + +#include + +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "host_tensor.hpp" +#include "sequence.hpp" + +namespace test { +namespace conv { + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +using DeviceConvFwdNoOpPtr = + ck::tensor_operation::device::DeviceConvFwdPtr; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device:: + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + // clang-format off + InDataType, // + WeiDataType, // + OutDataType, // + InDataType, // + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + SpatialDims, // SptialDims + 64, // BlockSize + 16, // MPerBlock + 16, // NPerBlock + 4, // K0PerBlock + 1, // K1 + 16, // MPerXDL + 16, // NPerXDL + 1, // MXdlPerWave + 1, // NXdlPerWave + S<1, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 1, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 1, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockTransferAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector +// clang-format on + +template +void get_test_convolution_fwd_instance(std::vector& instances) +{ + using ConvInstanceT = DeviceConvNDFwdInstance; + instances.emplace_back(std::make_unique()); +} + +} // namespace conv +} // namespace test + +#endif diff --git a/test/gemm/CMakeLists.txt b/test/gemm/CMakeLists.txt new file mode 100644 index 00000000000..b8679e37157 --- /dev/null +++ b/test/gemm/CMakeLists.txt @@ -0,0 +1,29 @@ +# GEMM XDL +add_test_executable(test_gemm_xdl_fp32 gemm_xdl_fp32.cpp) +target_link_libraries(test_gemm_xdl_fp32 PRIVATE host_tensor) +target_link_libraries(test_gemm_xdl_fp32 PRIVATE device_gemm_instance) + +add_test_executable(test_gemm_xdl_fp16 gemm_xdl_fp16.cpp) +target_link_libraries(test_gemm_xdl_fp16 PRIVATE host_tensor) +target_link_libraries(test_gemm_xdl_fp16 PRIVATE device_gemm_instance) + +add_test_executable(test_gemm_xdl_bf16 gemm_xdl_bf16.cpp) +target_link_libraries(test_gemm_xdl_bf16 PRIVATE host_tensor) +target_link_libraries(test_gemm_xdl_bf16 PRIVATE device_gemm_instance) + +add_test_executable(test_gemm_xdl_int8 gemm_xdl_int8.cpp) +target_link_libraries(test_gemm_xdl_int8 PRIVATE host_tensor) +target_link_libraries(test_gemm_xdl_int8 PRIVATE device_gemm_instance) + +# GEMM DL +add_test_executable(test_gemm_dl_fp32 gemm_dl_fp32.cpp) +target_link_libraries(test_gemm_dl_fp32 PRIVATE host_tensor) +target_link_libraries(test_gemm_dl_fp32 PRIVATE device_gemm_instance) + +add_test_executable(test_gemm_dl_fp16 gemm_dl_fp16.cpp) +target_link_libraries(test_gemm_dl_fp16 PRIVATE host_tensor) +target_link_libraries(test_gemm_dl_fp16 PRIVATE device_gemm_instance) + +add_test_executable(test_gemm_dl_int8 gemm_dl_int8.cpp) +target_link_libraries(test_gemm_dl_int8 PRIVATE host_tensor) +TArget_link_libraries(test_gemm_dl_int8 PRIVATE device_gemm_instance) diff --git a/test/gemm/gemm_dl_fp16.cpp b/test/gemm/gemm_dl_fp16.cpp new file mode 100644 index 00000000000..6165355ec41 --- /dev/null +++ b/test/gemm/gemm_dl_fp16.cpp @@ -0,0 +1,130 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "../gemm/gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = ck::half_t; + using BDataType = ck::half_t; + using CDataType = ck::half_t; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + + std::vector gemmPtrs; + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} diff --git a/test/gemm/gemm_dl_fp32.cpp b/test/gemm/gemm_dl_fp32.cpp new file mode 100644 index 00000000000..cd0f8167315 --- /dev/null +++ b/test/gemm/gemm_dl_fp32.cpp @@ -0,0 +1,128 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "../gemm/gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = float; + using BDataType = float; + using CDataType = float; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + std::vector gemmPtrs; + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} diff --git a/test/gemm/gemm_dl_int8.cpp b/test/gemm/gemm_dl_int8.cpp new file mode 100644 index 00000000000..72b9f1440fe --- /dev/null +++ b/test/gemm/gemm_dl_int8.cpp @@ -0,0 +1,128 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "../gemm/gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = int8_t; + using BDataType = int8_t; + using CDataType = int8_t; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + std::vector gemmPtrs; + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} diff --git a/test/gemm/gemm_util.hpp b/test/gemm/gemm_util.hpp new file mode 100644 index 00000000000..258ed60b08d --- /dev/null +++ b/test/gemm/gemm_util.hpp @@ -0,0 +1,345 @@ +#ifndef GEMM_UTILS_HPP +#define GEMM_UTILS_HPP + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "reference_gemm.hpp" +#include "tensor_layout.hpp" + +namespace ck { +namespace gemm_util { + +struct GemmParams +{ + GemmParams() + : M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0) + { + } + + ck::index_t M; + ck::index_t N; + ck::index_t K; + + ck::index_t StrideA; + ck::index_t StrideB; + ck::index_t StrideC; + + float alpha; + float beta; +}; + +template +void RunHostGEMM(const Tensor& A, + const Tensor& B, + Tensor& C, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) +{ + auto ref_gemm = GemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); +} + +template +bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, + const ck::gemm_util::GemmParams& params, + const Tensor& A, + const Tensor& B, + Tensor& C, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) +{ + DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); + + auto invoker_ptr = gemmPtr->MakeInvokerPointer(); + auto argument_ptr = + gemmPtr->MakeArgumentPointer(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + params.M, + params.N, + params.K, + params.StrideA, + params.StrideB, + params.StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(gemmPtr->IsSupportedArgument(argument_ptr.get())) + { + a_m_k_device_buf.ToDevice(A.mData.data()); + b_k_n_device_buf.ToDevice(B.mData.data()); + invoker_ptr->Run(argument_ptr.get()); + c_m_n_device_buf.FromDevice(C.mData.data()); + + return true; + } + else + { + std::cout << "device_gemm with the specified compilation parameters does " + "not support this GEMM problem" + << std::endl; + + return false; + } +} + +template +struct TestGemm +{ + auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params) + { + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor b_k_n( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor c_m_n_host_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_device_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + auto f_generate_tensor_value = [](auto& tensor, auto type) { + using dataType = decltype(type); + + tensor.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + }; + + f_generate_tensor_value(a_m_k, ADataType{}); + f_generate_tensor_value(b_k_n, BDataType{}); + + return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result); + } + + auto operator()(DeviceGemmPtr_& gemmPtr) + { + std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name + << ", CLayout = " << CLayout{}.name << std::endl; + std::cout << gemmPtr->GetTypeString() << std::endl; + + // Arrange + ck::gemm_util::GemmParams params; + params.M = 1024; + params.N = 1024; + params.K = 1024; + params.StrideA = 1024; + params.StrideB = 1024; + params.StrideC = 1024; + + auto host_tensors = PrepareGemmTensor(params); + + const Tensor& a = std::get<0>(host_tensors); + const Tensor& b = std::get<1>(host_tensors); + Tensor& c_host = std::get<2>(host_tensors); + Tensor& c_device = std::get<3>(host_tensors); + + auto a_element_op = AElementwiseOperation{}; + auto b_element_op = BElementwiseOperation{}; + auto c_element_op = CElementwiseOperation{}; + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemm; + ck::gemm_util::RunHostGEMM( + a, b, c_host, a_element_op, b_element_op, c_element_op); + + // Act + bool is_supported = ck::gemm_util::RunDeviceGEMM( + gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op); + + if(is_supported) + { + // Assert + bool res = false; + if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + + return res; + } + else + { + return true; + } + } +}; + +template +struct TestGemmBF16 +{ + using BF16 = ck::bhalf_t; + + auto PrepareGemmTensorBF16(const ck::gemm_util::GemmParams& params) + { + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + // use fp32 host kernel to verify bf16 device kernel + Tensor a_m_k_bf16( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor b_k_n_bf16( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor c_m_n_device_bf16( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + Tensor a_m_k_fp32( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor b_k_n_fp32( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor c_m_n_host_fp32( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_device_fp32( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + a_m_k_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + + bf16_to_f32_(a_m_k_bf16, a_m_k_fp32); + bf16_to_f32_(b_k_n_bf16, b_k_n_fp32); + + return std::make_tuple(a_m_k_bf16, + b_k_n_bf16, + c_m_n_device_bf16, + a_m_k_fp32, + b_k_n_fp32, + c_m_n_host_fp32, + c_m_n_device_fp32); + } + + auto operator()(DeviceGemmPtr_& gemmPtr) + { + // Arrange + ck::gemm_util::GemmParams params; + params.M = 1024; + params.N = 1024; + params.K = 1024; + params.StrideA = 1024; + params.StrideB = 1024; + params.StrideC = 1024; + + auto host_tensors = PrepareGemmTensorBF16(params); + const Tensor& a_bf16 = std::get<0>(host_tensors); + const Tensor& b_bf16 = std::get<1>(host_tensors); + Tensor& c_device_bf16 = std::get<2>(host_tensors); + Tensor& a_fp32 = std::get<3>(host_tensors); + Tensor& b_fp32 = std::get<4>(host_tensors); + Tensor& c_host_fp32 = std::get<5>(host_tensors); + Tensor& c_device_fp32 = std::get<6>(host_tensors); + + auto a_element_op = AElementwiseOperation{}; + auto b_element_op = BElementwiseOperation{}; + auto c_element_op = CElementwiseOperation{}; + + // use fp32 host kernel to verify bf16 device kernel + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemm; + ck::gemm_util::RunHostGEMM( + a_fp32, b_fp32, c_host_fp32, a_element_op, b_element_op, c_element_op); + + // Act + ck::gemm_util::RunDeviceGEMM(gemmPtr, + params, + a_bf16, + b_bf16, + c_device_bf16, + a_element_op, + b_element_op, + c_element_op); + + bf16_to_f32_(c_device_bf16, c_device_fp32); + + // Assert + bool res = ck::utils::check_err( + c_device_fp32.mData, c_host_fp32.mData, "Error: incorrect results!", 1e-2f, 1e-3f); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + + return res; + }; +}; + +} // namespace gemm_util +} // namespace ck +#endif diff --git a/test/gemm/gemm_xdl_bf16.cpp b/test/gemm/gemm_xdl_bf16.cpp new file mode 100644 index 00000000000..5461088b022 --- /dev/null +++ b/test/gemm/gemm_xdl_bf16.cpp @@ -0,0 +1,116 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector&); +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + std::vector gemmPtrs; + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemmBF16{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemmBF16{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemmBF16{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemmBF16{}(gemmPtr); + } + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} diff --git a/test/gemm/gemm_xdl_fp16.cpp b/test/gemm/gemm_xdl_fp16.cpp new file mode 100644 index 00000000000..aeffeafd3e3 --- /dev/null +++ b/test/gemm/gemm_xdl_fp16.cpp @@ -0,0 +1,155 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector&); + +void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector&); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector&); + +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( + std::vector&); +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = ck::half_t; + using BDataType = ck::half_t; + using CDataType = ck::half_t; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + std::vector gemmPtrs; + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} diff --git a/test/gemm/gemm_xdl_fp32.cpp b/test/gemm/gemm_xdl_fp32.cpp new file mode 100644 index 00000000000..10b5175c37c --- /dev/null +++ b/test/gemm/gemm_xdl_fp32.cpp @@ -0,0 +1,154 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { +void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector&); + +void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector&); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = float; + using BDataType = float; + using CDataType = float; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + std::vector gemmPtrs; + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} diff --git a/test/gemm/gemm_xdl_int8.cpp b/test/gemm/gemm_xdl_int8.cpp new file mode 100644 index 00000000000..fbb1b1ac985 --- /dev/null +++ b/test/gemm/gemm_xdl_int8.cpp @@ -0,0 +1,128 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector&); +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = int8_t; + using BDataType = int8_t; + using CDataType = int8_t; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + std::vector gemmPtrs; + bool res = true; + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} diff --git a/test/gemm_reduce/CMakeLists.txt b/test/gemm_reduce/CMakeLists.txt new file mode 100644 index 00000000000..e474af32301 --- /dev/null +++ b/test/gemm_reduce/CMakeLists.txt @@ -0,0 +1,9 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/profiler/include + ${PROJECT_SOURCE_DIR}/test/include + ${PROJECT_SOURCE_DIR}/external/include/half +) + +add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp) +target_link_libraries(test_gemm_reduce_fp16 PRIVATE host_tensor) +target_link_libraries(test_gemm_reduce_fp16 PRIVATE device_gemm_reduce_instance) diff --git a/test/gemm_reduce/gemm_reduce_fp16.cpp b/test/gemm_reduce/gemm_reduce_fp16.cpp new file mode 100644 index 00000000000..6c7bb9658fd --- /dev/null +++ b/test/gemm_reduce/gemm_reduce_fp16.cpp @@ -0,0 +1,46 @@ +#include + +#include "profile_gemm_reduce_impl.hpp" + +int main() +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + int M = 512; + int N = 256; + int K = 128; + + bool pass = true; + + pass = pass && + ck::profiler:: + profile_gemm_reduce_impl( + true, 1, false, false, M, N, K, K, N, N); + + pass = pass && + ck::profiler:: + profile_gemm_reduce_impl( + true, 1, false, false, M, N, K, K, K, N); + + pass = pass && + ck::profiler:: + profile_gemm_reduce_impl( + true, 1, false, false, M, N, K, M, N, N); + + pass = pass && + ck::profiler:: + profile_gemm_reduce_impl( + true, 1, false, false, M, N, K, M, K, N); + + if(pass) + { + std::cout << "test GEMM+Reduce fp16: Pass" << std::endl; + return 0; + } + else + { + std::cout << "test GEMM+Reduce fp16: Fail" << std::endl; + return -1; + } +} diff --git a/test/gemm_split_k/CMakeLists.txt b/test/gemm_split_k/CMakeLists.txt new file mode 100644 index 00000000000..40d422377bc --- /dev/null +++ b/test/gemm_split_k/CMakeLists.txt @@ -0,0 +1,3 @@ +add_test_executable(test_gemm_split_k gemm_split_k.cpp) +target_link_libraries(test_gemm_split_k PRIVATE host_tensor) +target_link_libraries(test_gemm_split_k PRIVATE device_gemm_instance) diff --git a/test/gemm_split_k/gemm_split_k.cpp b/test/gemm_split_k/gemm_split_k.cpp new file mode 100644 index 00000000000..b63361aa1b2 --- /dev/null +++ b/test/gemm_split_k/gemm_split_k.cpp @@ -0,0 +1,254 @@ +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "host_gemm.hpp" +#include "tensor_layout.hpp" +#include "device_gemm_xdl_splitk.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +template +static bool check_out(const Tensor& ref, const Tensor& result) +{ + float max_diff = 1e-6; + + for(std::size_t i = 0; i < ref.mData.size(); ++i) + { + float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); + if(max_diff < diff) + { + return false; + } + } + + return true; +} + +struct gemmArgs +{ + GemmMatrixLayout layout; + int M; + int N; + int K; + int StrideA; + int StrideB; + int StrideC; + int KBatch; +}; + +int test_gemm(const gemmArgs& args) +{ + bool a_row_major, b_row_major, c_row_major; + + switch(args.layout) + { + case GemmMatrixLayout::MK_KN_MN: + a_row_major = true; + b_row_major = true; + c_row_major = true; + break; + case GemmMatrixLayout::MK_NK_MN: + a_row_major = true; + b_row_major = false; + c_row_major = true; + break; + case GemmMatrixLayout::KM_KN_MN: + a_row_major = false; + b_row_major = true; + c_row_major = true; + break; + case GemmMatrixLayout::KM_NK_MN: + a_row_major = false; + b_row_major = false; + c_row_major = true; + break; + default: printf("not supported layout"); return 1; + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, bool row_major) { + if(row_major) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(args.M, args.K, args.StrideA, a_row_major)); + Tensor b_k_n(f_host_tensor_descriptor(args.K, args.N, args.StrideB, b_row_major)); + Tensor c_m_n_host_result( + f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major)); + Tensor c_m_n_device_result( + f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major)); + + // init data + std::size_t num_thread = 1; + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + // set zero to c_device_buf + c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + + host_gemm_mk_kn_mn(a_m_k, + b_k_n, + c_m_n_host_result, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + + DeviceMem a_device_buf(sizeof(float) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(float) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(float) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + // add device GEMM instances + std::vector gemm_ptrs; + + if(args.layout == GemmMatrixLayout::MK_KN_MN) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + } + else if(args.layout == GemmMatrixLayout::MK_NK_MN) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + } + else if(args.layout == GemmMatrixLayout::KM_KN_MN) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + } + + bool success = false; + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + args.M, + args.N, + args.K, + args.StrideA, + args.StrideB, + args.StrideC, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + args.KBatch); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get()); + + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(!check_out(c_m_n_host_result, c_m_n_device_result)) + { + success = false; + break; + } + success = true; + } + } + auto error_code = 0; + if(success) + { + std::cout << "test split k : Pass" << std::endl; + } + else + { + std::cout << "test split k: Fail " << std::endl; + error_code = -1; // test needs to report failure + } + return error_code; +} + +int main(int argc, char* argv[]) +{ + std::vector test_cases; + if(argc == 1) + { + test_cases = {{GemmMatrixLayout::MK_KN_MN, 3, 3, 3, 3, 3, 3, 1}}; + // JD: Populate with more and meaningful + return 0; + } + else if(argc == 9) + { + const auto layout = static_cast(std::stoi(argv[1])); + + const int M = std::stoi(argv[2]); + const int N = std::stoi(argv[3]); + const int K = std::stoi(argv[4]); + + const int StrideA = std::stoi(argv[5]); + const int StrideB = std::stoi(argv[6]); + const int StrideC = std::stoi(argv[7]); + const int KBatch = std::stoi(argv[8]); + test_cases = {{layout, M, N, K, StrideA, StrideB, StrideC, KBatch}}; + } + else + { + printf("arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n"); + return -1; + } + for(const auto& kinder : test_cases) + { + const auto res = test_gemm(kinder); + if(!res) + return -1; + } + return 0; +} diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt new file mode 100644 index 00000000000..f04ee77062e --- /dev/null +++ b/test/grouped_gemm/CMakeLists.txt @@ -0,0 +1,3 @@ +add_test_executable(test_grouped_gemm_fp16 grouped_gemm_fp16.cpp) +target_link_libraries(test_grouped_gemm_fp16 PRIVATE host_tensor) +target_link_libraries(test_grouped_gemm_fp16 PRIVATE device_grouped_gemm_instance) diff --git a/test/grouped_gemm/grouped_gemm_fp16.cpp b/test/grouped_gemm/grouped_gemm_fp16.cpp new file mode 100644 index 00000000000..ef131ed8674 --- /dev/null +++ b/test/grouped_gemm/grouped_gemm_fp16.cpp @@ -0,0 +1,203 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_grouped_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGroupedGemmPtr_ = ck::tensor_operation::device::DeviceGroupedGemmPtr< + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_grouped_gemm_instance { +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector&); +} +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace { + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) +{ + int group_count = rand() % 10 + 1; + + // GEMM shape + std::vector gemm_shapes; + std::vector p_a, p_b; + std::vector p_c; + + gemm_shapes.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + int M = 256 + 256 * (rand() % 10); + int N = 256 + 256 * (rand() % 10); + int K = 128 + 128 * (rand() % 10); + + int AStride = std::is_same::value ? K : M; + int BStride = std::is_same::value ? N : K; + int CStride = std::is_same::value ? N : M; + + gemm_shapes.push_back({M, N, K, AStride, BStride, CStride}); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + std::vector> a_tensors; + ; + std::vector> b_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + for(std::size_t i = 0; i < gemm_shapes.size(); i++) + { + a_tensors.emplace_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].M, gemm_shapes[i].K, gemm_shapes[i].StrideA, ALayout{}))); + b_tensors.emplace_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].K, gemm_shapes[i].N, gemm_shapes[i].StrideB, BLayout{}))); + c_host_tensors.emplace_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); + c_device_tensors.emplace_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); + + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + + for(std::size_t i = 0; i < gemm_shapes.size(); i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize())); + b_tensors_device.emplace_back( + std::make_unique(sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize())); + c_tensors_device.emplace_back(std::make_unique( + sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize())); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + + p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_b.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); + } + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto c_element_op = PassThrough{}; + + // do GEMM + auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer(); + auto argument_ptr = groupedGemmPtr->MakeArgumentPointer( + p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op); + + invoker_ptr->Run(argument_ptr.get()); + + for(std::size_t i = 0; i < gemm_shapes.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data()); + + using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + c_element_op); + + if(!groupedGemmPtr->IsSupportedArgument(argument_ptr.get())) + { + return false; + } + + ref_invoker.Run(ref_argument); + + bool res = ck::utils::check_err(c_host_tensors[i].mData, c_device_tensors[i].mData); + + std::cout << "group_id: " << i << (res ? " SUCCESS" : " FAILURE") << std::endl; + + if(!res) + return false; + } + + return true; +} + +} // anonymous namespace + +int main() +{ + std::vector groupedGemmPtrs; + ck::tensor_operation::device::device_grouped_gemm_instance:: + add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(groupedGemmPtrs); + + bool res = true; + + for(auto& gemmPtr : groupedGemmPtrs) + { + res &= TestGroupedGemm(gemmPtr); + } + + std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + + return res ? 0 : 1; +} diff --git a/test/magic_number_division/CMakeLists.txt b/test/magic_number_division/CMakeLists.txt new file mode 100644 index 00000000000..c7d3f45cd42 --- /dev/null +++ b/test/magic_number_division/CMakeLists.txt @@ -0,0 +1,2 @@ +add_test_executable(test_magic_number_division magic_number_division.cpp) +target_link_libraries(test_magic_number_division PRIVATE host_tensor) diff --git a/test/magic_number_division/magic_number_division.cpp b/test/magic_number_division/magic_number_division.cpp new file mode 100644 index 00000000000..751a62be199 --- /dev/null +++ b/test/magic_number_division/magic_number_division.cpp @@ -0,0 +1,150 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "magic_division.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" + +__global__ void gpu_magic_number_division(uint32_t magic_multiplier, + uint32_t magic_shift, + const int32_t* p_dividend, + int32_t* p_result, + uint64_t num) +{ + uint64_t global_thread_num = blockDim.x * gridDim.x; + + uint64_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + for(uint64_t data_id = global_thread_id; data_id < num; data_id += global_thread_num) + { + p_result[data_id] = + ck::MagicDivision::DoMagicDivision(p_dividend[data_id], magic_multiplier, magic_shift); + } +} + +__global__ void +gpu_naive_division(int32_t divisor, const int32_t* p_dividend, int32_t* p_result, uint64_t num) +{ + uint64_t global_thread_num = blockDim.x * gridDim.x; + + uint64_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + for(uint64_t data_id = global_thread_id; data_id < num; data_id += global_thread_num) + { + p_result[data_id] = p_dividend[data_id] / divisor; + } +} + +__host__ void cpu_magic_number_division(uint32_t magic_multiplier, + uint32_t magic_shift, + const int32_t* p_dividend, + int32_t* p_result, + uint64_t num) +{ + for(uint64_t data_id = 0; data_id < num; ++data_id) + { + p_result[data_id] = + ck::MagicDivision::DoMagicDivision(p_dividend[data_id], magic_multiplier, magic_shift); + } +} + +int main(int, char*[]) +{ + uint64_t num_divisor = 4096; + uint64_t num_dividend = 1L << 16; + + std::vector divisors_host(num_divisor); + std::vector dividends_host(num_dividend); + + // generate divisor + for(uint64_t i = 0; i < num_divisor; ++i) + { + divisors_host[i] = i + 1; + } + + // generate dividend + for(uint64_t i = 0; i < num_divisor; ++i) + { + dividends_host[i] = i; + } + + DeviceMem dividends_dev_buf(sizeof(int32_t) * num_dividend); + DeviceMem naive_result_dev_buf(sizeof(int32_t) * num_dividend); + DeviceMem magic_result_dev_buf(sizeof(int32_t) * num_dividend); + + std::vector naive_result_host(num_dividend); + std::vector magic_result_host(num_dividend); + std::vector magic_result_host2(num_dividend); + + dividends_dev_buf.ToDevice(dividends_host.data()); + + bool pass = true; + + for(std::size_t i = 0; i < num_divisor; ++i) + { + // run naive division on GPU + gpu_naive_division<<<1024, 256>>>( + divisors_host[i], + static_cast(dividends_dev_buf.GetDeviceBuffer()), + static_cast(naive_result_dev_buf.GetDeviceBuffer()), + num_dividend); + + // calculate magic number + uint32_t magic_multiplier, magic_shift; + + ck::tie(magic_multiplier, magic_shift) = + ck::MagicDivision::CalculateMagicNumbers(divisors_host[i]); + + // run magic division on GPU + gpu_magic_number_division<<<1024, 256>>>( + magic_multiplier, + magic_shift, + static_cast(dividends_dev_buf.GetDeviceBuffer()), + static_cast(magic_result_dev_buf.GetDeviceBuffer()), + num_dividend); + + naive_result_dev_buf.FromDevice(naive_result_host.data()); + magic_result_dev_buf.FromDevice(magic_result_host.data()); + + bool res = ck::utils::check_err(magic_result_host, naive_result_host); + + if(!res) + { + pass = false; + continue; + } + + cpu_magic_number_division(magic_multiplier, + magic_shift, + dividends_host.data(), + magic_result_host2.data(), + num_dividend); + + res = ck::utils::check_err(magic_result_host2, naive_result_host); + + if(!res) + { + pass = false; + continue; + } + } + + if(pass) + { + std::cout << "test magic number division: Pass" << std::endl; + return 0; + } + else + { + std::cout << "test magic number division: Fail" << std::endl; + return -1; + } +} diff --git a/test/reduce/CMakeLists.txt b/test/reduce/CMakeLists.txt new file mode 100644 index 00000000000..4e11b049a8d --- /dev/null +++ b/test/reduce/CMakeLists.txt @@ -0,0 +1,7 @@ +add_test_executable(test_reduce_no_index reduce_no_index.cpp) +add_test_executable(test_reduce_with_index reduce_with_index.cpp) +target_link_libraries(test_reduce_no_index PRIVATE host_tensor) +target_link_libraries(test_reduce_no_index PRIVATE device_reduce_instance) +target_link_libraries(test_reduce_with_index PRIVATE host_tensor) +target_link_libraries(test_reduce_with_index PRIVATE device_reduce_instance) + diff --git a/test/reduce/reduce_no_index.cpp b/test/reduce/reduce_no_index.cpp new file mode 100644 index 00000000000..20030392b5a --- /dev/null +++ b/test/reduce/reduce_no_index.cpp @@ -0,0 +1,245 @@ +#include "getopt.h" + +#include "host_common_util.hpp" +#include "profile_reduce_impl.hpp" + +using namespace ck; + +static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, + {"reduceDimensions", required_argument, nullptr, 'R'}, + {"scales", required_argument, nullptr, 'S'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class SimpleAppArgs +{ + private: + int option_index = 0; + + public: + std::vector inLengths; + std::vector reduceDims; + std::vector scales; + + int data_type; + int init_method = 1; + + public: + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths " + "(only 4-d tensor supported)" + << std::endl; + std::cout << "--reduceDimensions or -R comma seperated list of dimension indexes to reduce " + "(only 1 or 3 or 4 dimensions supported)" + << std::endl; + std::cout << "--scales or -S, comma separated two float values for alpha and beta" + << std::endl; + std::cout << "Arg1 -- data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)" << std::endl; + std::cout << "Arg2 -- init method(0=no init, 1=single integer value, 2=scope integer " + "value, 3=decimal value)" + << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:R:S:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'R': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + reduceDims = getTypeValuesFromString(optarg); + break; + case 'S': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + scales = getTypeValuesFromString(optarg); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 2 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + data_type = std::atoi(argv[optind++]); + init_method = std::atoi(argv[optind]); + + if(scales.empty()) + { + scales.push_back(1.0f); + scales.push_back(0.0f); + }; + + if(inLengths.size() != 4 || + (reduceDims.size() != 1 && reduceDims.size() != 3 && reduceDims.size() != 4)) + return (-1); + + if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6) + return (-1); + + return (0); + }; +}; + +bool test_reduce_no_index(int data_type, + int init_method, + std::vector reduceDims, + std::vector inLengths, + ReduceTensorOp reduceOpId, + bool propagateNan, + float alpha, + float beta) +{ + using ck::profiler::profile_reduce_impl; + + bool result = true; + + if(data_type == 0) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + false, + alpha, + beta); + } + else if(data_type == 1) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + false, + alpha, + beta); + } + else if(data_type == 3) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + false, + alpha, + beta); + } + else if(data_type == 5) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + false, + alpha, + beta); + } + else if(data_type == 6) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + false, + alpha, + beta); + } + + return (result); +}; + +constexpr ReduceTensorOp reduceOpId = ReduceTensorOp::AVG; +constexpr bool propagateNan = false; + +int main(int argc, char* argv[]) +{ + SimpleAppArgs args; + + bool result = true; + + if(argc == 1) + { + int data_type = 1; + int init_method = 2; + std::vector inLengths{64, 4, 280, 80}; + std::vector> v_reduceDims{ + {0, 1, 2, 3}, {0, 1, 2}, {1, 2, 3}, {0, 1, 3}, {0, 2, 3}, {0}, {1}, {2}, {3}}; + + for(auto& reduceDims : v_reduceDims) + result = result && test_reduce_no_index(data_type, + init_method, + reduceDims, + inLengths, + reduceOpId, + propagateNan, + 1.0f, + 0.0f); + } + else + { + if(args.processArgs(argc, argv) < 0) + { + throw std::runtime_error( + "Invalid input arguments, test_reduce_no_index could not be executed!"); + }; + + result = test_reduce_no_index(args.data_type, + args.init_method, + args.reduceDims, + args.inLengths, + reduceOpId, + propagateNan, + args.scales[0], + args.scales[1]); + } + + std::cout << "test_reduce_no_index ..... " << (result ? "SUCCESS" : "FAILURE") << std::endl; + + return (result ? 0 : -1); +} diff --git a/test/reduce/reduce_with_index.cpp b/test/reduce/reduce_with_index.cpp new file mode 100644 index 00000000000..c1918bf3886 --- /dev/null +++ b/test/reduce/reduce_with_index.cpp @@ -0,0 +1,245 @@ +#include "getopt.h" + +#include "host_common_util.hpp" +#include "profile_reduce_impl.hpp" + +using namespace ck; + +static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, + {"reduceDimensions", required_argument, nullptr, 'R'}, + {"scales", required_argument, nullptr, 'S'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class SimpleAppArgs +{ + private: + int option_index = 0; + + public: + std::vector inLengths; + std::vector reduceDims; + std::vector scales; + + int data_type; + int init_method = 1; + + public: + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths " + "(only 4-d tensor supported)" + << std::endl; + std::cout << "--reduceDimensions or -R comma seperated list of dimension indexes to reduce " + "(only 1 or 3 or 4 dimensions supported)" + << std::endl; + std::cout << "--scales or -S, comma separated two float values for alpha and beta" + << std::endl; + std::cout << "Arg1 -- data type (1: fp32, 3: int8, 5: bp16, 6: fp64)" << std::endl; + std::cout << "Arg2 -- init method(0=no init, 1=single integer value, 2=scope integer " + "value, 3=decimal value)" + << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:R:S:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'R': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + reduceDims = getTypeValuesFromString(optarg); + break; + case 'S': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + scales = getTypeValuesFromString(optarg); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 2 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + data_type = std::atoi(argv[optind++]); + init_method = std::atoi(argv[optind]); + + if(scales.empty()) + { + scales.push_back(1.0f); + scales.push_back(0.0f); + }; + + if(inLengths.size() != 4 || + (reduceDims.size() != 1 && reduceDims.size() != 3 && reduceDims.size() != 4)) + return (-1); + + if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6) + return (-1); + + return (0); + }; +}; + +bool test_reduce_with_index(int data_type, + int init_method, + std::vector reduceDims, + std::vector inLengths, + ReduceTensorOp reduceOpId, + bool propagateNan, + float alpha, + float beta) +{ + using ck::profiler::profile_reduce_impl; + + bool result = true; + + if(data_type == 0) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + true, + alpha, + beta); + } + else if(data_type == 1) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + true, + alpha, + beta); + } + else if(data_type == 3) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + true, + alpha, + beta); + } + else if(data_type == 5) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + true, + alpha, + beta); + } + else if(data_type == 6) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + true, + alpha, + beta); + } + + return (result); +}; + +constexpr ReduceTensorOp reduceOpId = ReduceTensorOp::AMAX; +constexpr bool propagateNan = false; + +int main(int argc, char* argv[]) +{ + SimpleAppArgs args; + + bool result = true; + + if(argc == 1) + { + int data_type = 1; + int init_method = 2; + std::vector inLengths{64, 4, 280, 80}; + std::vector> v_reduceDims{ + {0, 1, 2, 3}, {0, 1, 2}, {1, 2, 3}, {0, 1, 3}, {0, 2, 3}, {0}, {1}, {2}, {3}}; + + for(auto& reduceDims : v_reduceDims) + result = result && test_reduce_with_index(data_type, + init_method, + reduceDims, + inLengths, + reduceOpId, + propagateNan, + 1.0f, + 0.0f); + } + else + { + if(args.processArgs(argc, argv) < 0) + { + throw std::runtime_error( + "Invalid input arguments, test_reduce_with_index could not be executed!"); + }; + + result = test_reduce_with_index(args.data_type, + args.init_method, + args.reduceDims, + args.inLengths, + reduceOpId, + propagateNan, + args.scales[0], + args.scales[1]); + } + + std::cout << "test_reduce_with_index ..... " << (result ? "SUCCESS" : "FAILURE") << std::endl; + + return (result ? 0 : -1); +} diff --git a/test/reference_conv_fwd/CMakeLists.txt b/test/reference_conv_fwd/CMakeLists.txt new file mode 100644 index 00000000000..04b720b169a --- /dev/null +++ b/test/reference_conv_fwd/CMakeLists.txt @@ -0,0 +1,2 @@ +add_gtest_executable(test_reference_conv_fwd reference_conv_fwd.cpp) +target_link_libraries(test_reference_conv_fwd PRIVATE host_tensor conv_util) diff --git a/test/reference_conv_fwd/reference_conv_fwd.cpp b/test/reference_conv_fwd/reference_conv_fwd.cpp new file mode 100644 index 00000000000..69b223989fd --- /dev/null +++ b/test/reference_conv_fwd/reference_conv_fwd.cpp @@ -0,0 +1,389 @@ +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" + +#include "check_err.hpp" +#include "config.hpp" +#include "conv_util.hpp" +#include "element_wise_operation.hpp" +#include "fill.hpp" +#include "host_tensor.hpp" +#include "reference_conv_fwd.hpp" +#include "tensor_layout.hpp" + +namespace { +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +template , + typename FillWeightsOp = ck::utils::FillConstant> +Tensor +run_reference_convolution_forward(const ck::utils::conv::ConvParams& params, + const FillInputOp& fill_input_op = FillInputOp{}, + const FillWeightsOp& fill_weights_op = FillWeightsOp{0.5f}) +{ + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(ck::utils::conv::get_host_tensor_descriptor(input_dims, InLayout{})); + Tensor weights( + ck::utils::conv::get_host_tensor_descriptor(filter_dims, WeiLayout{})); + Tensor host_output( + ck::utils::conv::get_host_tensor_descriptor(output_dims, OutLayout{})); + + fill_input_op(input.begin(), input.end()); + fill_weights_op(weights.begin(), weights.end()); + std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + return host_output; +} + +} // anonymous namespace + +TEST(ReferenceConvolutionFWD, Conv2DNHWC) +{ + ck::utils::conv::ConvParams params; + params.N_ = 1; + params.K_ = 1; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3, 3}; + params.input_spatial_lengths_ = std::vector{6, 6}; + params.conv_filter_strides_ = std::vector{1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1}; + params.input_left_pads_ = std::vector{0, 0}; + params.input_right_pads_ = std::vector{0, 0}; + + auto out_tensor = run_reference_convolution_forward<2>(params); + std::vector ref_dims{1, 1, 4, 4}; + std::vector ref_data{130.5, + 148.5, + 166.5, + 184.5, + 238.5, + 256.5, + 274.5, + 292.5, + 346.5, + 364.5, + 382.5, + 400.5, + 454.5, + 472.5, + 490.5, + 508.5}; + EXPECT_TRUE(ck::utils::check_err( + out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); + EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); +} + +TEST(ReferenceConvolutionFWD, Conv2DNHWCStridesDilationsPadding) +{ + ck::utils::conv::ConvParams params; + params.N_ = 1; + params.K_ = 2; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3, 3}; + params.input_spatial_lengths_ = std::vector{12, 12}; + params.conv_filter_strides_ = std::vector{2, 2}; + params.conv_filter_dilations_ = std::vector{2, 2}; + params.input_left_pads_ = std::vector{1, 1}; + params.input_right_pads_ = std::vector{1, 1}; + + auto out_tensor = run_reference_convolution_forward<2>(params); + std::vector ref_dims = std::vector{1, 2, 5, 5}; + std::vector ref_data{ + 210., 210., 327., 327., 351., 351., 375., 375., 399., 399., + 459., 459., 706.5, 706.5, 742.5, 742.5, 778.5, 778.5, 814.5, 814.5, + 747., 747., 1138.5, 1138.5, 1174.5, 1174.5, 1210.5, 1210.5, 1246.5, 1246.5, + 1035., 1035., 1570.5, 1570.5, 1606.5, 1606.5, 1642.5, 1642.5, 1678.5, 1678.5, + 1323., 1323., 2002.5, 2002.5, 2038.5, 2038.5, 2074.5, 2074.5, 2110.5, 2110.5}; + EXPECT_TRUE(ck::utils::check_err( + out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); + EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); +} + +TEST(ReferenceConvolutionFWD, Conv1DNWC) +{ + ck::utils::conv::ConvParams params; + params.num_dim_spatial_ = 1; + params.N_ = 1; + params.K_ = 1; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3}; + params.input_spatial_lengths_ = std::vector{6}; + params.conv_filter_strides_ = std::vector{1}; + params.conv_filter_dilations_ = std::vector{1}; + params.input_left_pads_ = std::vector{0}; + params.input_right_pads_ = std::vector{0}; + + auto out_tensor = + run_reference_convolution_forward<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>(params); + std::vector ref_dims{1, 1, 4}; + std::vector ref_data{7.5, 13.5, 19.5, 25.5}; + EXPECT_TRUE(ck::utils::check_err( + out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); + EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); +} + +TEST(ReferenceConvolutionFWD, Conv1DNWCStridesDilationsPadding) +{ + ck::utils::conv::ConvParams params; + params.num_dim_spatial_ = 1; + params.N_ = 1; + params.K_ = 2; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3}; + params.input_spatial_lengths_ = std::vector{12}; + params.conv_filter_strides_ = std::vector{2}; + params.conv_filter_dilations_ = std::vector{2}; + params.input_left_pads_ = std::vector{1}; + params.input_right_pads_ = std::vector{1}; + + auto out_tensor = + run_reference_convolution_forward<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>(params); + std::vector ref_dims{1, 2, 5}; + std::vector ref_data{9., 9., 19.5, 19.5, 31.5, 31.5, 43.5, 43.5, 55.5, 55.5}; + EXPECT_TRUE(ck::utils::check_err( + out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); + EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); +} + +TEST(ReferenceConvolutionFWD, Conv1DNWCSameOutputSize) +{ + ck::utils::conv::ConvParams params; + params.num_dim_spatial_ = 1; + params.N_ = 2; + params.K_ = 16; + params.C_ = 4; + params.filter_spatial_lengths_ = std::vector{3}; + params.input_spatial_lengths_ = std::vector{16}; + params.conv_filter_strides_ = std::vector{1}; + params.conv_filter_dilations_ = std::vector{1}; + params.input_left_pads_ = std::vector{1}; + params.input_right_pads_ = std::vector{1}; + + auto out_tensor2 = run_reference_convolution_forward<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + params, ck::utils::FillMonotonicSeq{0.f, 0.1f}); + + std::vector ref_dims{2, 16, 16}; + std::vector ref_data{ + 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, + 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, + 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, + 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, + 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, + 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, + 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, + 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, + 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, + 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, + 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, + 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, + 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, + 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, + 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, + 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, + 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, + 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, + 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, + 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, + 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, + 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, + 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, + 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, + 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, + 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, + 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, + 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, + 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, + 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, + 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, + 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, + 27., 27., 27., 27., 27., 27., 27., 27., + 27., 27., 27., 27., 27., 27., 27., 27., + 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, + 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, + 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, + 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, + 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, + 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, + 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, + 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, + 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, + 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, + 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, + 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, + 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, + 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, + 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, + 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, + 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, + 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, + 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, + 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, + 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, + 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, + 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, + 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, + 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, + 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, + 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, + 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, + 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, + 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4}; + EXPECT_TRUE(ck::utils::check_err( + out_tensor2.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); + EXPECT_TRUE(ck::utils::check_err(out_tensor2.mData, ref_data, "Error: incorrect results!")); +} + +TEST(ReferenceConvolutionFWD, Conv3DNCDHW) +{ + ck::utils::conv::ConvParams params; + params.num_dim_spatial_ = 3; + params.N_ = 1; + params.K_ = 1; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3, 3, 3}; + params.input_spatial_lengths_ = std::vector{6, 6, 6}; + params.conv_filter_strides_ = std::vector{1, 1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{0, 0, 0}; + params.input_right_pads_ = std::vector{0, 0, 0}; + + auto out_tensor = run_reference_convolution_forward<3, + float, + float, + float, + ck::tensor_layout::convolution::NCDHW, + ck::tensor_layout::convolution::KCZYX, + ck::tensor_layout::convolution::NKDHW>( + params, ck::utils::FillMonotonicSeq{0.f, 0.1f}); + std::vector ref_dims{1, 1, 4, 4, 4}; + std::vector ref_data{ + 407.7, 410.40002, 413.09998, 415.80002, 423.90002, 426.6, 429.30002, 432., + 440.1, 442.80002, 445.5, 448.2, 456.30002, 459., 461.7, 464.40002, + 504.90002, 507.6, 510.30002, 513., 521.1, 523.8, 526.5, 529.2001, + 537.3, 540., 542.7001, 545.4, 553.5, 556.2001, 558.9, 561.6, + 602.10004, 604.8, 607.5, 610.2, 618.3, 621., 623.7, 626.4, + 634.5, 637.2, 639.9, 642.60004, 650.7, 653.4, 656.10004, 658.8, + 699.3, 702., 704.7, 707.4, 715.5, 718.2, 720.9, 723.60004, + 731.7, 734.4001, 737.10004, 739.8, 747.9001, 750.60004, 753.3, 756.}; + EXPECT_TRUE(ck::utils::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error [case 1]: wrong output tensor dimensions!")); + EXPECT_TRUE( + ck::utils::check_err(out_tensor.mData, ref_data, "Error [case 1]: incorrect results!")); +} + +TEST(ReferenceConvolutionFWD, Conv3DNCDHWStridesDilations) +{ + ck::utils::conv::ConvParams params; + params.num_dim_spatial_ = 3; + params.N_ = 1; + params.K_ = 2; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3, 3, 3}; + params.input_spatial_lengths_ = std::vector{12, 12, 12}; + params.conv_filter_strides_ = std::vector{3, 3, 3}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{0, 0, 0}; + params.input_right_pads_ = std::vector{0, 0, 0}; + + auto out_tensor = run_reference_convolution_forward<3, + float, + float, + float, + ck::tensor_layout::convolution::NCDHW, + ck::tensor_layout::convolution::KCZYX, + ck::tensor_layout::convolution::NKDHW>( + params, ck::utils::FillMonotonicSeq{0.f, 0.1f}); + std::vector ref_dims{1, 2, 4, 4, 4}; + std::vector ref_data{ + 2756.7002, 2764.7998, 2772.9001, 2781., 2853.9001, 2862., 2870.1, 2878.2002, + 2951.1, 2959.2002, 2967.2998, 2975.4001, 3048.2998, 3056.4001, 3064.5, 3072.6, + 3923.1, 3931.2, 3939.2998, 3947.4, 4020.2998, 4028.4001, 4036.5002, 4044.5999, + 4117.5, 4125.6, 4133.7, 4141.8, 4214.7, 4222.8, 4230.9004, 4239., + 5089.5, 5097.5996, 5105.7, 5113.8, 5186.7, 5194.8, 5202.9, 5211., + 5283.9004, 5292., 5300.0996, 5308.2, 5381.0996, 5389.2, 5397.3, 5405.4004, + 6255.9004, 6264.0005, 6272.1, 6280.2, 6353.1, 6361.2, 6369.301, 6377.4, + 6450.301, 6458.4, 6466.5, 6474.6, 6547.5, 6555.6, 6563.699, 6571.801, + 2756.7002, 2764.7998, 2772.9001, 2781., 2853.9001, 2862., 2870.1, 2878.2002, + 2951.1, 2959.2002, 2967.2998, 2975.4001, 3048.2998, 3056.4001, 3064.5, 3072.6, + 3923.1, 3931.2, 3939.2998, 3947.4, 4020.2998, 4028.4001, 4036.5002, 4044.5999, + 4117.5, 4125.6, 4133.7, 4141.8, 4214.7, 4222.8, 4230.9004, 4239., + 5089.5, 5097.5996, 5105.7, 5113.8, 5186.7, 5194.8, 5202.9, 5211., + 5283.9004, 5292., 5300.0996, 5308.2, 5381.0996, 5389.2, 5397.3, 5405.4004, + 6255.9004, 6264.0005, 6272.1, 6280.2, 6353.1, 6361.2, 6369.301, 6377.4, + 6450.301, 6458.4, 6466.5, 6474.6, 6547.5, 6555.6, 6563.699, 6571.801}; + EXPECT_TRUE(ck::utils::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error [case 2]: wrong output tensor dimensions!")); + EXPECT_TRUE(ck::utils::check_err( + out_tensor.mData, ref_data, "Error [case 2]: incorrect results!", 1e-4f, 1e-6f)); +} diff --git a/test/space_filling_curve/CMakeLists.txt b/test/space_filling_curve/CMakeLists.txt new file mode 100644 index 00000000000..a5272680428 --- /dev/null +++ b/test/space_filling_curve/CMakeLists.txt @@ -0,0 +1 @@ +add_test_executable(test_space_filling_curve space_filling_curve.cpp) diff --git a/test/space_filling_curve/space_filling_curve.cpp b/test/space_filling_curve/space_filling_curve.cpp new file mode 100644 index 00000000000..635d31d6830 --- /dev/null +++ b/test/space_filling_curve/space_filling_curve.cpp @@ -0,0 +1,127 @@ +#include +#include +#include +#include + +#include "tensor_space_filling_curve.hpp" + +using namespace ck; + +void traverse_using_space_filling_curve(); + +int main(int argc, char** argv) +{ + (void)argc; + (void)argv; + + traverse_using_space_filling_curve(); + + return 0; +} + +void traverse_using_space_filling_curve() +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + using TensorLengths = Sequence<16, 10, 9>; + using DimAccessOrder = Sequence<2, 0, 1>; + using ScalarsPerAccess = Sequence<4, 2, 3>; + using SpaceFillingCurve = SpaceFillingCurve; + + constexpr auto expected = make_tuple(make_tuple(0, 0, 0), + make_tuple(0, 2, 0), + make_tuple(0, 4, 0), + make_tuple(0, 6, 0), + make_tuple(0, 8, 0), + make_tuple(4, 8, 0), + make_tuple(4, 6, 0), + make_tuple(4, 4, 0), + make_tuple(4, 2, 0), + make_tuple(4, 0, 0), + make_tuple(8, 0, 0), + make_tuple(8, 2, 0), + make_tuple(8, 4, 0), + make_tuple(8, 6, 0), + make_tuple(8, 8, 0), + make_tuple(12, 8, 0), + make_tuple(12, 6, 0), + make_tuple(12, 4, 0), + make_tuple(12, 2, 0), + make_tuple(12, 0, 0), + make_tuple(12, 0, 3), + make_tuple(12, 2, 3), + make_tuple(12, 4, 3), + make_tuple(12, 6, 3), + make_tuple(12, 8, 3), + make_tuple(8, 8, 3), + make_tuple(8, 6, 3), + make_tuple(8, 4, 3), + make_tuple(8, 2, 3), + make_tuple(8, 0, 3), + make_tuple(4, 0, 3), + make_tuple(4, 2, 3), + make_tuple(4, 4, 3), + make_tuple(4, 6, 3), + make_tuple(4, 8, 3), + make_tuple(0, 8, 3), + make_tuple(0, 6, 3), + make_tuple(0, 4, 3), + make_tuple(0, 2, 3), + make_tuple(0, 0, 3), + make_tuple(0, 0, 6), + make_tuple(0, 2, 6), + make_tuple(0, 4, 6), + make_tuple(0, 6, 6), + make_tuple(0, 8, 6), + make_tuple(4, 8, 6), + make_tuple(4, 6, 6), + make_tuple(4, 4, 6), + make_tuple(4, 2, 6), + make_tuple(4, 0, 6), + make_tuple(8, 0, 6), + make_tuple(8, 2, 6), + make_tuple(8, 4, 6), + make_tuple(8, 6, 6), + make_tuple(8, 8, 6), + make_tuple(12, 8, 6), + make_tuple(12, 6, 6), + make_tuple(12, 4, 6), + make_tuple(12, 2, 6), + make_tuple(12, 0, 6)); + + constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_assert(num_access == reduce_on_sequence(TensorLengths{} / ScalarsPerAccess{}, + math::multiplies{}, + Number<1>{})); + + static_for<1, num_access, 1>{}([&](auto i) { + constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i); + + static_assert(idx_curr[I0] == expected[i][I0]); + static_assert(idx_curr[I1] == expected[i][I1]); + static_assert(idx_curr[I2] == expected[i][I2]); + + constexpr auto backward_step = SpaceFillingCurve::GetBackwardStep(i); + constexpr auto expected_step = expected[i - I1] - expected[i]; + static_assert(backward_step[I0] == expected_step[I0]); + static_assert(backward_step[I1] == expected_step[I1]); + static_assert(backward_step[I2] == expected_step[I2]); + }); + + static_for<0, num_access - 1, 1>{}([&](auto i) { + constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i); + + static_assert(idx_curr[I0] == expected[i][I0]); + static_assert(idx_curr[I1] == expected[i][I1]); + static_assert(idx_curr[I2] == expected[i][I2]); + + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(i); + constexpr auto expected_step = expected[i + I1] - expected[i]; + static_assert(forward_step[I0] == expected_step[I0]); + static_assert(forward_step[I1] == expected_step[I1]); + static_assert(forward_step[I2] == expected_step[I2]); + }); +}