diff --git a/.gitignore b/.gitignore index a0a924252..4355a9ab3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *~ *.o -build/ -*.pyc \ No newline at end of file +build*/ +*.pyc +.vscode/ \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index d7775e56a..2c84de00e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,3 +2,6 @@ path = sample/tensorflow_bert/bert url = https://github.com/google-research/bert.git +[submodule "OpenNMT-tf"] + path = OpenNMT-tf + url = https://github.com/OpenNMT/OpenNMT-tf diff --git a/CMakeLists.txt b/CMakeLists.txt index 8be7fb0b5..00e3e2e31 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,13 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -cmake_minimum_required(VERSION 3.8 FATAL_ERROR) +cmake_minimum_required(VERSION 3.8 FATAL_ERROR) # for PyTorch extensions, version should be greater than 3.13 project(FasterTransformer LANGUAGES CXX CUDA) find_package(CUDA 10.0 REQUIRED) option(BUILD_TRT "Build in TensorRT mode" OFF) option(BUILD_TF "Build in TensorFlow mode" OFF) +option(BUILD_THE "Build in PyTorch eager mode" OFF) +option(BUILD_THS "Build in TorchScript class mode" OFF) +option(BUILD_THSOP "Build in TorchScript OP mode" OFF) + +set(CXX_STD "11" CACHE STRING "C++ standard") set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR}) @@ -53,6 +58,11 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_${SM},code=\\\"s set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA") endif() +if(BUILD_THE OR BUILD_THS OR BUILD_THSOP) + string(SUBSTRING ${SM} 0 1 SM_MAJOR) + string(SUBSTRING ${SM} 1 1 SM_MINOR) + set(ENV{TORCH_CUDA_ARCH_LIST} "${SM_MAJOR}.${SM_MINOR}") +endif() message("-- Assign GPU architecture (sm=${SM})") else() @@ -65,22 +75,21 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} \ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA") +if(BUILD_THE OR BUILD_THS OR BUILD_THSOP) + set(ENV{TORCH_CUDA_ARCH_LIST} "6.0;6.1;7.0;7.5") +endif() message("-- Assign GPU architecture (sm=60,61,70,75)") endif() set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -O0") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0") -set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall") - +set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall --ptxas-options=-v --resource-usage") -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD "${CXX_STD}") set(CMAKE_CXX_STANDARD_REQUIRED ON) - -if(CMAKE_CXX_STANDARD STREQUAL "11") - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++11") -endif() +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD}") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3") @@ -108,6 +117,41 @@ if(BUILD_TRT) list(APPEND COMMON_LIB_DIRS ${TRT_PATH}/lib) endif() +set(PYTHON_PATH "python" CACHE STRING "Python path") +if(BUILD_THE OR BUILD_THS OR BUILD_THSOP) + execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; import os; import torch; +print(os.path.dirname(torch.__file__),end='');" + RESULT_VARIABLE _PYTHON_SUCCESS + OUTPUT_VARIABLE TORCH_DIR) + if (NOT _PYTHON_SUCCESS MATCHES 0) + message(FATAL_ERROR "Torch config Error.") + endif() + list(APPEND CMAKE_PREFIX_PATH ${TORCH_DIR}) + find_package(Torch REQUIRED) + + execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; from distutils import sysconfig; +print(sysconfig.get_python_inc()); +print(sysconfig.get_config_var('SO'));" + RESULT_VARIABLE _PYTHON_SUCCESS + OUTPUT_VARIABLE _PYTHON_VALUES) + if (NOT _PYTHON_SUCCESS MATCHES 0) + message(FATAL_ERROR "Python config Error.") + endif() + string(REGEX REPLACE ";" "\\\\;" _PYTHON_VALUES ${_PYTHON_VALUES}) + string(REGEX REPLACE "\n" ";" _PYTHON_VALUES ${_PYTHON_VALUES}) + list(GET _PYTHON_VALUES 0 PY_INCLUDE_DIR) + list(GET _PYTHON_VALUES 1 PY_SUFFIX) + list(APPEND COMMON_HEADER_DIRS ${PY_INCLUDE_DIR}) + + execute_process(COMMAND ${PYTHON_PATH} "-c" "from torch.utils import cpp_extension; print(' '.join(cpp_extension._prepare_ldflags([],True,False)),end='');" + RESULT_VARIABLE _PYTHON_SUCCESS + OUTPUT_VARIABLE TORCH_LINK) + if (NOT _PYTHON_SUCCESS MATCHES 0) + message(FATAL_ERROR "PyTorch link config Error.") + endif() +endif() + + include_directories( ${COMMON_HEADER_DIRS} ) @@ -124,10 +168,17 @@ if(BUILD_TF) add_custom_target(copy ALL COMMENT "Copying tensorflow test scripts") add_custom_command(TARGET copy POST_BUILD - COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow/*.py ${PROJECT_SOURCE_DIR}/build/ - COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow/utils ${PROJECT_SOURCE_DIR}/build/ -r - COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow/scripts ${PROJECT_SOURCE_DIR}/build/ -r - COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow_bert ${PROJECT_SOURCE_DIR}/build/ -r + COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow/ ${PROJECT_BINARY_DIR} -r + COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow_bert ${PROJECT_BINARY_DIR}/tensorflow -r ) endif() +if(BUILD_THE OR BUILD_THS OR BUILD_THSOP) + add_custom_target(copy ALL COMMENT "Copying pytorch test scripts") + add_custom_command(TARGET copy + POST_BUILD + COMMAND cp ${PROJECT_SOURCE_DIR}/sample/pytorch/ ${PROJECT_BINARY_DIR} -r + COMMAND mkdir -p ${PROJECT_BINARY_DIR}/pytorch/translation/data/ + COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow/utils/translation/test.* ${PROJECT_BINARY_DIR}/pytorch/translation/data/ + ) +endif() diff --git a/README.md b/README.md index b36b49971..6e31b2f5f 100644 --- a/README.md +++ b/README.md @@ -4,10 +4,13 @@ This repository provides a script and recipe to run the highly optimized transfo ## Table Of Contents +- [FasterTransformer](#fastertransformer) + - [Table Of Contents](#table-of-contents) - [Model overview](#model-overview) - [Configuration support matrix](#configuration-support-matrix) - [Model architecture](#model-architecture) - [Encoder](#encoder) + - [Effective Transformer](#effective-transformer) - [Decoder](#decoder) - [Decoding](#decoding) - [Decoder and Decoding](#decoder-and-decoding) @@ -16,7 +19,8 @@ This repository provides a script and recipe to run the highly optimized transfo - [Quick Start Guide](#quick-start-guide) - [Build the FasterTransformer](#build-the-fastertransformer) - [Execute the encoder demos](#execute-the-encoder-demos) - - [Execute the decoding demos](#execute-the-decoding-demos) + - [Execute the decoder/decoding demos](#execute-the-decoderdecoding-demos) + - [Translation demos](#translation-demos) - [Advanced](#advanced) - [Scripts and sample codes](#scripts-and-sample-codes) - [Command-line options](#command-line-options) @@ -26,12 +30,27 @@ This repository provides a script and recipe to run the highly optimized transfo - [Translation process](#translation-process) - [Performance](#performance) - [Encoder performance](#encoder-performance) - - [Decoder performance on T4](#decoder-performance-on-t4) - - [Decoding performance on T4](#decoding-performance-on-t4) - - [Decoding performance on V100](#decoding-performance-on-v100) + - [Encoder performance on T4 and TensorFlow](#encoder-performance-on-t4-and-tensorflow) + - [Encoder performance on V100 and TensorFlow](#encoder-performance-on-v100-and-tensorflow) + - [Effective Transformer performance on V100 and TensorFlow](#effective-transformer-performance-on-v100-and-tensorflow) + - [Encoder performance on T4 and PyTorch](#encoder-performance-on-t4-and-pytorch) + - [Encoder performance on V100 and PyTorch](#encoder-performance-on-v100-and-pytorch) + - [Performance on application codes of TensorFlow](#performance-on-application-codes-of-tensorflow) + - [Performance on application codes of PyTorch](#performance-on-application-codes-of-pytorch) + - [Decoder performance](#decoder-performance) + - [Decoder performance on T4 and TensorFlow](#decoder-performance-on-t4-and-tensorflow) + - [Decoder performance on V100 and TensorFlow](#decoder-performance-on-v100-and-tensorflow) + - [Decoding performance](#decoding-performance) + - [Decoding performance on T4 and TensorFlow](#decoding-performance-on-t4-and-tensorflow) + - [Decoding performance on V100 and TensorFlow](#decoding-performance-on-v100-and-tensorflow) + - [Decoder and decoding performance on T4 and PyTorch](#decoder-and-decoding-performance-on-t4-and-pytorch) + - [Decoder and decoding performance on V100 and PyTorch](#decoder-and-decoding-performance-on-v100-and-pytorch) + - [TensorFlow performance on translation](#tensorflow-performance-on-translation) + - [PyTorch performance on translation](#pytorch-performance-on-translation) - [Release notes](#release-notes) - [Changelog](#changelog) - [Known issues](#known-issues) + - [TODO](#todo) ## Model overview @@ -42,104 +61,166 @@ In FasterTransformer 1.0, we implemented a highly optimized BERT transformer lay In FasterTransformer 2.0, we have added a highly optimized decoder and decoding models based on OpenNMT-TF, an open-source library. Here, the decoder is the model that contains some transformer layers. On the other hand, decoding refers to the whole translating process, including the lookup embedding table, position encoding, a decoder and beam search. +In FasterTransformer 2.1, we add some important features. First one is the supporting on PyTorch. Recently, there are more and more PyTorch users. We hope the users of PyTorch can also use the FasterTransformer in their application and researches. The second feature is the supporting of [effective transformer](https://github.com/bytedance/effective_transformer). This idea is proposed by ByteDance. It removes the useless padding of encoder input to reduce the computing cost. Third, in addition to decoding with beam search, we also provide the decoding with sampling module. Finally, we optimize many kernels of encoder, decoder and beam search to improve the speed of FasterTransformer. + The following graph demonstrates the model architecture. -![](images/encoder-decoding.png) +![](images/encoder-decoding-2.png) -FasterTransformer is built on top of CUDA and cuBLAS, providing the C++ API and TensorFlow OP. Users can integrate them into TensorFlow or other inference service codes that are built in native C++. We also provide some simple sample code to demonstrate how to use the encoder, decoder and to carry out decoding in C++ and TensorFlow. +FasterTransformer is built on top of CUDA and cuBLAS, providing the C++ API and TensorFlow/PyTorch OPs. Users can integrate them into TensorFlow, PyTorch, or other inference service codes that are built in native C++. We also provide some simple sample code to demonstrate how to use the encoder, decoder and to carry out decoding in C++, TensorFlow and PyTorch. ### Configuration support matrix The following configurations are supported in the FasterTransformer encoder. - Batch size (B1): smaller or equal to 512 -- Sequence length (S): larger than 3 and smaller or equal to 1024 +- Sequence length (S): smaller or equal to 1024 - Head number (H) and size per head (N): + - 16 heads * 64 per heads - 12 heads * 64 per heads - 4 heads * 32 per heads - 8 heads * 96 per heads - Data type: FP32 and FP16 +- Any number layer (N1) if the memory is enough The following configurations are supported in the FasterTransformer decoder and decoding. - Batch size (B1) * beam width (B2): smaller than 1024 - Sequence length (S): smaller than 1024 - Head number (H): 8 and 12 - Size per head (N): 64 -- Vocabulary size (V): from 64 to 30000 +- Vocabulary size (V): from 64 to 40000 - Data type: FP32 and FP16 - -Note: For Encoder-Decoding structure, the sequence length of Encoder and Decoding must be the same. +- Any number layer (N2) if the memory is enough ### Model architecture #### Encoder -The encoder requires the following inputs: + +The arguments, inputs, and outputs of encoder: + +* Arguments: + 1. Head number (H) + 2. Size per head (N) + 3. Remove padding flag: A bool value to determine using the effective transformer or not. +* Inputs: 1. An input tensor. The shape is \[ B1, S, H x N\]. 2. An attention mask. 3. The weights of all parameters. - -The encoder will return the following outputs: + 4. Sequence id offset vector, using to compute the offset of sentence for effective transformer. +* Outputs: 1. The encoder output feature. The shape is \[ B1, S, H x N \]. +#### Effective Transformer + +Effective Transformer is proposed by [here](https://github.com/bytedance/effective_transformer). It is based on the encoder of FasterTransformer. + +The main idea is: removing the padding of sentence to prevent computing the useless tokens. This method can save lots of time when the ratio of the average sequence length of one batch and the maximum sequence length. The smaller ratio, the higher speedup. + +Using the Effective Transformer requires to add some additional kernels, the details are demonstrated in the sample codes. + +![](images/effective_transformer.png) + #### Decoder -The decoder requires the following inputs: + +The arguments, inputs, and outputs of decoder: + +* Arguments: + 1. Head number (H) + 2. size per head (N) +* Inputs: 1. The features vector obtained by looking up the embedding table, or the previous result of the decoder. The shape is \[ B1 x B2, 1, H x N \]. 2. The output of the encoder. 3. The sequence length of the source sentence. Note that the lengths should be expanded by beam width times. 4. A memory cache space to store the K, V of masked multi-head attention. The size will grow for each step. 5. A memory cache space to store the K, V of cross attention. Since K, V is computed by the encoder result, we only compute them in the first step, storing them into the cache, and then reuse in the other steps. 6. The weights of all parameters. - 7. In order to prevent the parallel computing of TensorFlow decoder and FasterTransformer Decoder, we put the TensorFlow result as a pseudo input in the TensorFlow OP. Otherwise, the results of FasterTransformer Decoder will incorrect. This input is useless for computing. Users can remove it when applying Decoder into a real application. - -The decoder will return the following outputs: + 7. To prevent the parallel computing of TensorFlow decoder and FasterTransformer Decoder, we put the TensorFlow result as a pseudo input in the TensorFlow OP. Otherwise, the results of FasterTransformer Decoder will incorrect. This input is useless for computing. Users can remove it when applying Decoder into a real application. +* Outputs: 1. Memory cache of masked multi-head attention. 2. Memory cache of cross attention. 3. The decoder output feature. The shape is \[ B1 x B2, 1, H x N \]. #### Decoding -Decoding refers to the whole translating process, including position encoding, embedding lookup, and a simple beam search kernel. -Decoding requires the following inputs: +Decoding refers to the whole translating process, including position encoding, embedding lookup, and beam search or sampling method to choose the token. + +The arguments, inputs, and outputs of decoding with beam search: + +* Arguments: + 1. Beam width (B2) + 2. Maximum sequence length (S) + 3. Head number (H) + 4. Size per head (N) + 5. Number of decoder layers + 6. Start id of the vocabulary + 7. End id of the vocabulary + 8. Beam search diversity rate of [simple diverse decoding](https://arxiv.org/pdf/1611.08562.pdf) +* Inputs: 1. The output of the encoder. The shape is \[ B1, memory sequence length, H x N \]. 2. The sequence length of the source sentence. Note that the lengths should be expanded by beam width times. 3. The table for embedding lookup. The shape is \[ V, H x N \]. - 4. The start id and end id for the vocabulary. - 5. The weights of all parameters. - -Decoding returns the following outputs: + 4. The weights of all parameters. + 5. Position encoding table. The shape is \[ S, H x N \]. +* Outputs: 1. The output ids. The shape is \[ B1 x B2 \]. 2. The parent ids, which are the chosen beam ids. 3. The sequence lengths of each sentence. Note that these results are required to be finalized by TensorFlow's `tf.contrib.seq2seq.gather_tree` or other progress. +The arguments, inputs, and outputs of decoding with sampling: + +* Arguments: + 1. Maximum sequence length (S) + 2. Top k value (K) + 3. Top p value (P) + 4. Head number (H) + 5. Size per head (N) + 6. Number of decoder layers + 7. Start id of the vocabulary + 8. End id of the vocabulary +* Inputs: + 1. The output of the encoder. The shape is \[ B1, memory sequence length, H x N \]. + 2. The sequence length of the source sentence. Note that the lengths should be expanded by beam width times. + 3. The table for embedding lookup. The shape is \[ V, H x N \]. + 4. The weights of all parameters. + 5. Position encoding table. The shape is \[ S, H x N \]. +* Outputs: + 1. The output ids. The shape is \[ B1 x B2 \]. + 2. The sequence lengths of each sentence. + +Note that K and P cannot be zero or non-zero value in the same time. FasterTransformer chooses the non-zero one to determine to use top k sampling or top p sampling. + #### Decoder and Decoding + Although the decoding process of most methods is similar, we find that there are lots of different kinds to compute the probability and implement the beam search. Therefore, if your chosen beam search algorithm is different from our implementation and it is hard for you to modify the beam search kernel, TensorFlow decoding with FasterTransformer Decoder is the recommended choice. However, the performance of the TensorFlow decoding with the FasterTransformer Decoder is worse than the performance of the FasterTransformer Decoding, especially for small batch sizes. ## Setup -The following section lists the requirements in order to use FasterTransformer. +The following section lists the requirements to use FasterTransformer. ### Requirements -- CMake >= 3.8 -- CUDA 10.1 -- Python 2.7 -- Tensorflow 1.14 -- TensorRT 5.1.5.0 +- CMake >= 3.8 for Tensorflow, CMake >= 3.13 for PyTorch +- CUDA 10.1 or newer version +- Python 3 is recommended because some features are not supported in python 2 +- Tensorflow 1.13 or 1.14 or 1.15 +- PyTorch >= 1.4.0 +- TensorRT 5 or newer version -These components are readily available within the NGC TensorFlow Docker image below, except TensorRT. +These components are readily available within the NGC TensorFlow Docker image below. Ensure you have the following components: -- [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker) -- [TensorFlow 19.07-py2+](https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow) NGC container -- [NVIDIA Pascal](https://www.nvidia.com/en-us/data-center/pascal-gpu-architecture/) or [Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) or [Turing](https://www.nvidia.com/en-us/geforce/turing/) based GPU +- [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker) and NGC container are recommended +- [NVIDIA Pascal](https://www.nvidia.com/en-us/data-center/pascal-gpu-architecture/) or [Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) or [Turing](https://www.nvidia.com/en-us/geforce/turing/) or [Ampere](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/) based GPU For more information about how to get started with NGC containers, see the following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning Documentation: + - [Getting Started Using NVIDIA GPU Cloud](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html) - [Accessing And Pulling From The NGC Container Registry](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#accessing_registry) - [Running TensorFlow](https://docs.nvidia.com/deeplearning/frameworks/tensorflow-release-notes/running.html#running) +- [Running PyTorch](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html) -For those unable to use the TensorFlow NGC container, to set up the required environment or create your own container, see the versioned [NVIDIA Container Support Matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html). +For those unable to use the NGC container, to set up the required environment or create your own container, see the versioned [NVIDIA Container Support Matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html). ## Quick Start Guide @@ -149,251 +230,487 @@ The following section shows how to use FasterTransformer on the NGC container. 1. Run the container. -```bash -nvidia-docker run -ti nvcr.io/nvidia/tensorflow:19.07-py2 bash -``` + You can choose the tensorflow version and python version you want. Here, we list some possible images: + + - `nvcr.io/nvidia/tensorflow:19.06-py3` contains the TensorFlow 1.13 and python 3.5. + - `nvcr.io/nvidia/tensorflow:19.07-py2` contains the TensorFlow 1.14 and python 2.7. + - `nvcr.io/nvidia/tensorflow:20.03-tf1-py3` contains the TensorFlow 1.15 and python 3.6. + - `nvcr.io/nvidia/tensorrt:20.03-py3` contains the TensorRT 7.0.0 and python 3.6. + - `nvcr.io/nvidia/pytorch:20.01-py3` contains the PyTorch 1.4.0 and python 3.6 + - `nvcr.io/nvidia/pytorch:20.03-py3` contains the PyTorch 1.5.0 and python 3.6 + + For example, running image `nvcr.io/nvidia/tensorflow:19.07-py2` by + + ```bash + nvidia-docker run -ti nvcr.io/nvidia/tensorflow:19.07-py2 bash + ``` 2. Clone the repository. -```bash -git clone https://github.com/NVIDIA/DeepLearningExamples -cd DeepLearningExamples/FasterTransformer/v2 -git submodule init -git submodule update -``` + ```bash + git clone https://github.com/NVIDIA/DeepLearningExamples + cd DeepLearningExamples/FasterTransformer/v2.1 + git submodule init + git submodule update + mkdir -p build + cd build + ``` 3. Build the project. -```bash -ln -s /usr/local/lib/python2.7/dist-packages/tensorflow/libtensorflow_framework.so.1 /usr/local/lib/python2.7/dist-packages/tensorflow/libtensorflow_framework.so -mkdir -p build -cd build -cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release .. # C++ only -cmake -DSM=xx -DCMAKE_BUILD_TYPE=Debug .. # C++ debug only -cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TF=ON -DTF_PATH=/usr/local/lib/python2.7/dist-packages/tensorflow .. # Tensorflow mode -cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TRT=ON -DTRT_PATH=/usr/include/x86_64-linux-gnu .. # TensorRT mode -cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TRT=ON -DTRT_PATH= .. # TensorRT mode if you put TensorRT in -make -``` + 3.1 build with c++ -Note: `xx` is the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4). + ```bash + cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release .. + make + ``` -Note: If you use the image we recommand, then the tensorrt related libraries are in the `/usr/include/x86_64-linux-gnu`. + Note: `xx` is the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4). -### Execute the encoder demos + 3.2 build with TensorFlow -1. Generate the `gemm_config.in` file. + * `nvcr.io/nvidia/tensorflow:19.06-py3` -```bash -./bin/encoder_gemm -./bin/encoder_gemm 1 32 12 64 0 -``` + First, update the cmake to cmake 3.8 or later version, and then build the project by the following scripts. -2. Run the encoder. + ```bash + cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TF=ON -DTF_PATH=/usr/local/lib/python3.5/dist-packages/tensorflow .. + make + ``` -a. Run the encoder in C++ by running the following scripts: - -```bash -./bin/encoder_sample -./bin/encoder_sample 1 12 32 12 64 0 -``` + Note: `xx` is the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4). -b. Run the encoder in TensorFlow by running the following scripts: + * `nvcr.io/nvidia/tensorflow:19.07-py2` -```bash -python encoder_sample.py \ - --batch_size 1 \ - --seq_len 32 \ - --head_number 12 \ - --size_per_head 64 \ - --num_layer 12 \ - --data_type fp32 \ - --test_time 1 -``` + First, link the `libtensorflow_framework.so`, and then build the project by the following scripts. -c. Run the encoder in FP16: + ```bash + ln -s /usr/local/lib/python2.7/dist-packages/tensorflow/libtensorflow_framework.so.1 /usr/local/lib/python2.7/dist-packages/tensorflow/libtensorflow_framework.so + cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TF=ON -DTF_PATH=/usr/local/lib/python2.7/dist-packages/tensorflow .. + make + ``` -Note that the configuration of FP32 and FP16 are different, so it is necessary to generate the configuration again. + Note: `xx` is the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4). -```bash -./bin/encoder_gemm 1 32 12 64 1 -./bin/encoder_sample 1 12 32 12 64 1 -python encoder_sample.py \ - --batch_size 1 \ - --seq_len 32 \ - --head_number 12 \ - --size_per_head 64 \ - --num_layer 12 \ - --data_type fp16 \ - --test_time 1 -``` + * `nvcr.io/nvidia/tensorflow:20.03-tf1-py3` -d. Run the encoder in TensorRT by tensorrt sample. + First, link the `libtensorflow_framework.so`, and then build the project by the following scripts. -```bash -./bin/encoder_gemm 1 32 12 64 0 -./bin/transformer_trt fp16(fp32) -./bin/transformer_trt 1 12 32 12 64 fp32 -``` + ```bash + ln -s /usr/local/lib/python3.6/dist-packages/tensorflow_core/libtensorflow_framework.so.1 /usr/local/lib/python3.6/dist-packages/tensorflow_core/libtensorflow_framework.so + cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TF=ON -DTF_PATH=/usr/local/lib/python3.6/dist-packages/tensorflow_core/ .. + make + ``` -3. Run the FasterTransformer in BERT. + Note: `xx` is the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4). -The following script demonstrates how to integrate the FasterTransformer into a BERT model. This requires the repo of [BERT](https://github.com/google-research/bert). + 3.3 build with TensorRT -a. Prepare the BERT codes, Download the BERT pretrained model. + * `nvcr.io/nvidia/tensorrt:20.03-py3` -```bash -cd tensorflow_bert -git clone https://github.com/google-research/bert.git -wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip -unzip uncased_L-12_H-768_A-12.zip -``` + ```bash + cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TRT=ON -DTRT_PATH=/opt/tensorrt/ .. + make + ``` -b. Download the GLUE MRPC dataset. Note that the file `download_glue_data.py` can only executed under python3. + Note: `xx` is the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4). -```bash -wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py -python download_glue_data.py --tasks MRPC -``` + 3.4 build with PyTorch -c. Finetune the pretrained model on MRPC datasets. This takes some minutes. The accuracy would be better or worse because the MRPC dataset is very small. + * `nvcr.io/nvidia/pytorch:20.03-py3` -```bash -export BERT_BASE_DIR=${PWD}/uncased_L-12_H-768_A-12 -export GLUE_DIR=${PWD}/glue_data/ - -python bert/run_classifier.py \ - --task_name=MRPC \ - --do_train=true \ - --do_eval=true \ - --data_dir=$GLUE_DIR/MRPC \ - --vocab_file=$BERT_BASE_DIR/vocab.txt \ - --bert_config_file=$BERT_BASE_DIR/bert_config.json \ - --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ - --max_seq_length=128 \ - --train_batch_size=32 \ - --learning_rate=2e-5 \ - --num_train_epochs=3.0 \ - --output_dir=mrpc_output/ -``` + ```bash + cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_THE=ON -DBUILD_THS=ON -DBUILD_THSOP=ON -DCXX_STD=14 .. + make + ``` -The results would be like: -```bash -I0403 08:52:49.721482 140547349206848 estimator.py:2039] Saving dict for global step 343: eval_accuracy = 0.87009805, eval_loss = 0.44462326, global_step = 343, loss = 0.44462326 -I0403 08:52:50.128525 140547349206848 estimator.py:2099] Saving 'checkpoint_path' summary for global step 343: mrpc_output/model.ckpt-343 -I0403 08:52:50.129132 140547349206848 error_handling.py:96] evaluation_loop marked as finished -I0403 08:52:50.129281 140547349206848 run_classifier.py:923] ***** Eval results ***** -I0403 08:52:50.129338 140547349206848 run_classifier.py:925] eval_accuracy = 0.87009805 -I0403 08:52:50.129695 140547349206848 run_classifier.py:925] eval_loss = 0.44462326 -I0403 08:52:50.129786 140547349206848 run_classifier.py:925] global_step = 343 -I0403 08:52:50.129833 140547349206848 run_classifier.py:925] loss = 0.44462326 -``` + Note: `xx` is the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4). (You can ignore this variable.) -d. Conver the finetuned checkpoint to FP16, check the accuracy of Fastertransformer under FP16. + `-DBUILD_THE=ON` is to build the regular PyTorch extension for eager mode. If you do not use TorchScript, please use this. It may be compatible with more PyTorch versions. -```bash -python ckpt_type_convert.py --init_checkpoint=mrpc_output/model.ckpt-343 --fp16_checkpoint=mrpc_output/fp16_model.ckpt -python run_classifier_wrap.py --floatx=float16 --task_name=MRPC --do_eval=true --data_dir=$GLUE_DIR/MRPC --vocab_file=$BERT_BASE_DIR/vocab.txt --bert_config_file=$BERT_BASE_DIR/bert_config.json --init_checkpoint=mrpc_output/fp16_model.ckpt --max_seq_length=128 --eval_batch_size=8 --output_dir=mrpc_output -``` + `-DBUILD_THS=ON` is to build the TorchScript custom class. If you want to use this custom class, please make sure that the `PyTorch >= 1.5.0`. -Because we do not generate the `gemm_config.ini` file, you can see many warning messages like: + `-DBUILD_THSOP=ON` is to build the TorchScript custom op (function type). This is only for compatibility with older PyTorch, so we only have encoder supported. -```bash -gemm_config.in is not found -loading GEMM algorithms error, using default GEMM algorithms -gemm_config.in is not found -loading GEMM algorithms error, using default GEMM algorithms! -I0403 08:55:07.053885 140260684429120 evaluation.py:275] Finished evaluation at 2020-04-03-08:55:07 -I0403 08:55:07.054126 140260684429120 estimator.py:2039] Saving dict for global step 343: eval_accuracy = 0.86764705, eval_loss = 0.45615184, global_step = 343, loss = 0.4561844 -I0403 08:55:07.422543 140260684429120 estimator.py:2099] Saving 'checkpoint_path' summary for global step 343: mrpc_output/fp16_model.ckpt -I0403 08:55:07.423089 140260684429120 error_handling.py:96] evaluation_loop marked as finished -I0403 08:55:07.423257 140260684429120 run_classifier.py:923] ***** Eval results ***** -I0403 08:55:07.423315 140260684429120 run_classifier.py:925] eval_accuracy = 0.86764705 -I0403 08:55:07.423553 140260684429120 run_classifier.py:925] eval_loss = 0.45615184 -I0403 08:55:07.423635 140260684429120 run_classifier.py:925] global_step = 343 -I0403 08:55:07.423686 140260684429120 run_classifier.py:925] loss = 0.4561844 -``` + ***You can choose one of them or all. No need to add all options.*** -This shows that we use the FasterTransformer to run the inference successfully. In this case, using FP16 to do inference will reduce the accuracy with about 0.3%. + For `PyTorch == 1.4.0`, please use c++11, that is, `-DCXX_STD=11` or just ignore this variable. -e. Compare the speed of BERT of TensorFlow and FasterTransformer under both FP32 and FP16. + For `PyTorch >= 1.5.0`, please use c++14, that is, `-DCXX_STD=14`. -```bash -../bin/encoder_gemm 1 32 12 64 0 -python profile_transformer_inference.py --init_checkpoint=mrpc_output/model.ckpt-343 --tf_profile=false --output_dir=mrpc_output --profiling_output_file=time_elapsed --xla=false --floatx=float32 -../bin/encoder_gemm 1 32 12 64 1 -python profile_transformer_inference.py --init_checkpoint=mrpc_output/fp16_model.ckpt --tf_profile=false --output_dir=mrpc_output --profiling_output_file=time_elapsed --xla=false --floatx=float16 -``` +### Execute the encoder demos -The results of FP16 under V100 would be like: +1. Run FasterTransformer encoder on c++ + + ```bash + ./bin/encoder_gemm + ./bin/encoder_sample + ``` + + 1.1 Run FasterTransformer encoder under FP32 on c++ + + ```bash + ./bin/encoder_gemm 32 32 12 64 0 + ./bin/encoder_sample 32 12 32 12 64 0 0 + ``` -```bash -average time (seconds) elasped original tensorflow: 0.011663460731506347 -average time (seconds) elasped fast transformer: 0.007064676284790039 -``` + 1.2 Run FasterTransformer encoder under FP16 on c++ + + ```bash + ./bin/encoder_gemm 32 32 12 64 1 + ./bin/encoder_sample 32 12 32 12 64 1 0 + ``` -### Execute the decoding demos + 1.3 Run Effective Transformer under FP32 on c++ -1. Generate the `decoding_gemm_config.in` file. + ```bash + ./bin/encoder_gemm 32 32 12 64 0 + ./bin/encoder_sample 32 12 32 12 64 0 1 + ``` -```bash -./bin/decoding_gemm -./bin/decoding_gemm 32 4 8 64 30000 32 768 0 -``` +2. Run FasterTransformer encoder on TensorFlow -2. Run the decoder and decoding. - -a. Run the decoding in C++ by running the following script: + 2.1 Run FasterTransformer encoder under FP32 on TensorFlow -```bash -./bin/decoding_sample -./bin/decoding_sample 32 4 8 64 30000 32 6 768 0 -``` + ```bash + ./bin/encoder_gemm 32 32 12 64 0 + python tensorflow/encoder_sample.py \ + --batch_size 32 \ + --max_seq_len 32 \ + --head_number 12 \ + --size_per_head 64 \ + --num_layer 12 \ + --data_type fp32 \ + --test_time 1 + ``` -b. Run the decoder in TensorFlow by running the following script: + 2.2 Run FasterTransformer encoder under FP16 on TensorFlow -```bash -python decoder_sample.py \ - --batch_size 32 \ - --beam_width 4 \ - --max_seq_len 32 \ - --head_number 8 \ - --size_per_head 64 \ - --memory_hidden_dim 768 \ - --num_layer 6 \ - --data_type fp32 \ - --decoder_type 2 -``` + ```bash + ./bin/encoder_gemm 32 32 12 64 1 + python tensorflow/encoder_sample.py \ + --batch_size 32 \ + --max_seq_len 32 \ + --head_number 12 \ + --size_per_head 64 \ + --num_layer 12 \ + --data_type fp16 \ + --test_time 1 + ``` -c. Run the decoding in TensorFlow by running the following script: - -```bash -python decoding_sample.py \ - --batch_size 32 \ - --beam_width 4 \ - --max_seq_len 32 \ - --head_number 8 \ - --size_per_head 64 \ - --memory_hidden_dim 768 \ - --num_layer 6 \ - --data_type fp32 -``` + 2.3 Run Effective Transformer under FP32 on TensorFlow -3. Run the encoder and decoding at the same time. + ```bash + ./bin/encoder_gemm 32 32 12 64 0 + python tensorflow/encoder_sample.py \ + --batch_size 32 \ + --max_seq_len 32 \ + --head_number 12 \ + --size_per_head 64 \ + --num_layer 12 \ + --data_type fp32 \ + --test_time 1 \ + --remove_padding True + ``` -```bash -python encoder_decoding_sample.py \ - --batch_size 32 \ - --beam_width 4 \ - --max_seq_len 32 \ - --encoder_head_number 12 \ - --encoder_size_per_head 64 \ - --decoder_head_number 8 \ - --decoder_size_per_head 64 \ - --encoder_num_layer 6 \ - --decoder_num_layer 6 \ - --data_type fp32 -``` +3. Run FasterTransformer on PyTorch + + Please install HuggingFace's transformers first before run the demos by + ```bash + pip install transformers==2.5.1 + ``` + + 3.1 Run FasterTransformer encoder under FP32 on PyTorch + + ```bash + ./bin/encoder_gemm 32 32 12 64 0 + python pytorch/encoder_sample.py 32 12 32 12 64 --time + ``` + + 3.2 Run FasterTransformer encoder under FP16 on PyTorch + + ```bash + ./bin/encoder_gemm 32 32 12 64 1 + python pytorch/encoder_sample.py 32 12 32 12 64 --fp16 --time + ``` + + 3.3 Run Effective Transformer under FP32 on PyTorch + + ```bash + ./bin/encoder_gemm 32 32 12 64 0 + python pytorch/encoder_sample.py 32 12 32 12 64 --time --remove_padding + ``` + +4. Run FasterTransformer on TensorRT + + 4.1 Run FasterTransformer under FP32 on TensorRT + + ```bash + ./bin/encoder_gemm 32 32 12 64 0 + ./bin/transformer_trt 32 12 32 12 64 fp32 + ``` + + 4.2 Run FasterTransformer under FP16 on TensorRT + + ```bash + ./bin/encoder_gemm 32 32 12 64 1 + ./bin/transformer_trt 32 12 32 12 64 fp16 + ``` + +### Execute the decoder/decoding demos + +1. Run FasterTransformer decoding on c++ + + ```bash + ./bin/decoding_gemm + ./bin/decoding_beamsearch_sample + ./bin/decoding_sampling_sample + ``` + + 1.1 Run decoding under FP32 on c++ + + ```bash + ./bin/decoding_gemm 32 4 8 64 30000 32 512 0 + ./bin/decoding_beamsearch_sample 32 4 8 64 30000 32 6 512 0 # beam search + + ./bin/decoding_gemm 32 1 8 64 30000 32 512 0 + ./bin/decoding_sampling_sample 32 4 0.0 8 64 30000 32 6 512 0 # top k sampling + ./bin/decoding_sampling_sample 32 0 0.01 8 64 30000 32 6 512 0 # top p sampling + ``` + + 1.2 Run decoding under FP16 on c++ + + ```bash + ./bin/decoding_gemm 32 4 8 64 30000 32 512 1 + ./bin/decoding_beamsearch_sample 32 4 8 64 30000 32 6 512 1 # beam search + + ./bin/decoding_gemm 32 1 8 64 30000 32 512 1 + ./bin/decoding_sampling_sample 32 4 0.0 8 64 30000 32 6 512 1 # top k sampling + ./bin/decoding_sampling_sample 32 0 0.01 8 64 30000 32 6 512 1 # top p sampling + ``` + +2. Run FasterTransformer decoder/decoding on TensorFlow + + 2.1 Run FasterTransformer decoder under FP32 on TensorFlow + + 2.1.1 Verify the correctness + + ```bash + ./bin/decoding_gemm 32 4 8 64 30000 32 512 0 + python tensorflow/decoder_sample.py \ + --batch_size 32 \ + --beam_width 4 \ + --head_number 8 \ + --size_per_head 64 \ + --vocab_size 30000 \ + --max_seq_len 32 \ + --num_layer 6 \ + --memory_hidden_dim 512 \ + --data_type fp32 \ + --decoder_type 2 + ``` + + 2.1.2 Test time of TensorFlow decoder + + ```bash + python tensorflow/decoder_sample.py \ + --batch_size 32 \ + --beam_width 4 \ + --head_number 8 \ + --size_per_head 64 \ + --vocab_size 30000 \ + --max_seq_len 32 \ + --num_layer 6 \ + --memory_hidden_dim 512 \ + --data_type fp32 \ + --decoder_type 0 \ + --test_time 1 + ``` + + 2.1.3 Test time of FasterTransformer decoder + + ```bash + ./bin/decoding_gemm 32 4 8 64 30000 32 512 0 + python tensorflow/decoder_sample.py \ + --batch_size 32 \ + --beam_width 4 \ + --head_number 8 \ + --size_per_head 64 \ + --vocab_size 30000 \ + --max_seq_len 32 \ + --num_layer 6 \ + --memory_hidden_dim 512 \ + --data_type fp32 \ + --decoder_type 1 \ + --test_time 1 + ``` + + 2.2 Run FasterTransformer decoder under FP16 on TensorFlow + + ```bash + ./bin/decoding_gemm 32 4 8 64 30000 32 512 1 + python tensorflow/decoder_sample.py \ + --batch_size 32 \ + --beam_width 4 \ + --head_number 8 \ + --size_per_head 64 \ + --vocab_size 30000 \ + --max_seq_len 32 \ + --num_layer 6 \ + --memory_hidden_dim 512 \ + --data_type fp16 \ + --decoder_type 2 + ``` + + 2.3 Run FasterTransformer decoding under FP32 on TensorFlow + + ```bash + ./bin/decoding_gemm 32 4 8 64 30000 32 512 0 + python tensorflow/decoding_sample.py \ + --batch_size 32 \ + --beam_width 4 \ + --head_number 8 \ + --size_per_head 64 \ + --vocab_size 30000 \ + --max_seq_len 32 \ + --num_layer 6 \ + --memory_hidden_dim 512 \ + --data_type fp32 \ + --beam_search_diversity_rate -1.3 \ + --sampling_topk 0 \ + --sampling_topp 0.01 \ + --test_time 0123 + ``` + + 2.4 Run FasterTransformer decoding under FP16 on TensorFlow + + ```bash + ./bin/decoding_gemm 32 4 8 64 30000 32 512 1 + python tensorflow/decoding_sample.py \ + --batch_size 32 \ + --beam_width 4 \ + --head_number 8 \ + --size_per_head 64 \ + --vocab_size 30000 \ + --max_seq_len 32 \ + --num_layer 6 \ + --memory_hidden_dim 512 \ + --data_type fp16 \ + --beam_search_diversity_rate -1.3 \ + --sampling_topk 0 \ + --sampling_topp 0.01 \ + --test_time 0123 + ``` + +3. Run FasterTransformer decoder/decoding on PyTorch + + Please install OpenNMT-py first before run the demos by + ```bash + pip install opennmt-py==1.1.1 + ``` + + 3.1 Run FasterTransformer decoder under FP32 on PyTorch + + ```bash + ./bin/decoding_gemm 8 4 8 64 31538 32 512 0 + python pytorch/decoder_sample.py 8 6 32 8 64 --time + ``` + + 3.2 Run FasterTransformer decoder under FP16 on PyTorch + + ```bash + ./bin/decoding_gemm 8 4 8 64 31538 32 512 1 + python pytorch/decoder_sample.py 8 6 32 8 64 --fp16 --time + ``` + + 3.3 Run FasterTransformer decoding under FP32 on PyTorch + + ```bash + ./bin/decoding_gemm 8 4 8 64 31538 32 512 0 + python pytorch/decoding_sample.py 8 6 32 8 64 4 31538 --time + ``` + + 3.4 Run FasterTransformer decoding under FP16 on PyTorch + + ```bash + ./bin/decoding_gemm 8 4 8 64 31538 32 512 1 + python pytorch/decoding_sample.py 8 6 32 8 64 4 31538 --fp16 --time + ``` + +### Translation demos + +1. Translation with FasterTransformer on TensorFlow + + 1.1 Prepare data and model + + ```bash + bash tensorflow/utils/translation/download_model_data.sh + ``` + + 1.2 Run under FP32 + + ```bash + ./bin/decoding_gemm 128 4 8 64 32001 100 512 0 + python tensorflow/translate_sample.py \ + --batch_size 128 \ + --beam_width 4 \ + --encoder_head_number 8 \ + --encoder_size_per_head 64 \ + --decoder_head_number 8 \ + --decoder_size_per_head 64 \ + --max_seq_len 32 \ + --encoder_num_layer 6 \ + --decoder_num_layer 6 \ + --data_type fp32 \ + --beam_search_diversity_rate 0.0 \ + --sampling_topk 1 \ + --sampling_topp 0.00 \ + --test_time 012345 + ``` + + 1.3 Run under FP16 + + ```bash + python tensorflow/tensorflow_bert/ckpt_type_convert.py --init_checkpoint=translation/ckpt/model.ckpt-500000 --fp16_checkpoint=translation/ckpt/fp16_model.ckpt-500000 + ./bin/decoding_gemm 128 4 8 64 32001 100 512 1 + python tensorflow/translate_sample.py \ + --batch_size 128 \ + --beam_width 4 \ + --encoder_head_number 8 \ + --encoder_size_per_head 64 \ + --decoder_head_number 8 \ + --decoder_size_per_head 64 \ + --max_seq_len 32 \ + --encoder_num_layer 6 \ + --decoder_num_layer 6 \ + --data_type fp16 \ + --beam_search_diversity_rate 0.0 \ + --sampling_topk 1 \ + --sampling_topp 0.00 \ + --test_time 012345 + ``` + +2. Translation with FasterTransformer on PyTorch + + 2.1 Prepare model and data + + ```bash + bash pytorch/scripts/download_translation_model.sh + ``` + + 2.2 Run under FP32 + + ```bash + ./bin/decoding_gemm 128 4 8 64 31538 100 512 0 + python pytorch/run_translation.py --batch_size 128 --beam_size 4 --model_type decoding_ext --data_type fp32 + ``` + + 2.3 Run under FP16 + + ```bash + ./bin/decoding_gemm 128 4 8 64 31538 100 512 1 + python pytorch/run_translation.py --batch_size 128 --beam_size 4 --model_type decoding_ext --data_type fp16 + ``` ## Advanced @@ -407,11 +724,13 @@ The following code lists the directory structure of FasterTransformer: /fastertransformer: source code of transformer |--/cuda: some CUDA kernels and multi-head attention implementation, both are compiled with cuda/cuBLAS. |--/tf_op: custom Tensorflow OP implementation + |--/th_op: custom PyTorch OP implementation |--/trt_plugin: TensorRT plugin implementation /sample: c++ and tensorflow transformer interface samples |--/cpp: c++ interface samples + |--/pytorch: PyTorch OP samples |--/tensorflow_bert: samples that show of how to integrate our Tensorflow OP into the open source BERT model for sentence (and sentence-pair) classification tasks (GLUE), the samples support both FP16 and FP32, see readme file within this folder more details - |--/tensorflow:TensorFlow OP samples + |--/tensorflow: TensorFlow OP samples |--/tensorRT: both FP16 and FP32 tensorRT plugin samples /tools/gemm_test: loop over all GEMM algorithms to pick the best one ``` @@ -423,11 +742,12 @@ In the root directory of FasterTransformer, the most important directories are: The `fastertransformer/` folder encapsulates all the source codes of FasterTransformer: * `tf_op/` - Contains the TensorFlow Op source files of encoder, decoder and decoding +* `th_op/` - Contains the PyTorch Op source files of encoder, decoder and decoding * `cuda/` - Contains all CUDA kernels of FasterTransformer * `bert_encoder_transformer.h` - Contains the encoder transformer layer * `open_decoder.h` - Contains the decoder transformer layer -* `beam_search_opennmt.h` - Contains the beam search progress for decoding -* `decoding_opennmt.h` - Contains the decoding progress +* `decoding_beamsearch.h` - Contains the progress of decoding with beam search +* `decoding_sampling.h` - Contains the progress of decoding with beam search The `tools/` folder contains the tools to generate the GEMM configuration of FasterTransformer for different settings: * `tools/gemm_test/encoder_gemm.cc` - Encoder GEMM config @@ -435,26 +755,33 @@ The `tools/` folder contains the tools to generate the GEMM configuration of Fas The `sample/` folder contains useful sample codes for FasterTransformer: * `sample/cpp/encoder_sample.cc` - C encoder sample codes -* `sample/cpp/decoding_sample.cc` - C decoding sample codes +* `sample/cpp/decoding_beamsearch_sample.cc` - C decoding with beam search sample codes +* `sample/cpp/decoding_sampling_sample.cc` - C decoding with sampling sample codes * `sample/tensorflow/encoder_sample.py` - TensorFlow encoder sample codes * `sample/tensorflow/decoder_sample.py` - TensorFlow decoder sample codes * `sample/tensorflow/decoding_sample.py` - TensorFlow decoding sample codes +* `sample/tensorflow/tensorflow_bert/` - TensorFlow using FasterTransformer in BERT sample codes * `sample/tensorflow/encoder_decoder_sample.py` - TensorFlow `encoder_decoder` sample codes * `sample/tensorflow/encoder_decoding_sample.py` - TensorFlow `encoder_decoding` sample codes * `sample/tensorflow/translate_sample.py` - TensorFlow translation sample codes -* `sample/tensorRT/transformer_trt.cc` - Transformer layer tensorRT sample codes +* `sample/pytorch/encoder_sample.py` - PyTorch encoder sample codes +* `sample/pytorch/decoder_sample.py` - PyTorch decoder sample codes +* `sample/pytorch/decoding_sample.py` - PyTorch decoding sample codes +* `sample/pytorch/run_glue.py` - PyTorch BERT on glue dataset sample codes +* `sample/pytorch/run_squad.py` - PyTorch BERT on squad dataset sample codes +* `sample/pytorch/run_translation.py` - PyTorch decoding for translation sample codes ### Command-line options To see the full list of available options and their descriptions, use the `-h` or `--help` command-line option with the Python file, for example: ```bash -python encoder_sample.py --help -python decoder_sample.py --help -python decoding_sample.py --help -python encoder_decoder_sample.py --help -python encoder_decoding_sample.py --help -python translate_sample.py --help +python tensorflow/encoder_sample.py --help +python tensorflow/decoder_sample.py --help +python tensorflow/decoding_sample.py --help +python tensorflow/encoder_decoder_sample.py --help +python tensorflow/encoder_decoding_sample.py --help +python tensorflow/translate_sample.py --help ``` ### Inference process @@ -463,414 +790,1991 @@ This subsection provides the details about how to use the encoder, the decoder a #### Encoder process -1. Generate the `gemm_config.in` file. +1. Run FasterTransformer encoder on c++ -`./bin/encoder_gemm` can generate the best GEMM configuration. The arguments of `encoder_gemm` is: + 1.1 Generate the `gemm_config.in` file -```bash -./bin/encoder_gemm -``` + `./bin/encoder_gemm` can generate the best GEMM configuration. The arguments of `encoder_gemm` is: -Assume the settings of the encoder are as follows: -- `batch_size`=1 -- `sequence_length`=32 -- `head_number`=12 -- `size_per_head`=64 -- `data_type`=FP32 + ```bash + ./bin/encoder_gemm + ``` -Then the following scripts can generate the best GEMM configuration under such settings, and record the configuration into the `gemm_config.in.in` file. + This step is necessary no matter what platform we use when we use FasterTransformer. If we do not generate the configure file, the FasterTransformer will use the default configuration and the inference speed may be slower. -```bash -./bin/encoder_gemm 1 32 12 64 0 -``` + Assume the settings of the encoder are as follows: -2. Run the encoder. + - `batch_size`=32 + - `sequence_length`=32 + - `head_number`=12 + - `size_per_head`=64 + - `data_type`=FP32 -Assume the settings are the same as above, and the encoder contains 12 transformer layers. + Then the following scripts can generate the best GEMM configuration under such settings and record the configuration into the `gemm_config.in.in` file. -a. Run the encoder in C++ by running the following scripts: - -`./bin/encoder_sample` runs the encoder in the `cpp`. The arguments of `encoder_sample` is: + ```bash + ./bin/encoder_gemm 32 32 12 64 0 + ``` -```bash -./bin/encoder_sample -``` + In the following subsection, we use the same settings and 12 transformer layers unless specified. -Then the following scripts can run the encoder under the above settings. + 1.2 Run FasterTransformer encoder under FP32 on c++ -```bash -./bin/encoder_sample 1 12 32 12 64 0 -``` + `./bin/encoder_sample` runs the encoder in the `c++`. The arguments of `encoder_sample` is: -The outputs should be similar to the following: - -```bash -Device Tesla T4 -before allocate free 14.65 GB total 14.76 GB -After allocate free 14.61 GB used 0.14 GB total 14.76 GB -[batch_size 1 seq_len 32 12 transformer layers] costs 3.08 ms -``` + ```bash + ./bin/encoder_sample + ``` -b. Run the encoder in TensorFlow by running the following scripts: + Then the following scripts can run the encoder under the above settings. -The following script demonstrates the cross check between the encoder of TensorFlow and the encoder of FasterTransformer, and the execution time of them. + ```bash + ./bin/encoder_sample 32 12 32 12 64 0 0 + ``` -```bash -python encoder_sample.py \ - --batch_size 1 \ - --seq_len 32 \ - --head_number 12 \ - --size_per_head 64 \ - --num_layer 12 \ - --data_type fp32 \ - --test_time 1 -``` + The outputs should be like to the following: -The outputs should be similar to the following: + ```bash + Device Tesla V100-PCIE-32GB + before allocate free 29.46 GB total 31.75 GB + After allocate free 29.41 GB used 2.34 GB total 31.75 GB + [INFO] batch_size 32 seq_len 32 layer 12 FT-CPP-time 16.51 ms + ``` -```bash -[INFO] Encoder Cross check True -[INFO] Max diff 3.57627868652e-06 -[INFO] min diff 0.0 -[INFO] TF decoder time costs: 6.63149 ms -[INFO] OP decoder time costs: 4.64135 ms -``` + 1.3 Run FasterTransformer encoder under FP16 on c++ -c. Run the encoder in FP16: + So far, we use the FP32 to run the FasterTransformer. If we use the volta or newer NVIDIA gpu, we can use tensor core when we use the FP16. -Note that the configuration of FP32 and FP16 are different, so it is necessary to generate the configuration again. + To use the FP16, we only need to set the `` flag to 1 like following: -For C, users only need to set the `` flag as 1. + ```bash + ./bin/encoder_gemm 32 32 12 64 1 + ./bin/encoder_sample 32 12 32 12 64 1 0 + ``` -For TensorFlow, users can use the arguments `--data_type fp16` to change the computing mode. + Note that the configuration of FP32 and FP16 are different, so we need to generate the configuration again. -```bash -./bin/encoder_gemm 1 32 12 64 1 -./bin/encoder_sample 1 12 32 12 64 1 -python encoder_sample.py \ - --batch_size 1 \ - --seq_len 32 \ - --head_number 12 \ - --size_per_head 64 \ - --num_layer 12 \ - --data_type fp16 \ - --test_time 1 -``` + The outputs should be like to the following: -#### Decoder and decoding process + ```bash + Device Tesla V100-PCIE-32GB + before allocate free 29.46 GB total 31.75 GB + After allocate free 29.43 GB used 2.32 GB total 31.75 GB + [INFO] batch_size 32 seq_len 32 layer 12 FT-CPP-time 4.00 ms + ``` -1. Generate the `decoding_gemm_config.in` file. + 1.4 Run Effective Transformer on c++ -`./bin/decoding_gemm` can generate the best GEMM configuration. The arguments of `decoding_gemm` are: + To use the effective transformer, we only need to set the `` flag to 1 like following: -```bash -./bin/decoding_gemm -``` + ```bash + ./bin/encoder_gemm 32 32 12 64 0 + ./bin/encoder_sample 32 12 32 12 64 0 1 + ``` -Assume the settings of decoding are as follows. + The outputs should be like to the following: -- `batch_size`=32 -- `beam_width`=4 -- `head_number`=8 -- `size_per_head`=64 -- `vocabulary_size`=30000 -- `sequence_length`=32 -- `encoder's hidden dimension`=768 -- `data_type`=FP32 + ```bash + Device Tesla V100-PCIE-32GB + before allocate free 29.46 GB total 31.75 GB + After allocate free 29.40 GB used 2.35 GB total 31.75 GB + [INFO] batch_size 32 seq_len 32 layer 12 FT-CPP-time 9.77 ms + ``` -Then the following scripts can generate the best GEMM configuration under such settings, and record the configuration into the `decoding_gemm_config.in` file. +2. Run FasterTransformer on TensorFlow -```bash -./bin/decoding_gemm 32 4 8 64 30000 32 768 0 -``` + 2.1 Run FasterTransformer encoder under FP32 on TensorFlow -2. Run the decoder and decoding. + ```bash + ./bin/encoder_gemm 32 32 12 64 0 + python tensorflow/encoder_sample.py \ + --batch_size 32 \ + --max_seq_len 32 \ + --head_number 12 \ + --size_per_head 64 \ + --num_layer 12 \ + --data_type fp32 \ + --test_time 1 + ``` + + The outputs should be like to the following: + + ```bash + [INFO] Encoder TF v.s. FT with tensor input Cross check True + [INFO] Max diff 5.4836273193359375e-06 + [INFO] min diff 0.0 + [INFO] batch_size 32 max_seq_len 32 12 layer TF-time 20.01 ms + [INFO] batch_size 32 max_seq_len 32 12 layer FT-OP-tensor-time 18.42 ms + ``` + + 2.2 Run FasterTransformer encoder under FP16 on TensorFlow + + To use the FP16 in TensorFlow, we only need to set the `--data_type fp16` like following: + + ```bash + ./bin/encoder_gemm 32 32 12 64 1 + python tensorflow/encoder_sample.py \ + --batch_size 32 \ + --max_seq_len 32 \ + --head_number 12 \ + --size_per_head 64 \ + --num_layer 12 \ + --data_type fp16 \ + --test_time 1 + ``` + + The outputs should be like to the following: + + ```bash + [INFO] Encoder TF v.s. FT with tensor input Cross check True + [INFO] Max diff 0.0234375 + [INFO] min diff 0.0 + [INFO] batch_size 32 max_seq_len 32 12 layer TF-time 8.19 ms + [INFO] batch_size 32 max_seq_len 32 12 layer FT-OP-tensor-time 6.22 ms + ``` + + 2.3 Run Effective Transformer on TensorFlow + + To use the Effective Transformer in TensorFlow, we only need to set the `--remove_padding True` like following: + + ```bash + ./bin/encoder_gemm 32 32 12 64 0 + python tensorflow/encoder_sample.py \ + --batch_size 32 \ + --max_seq_len 32 \ + --head_number 12 \ + --size_per_head 64 \ + --num_layer 12 \ + --data_type fp32 \ + --test_time 1 \ + --remove_padding True + ``` + + The outputs should be like to the following: + + ```bash + [INFO] Encoder TF v.s. FT with tensor input Cross check True + [INFO] Max diff 5.9604644775390625e-06 + [INFO] min diff 0.0 + [INFO] batch_size 32 max_seq_len 32 12 layer TF-time 19.99 ms + [INFO] batch_size 32 max_seq_len 32 12 layer FT-OP-tensor-time 11.49 ms + ``` + + 2.4 Run FasterTransformer for GLUE dataset + + This subsection demonstrates how to integrate the FasterTransformer in TensorFlow, and evaluate the accuracy of FasterTransformer on GLUE dataset. To evaluate on GLUE dataset, it requires the repo of [BERT](https://github.com/google-research/bert). + + 2.4.1 Prepare the BERT codes, Download the BERT pretrained model. + + ```bash + git clone https://github.com/google-research/bert.git tensorflow/tensorflow_bert/bert + wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip + unzip uncased_L-12_H-768_A-12.zip + ``` + + 2.4.2 Download the GLUE MRPC dataset. Note that the file `download_glue_data.py` can only executed under python3. + + ```bash + wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py + python download_glue_data.py --tasks MRPC + ``` + + 2.4.3 Finetune the pretrained model on MRPC datasets. This takes some minutes. + + ```bash + export BERT_BASE_DIR=${PWD}/uncased_L-12_H-768_A-12 + export GLUE_DIR=${PWD}/glue_data/ + + python tensorflow/tensorflow_bert/bert/run_classifier.py \ + --task_name=MRPC \ + --do_train=true \ + --do_eval=true \ + --data_dir=$GLUE_DIR/MRPC \ + --vocab_file=$BERT_BASE_DIR/vocab.txt \ + --bert_config_file=$BERT_BASE_DIR/bert_config.json \ + --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ + --max_seq_length=128 \ + --train_batch_size=32 \ + --learning_rate=2e-5 \ + --num_train_epochs=3.0 \ + --output_dir=mrpc_output/ + ``` + + The results would be like: + + ```bash + INFO:tensorflow:***** Eval results ***** + I0623 12:11:12.009732 140165910435648 run_classifier.py:923] ***** Eval results ***** + INFO:tensorflow: eval_accuracy = 0.8627451 + I0623 12:11:12.009793 140165910435648 run_classifier.py:925] eval_accuracy = 0.8627451 + INFO:tensorflow: eval_loss = 0.5118897 + I0623 12:11:12.010092 140165910435648 run_classifier.py:925] eval_loss = 0.5118897 + INFO:tensorflow: global_step = 343 + I0623 12:11:12.010174 140165910435648 run_classifier.py:925] global_step = 343 + INFO:tensorflow: loss = 0.5118897 + I0623 12:11:12.010224 140165910435648 run_classifier.py:925] loss = 0.5118897 + ``` + + 2.4.4 Evaluate the accuracy of FasterTransformer under FP32 + + To evaluate the accuracy of FasterTransformer, we can use `tensorflow/tensorflow_bert/run_classifier_wrap.py`. This file uses `run_classifier.py` of bert repo, replacing the transformer model by FasterTransformer and add some additional arguments like `--floatx`. + + ```bash + ../bin/encoder_gemm 8 128 12 64 0 + python tensorflow/tensorflow_bert/run_classifier_wrap.py \ + --floatx=float32 \ + --task_name=MRPC \ + --do_eval=true \ + --data_dir=$GLUE_DIR/MRPC \ + --vocab_file=$BERT_BASE_DIR/vocab.txt \ + --bert_config_file=$BERT_BASE_DIR/bert_config.json \ + --init_checkpoint=mrpc_output/model.ckpt-343 \ + --max_seq_length=128 \ + --eval_batch_size=8 \ + --output_dir=mrpc_output + ``` + + The results would be like: + + ```bash + INFO:tensorflow:***** Eval results ***** + I0623 12:12:20.931746 140250133423936 run_classifier.py:923] ***** Eval results ***** + INFO:tensorflow: eval_accuracy = 0.8627451 + I0623 12:12:20.931810 140250133423936 run_classifier.py:925] eval_accuracy = 0.8627451 + INFO:tensorflow: eval_loss = 0.5118897 + I0623 12:12:20.931997 140250133423936 run_classifier.py:925] eval_loss = 0.5118897 + INFO:tensorflow: global_step = 343 + I0623 12:12:20.932071 140250133423936 run_classifier.py:925] global_step = 343 + INFO:tensorflow: loss = 0.5118897 + I0623 12:12:20.932122 140250133423936 run_classifier.py:925] loss = 0.5118897 + ``` + + 2.4.5 Convert the finetuned checkpoint to FP16, and evaluate the accuracy of Fastertransformer under FP16. + + To convert the checkpoint from FP32 to FP16, we can use `tensorflow/tensorflow_bert/ckpt_type_convert.py` to convert the checkpoint. This file requires two arguments, the location of FP32 checkpoint, and the location putting the FP16 checkpoint. + + ```bash + python tensorflow/tensorflow_bert/ckpt_type_convert.py \ + --init_checkpoint=mrpc_output/model.ckpt-343 \ + --fp16_checkpoint=mrpc_output_fp16/fp16_model.ckpt + ./bin/encoder_gemm 8 128 12 64 1 + python tensorflow/tensorflow_bert/run_classifier_wrap.py \ + --floatx=float16 \ + --task_name=MRPC \ + --do_eval=true \ + --data_dir=$GLUE_DIR/MRPC \ + --vocab_file=$BERT_BASE_DIR/vocab.txt \ + --bert_config_file=$BERT_BASE_DIR/bert_config.json \ + --init_checkpoint=mrpc_output_fp16/fp16_model.ckpt \ + --max_seq_length=128 \ + --eval_batch_size=8 \ + --output_dir=mrpc_output_fp16 + ``` + + The results would be like: + + ```bash + INFO:tensorflow:***** Eval results ***** + I0623 12:14:45.001711 139685820454720 run_classifier.py:923] ***** Eval results ***** + INFO:tensorflow: eval_accuracy = 0.86519605 + I0623 12:14:45.001776 139685820454720 run_classifier.py:925] eval_accuracy = 0.86519605 + INFO:tensorflow: eval_loss = 0.5089564 + I0623 12:14:45.001986 139685820454720 run_classifier.py:925] eval_loss = 0.5089564 + INFO:tensorflow: global_step = 343 + I0623 12:14:45.002063 139685820454720 run_classifier.py:925] global_step = 343 + INFO:tensorflow: loss = 0.5089728 + I0623 12:14:45.002117 139685820454720 run_classifier.py:925] loss = 0.5089728 + ``` + + 2.4.6 Compare the speed of BERT of TensorFlow and FasterTransformer under both FP32 and FP16. + + To compare the speed of TensorFlow and FasterTransformer on BERT model directly, we can use `tensorflow/tensorflow_bert/profile_transformer_inferece.py`. + + ```bash + ./bin/encoder_gemm 8 128 12 64 0 + python tensorflow/tensorflow_bert/profile_transformer_inference.py \ + --init_checkpoint=mrpc_output/model.ckpt-343 \ + --tf_profile=false \ + --output_dir=mrpc_output \ + --profiling_output_file=time_elapsed \ + --xla=false \ + --floatx=float32 + ./bin/encoder_gemm 8 128 12 64 1 + python tensorflow/tensorflow_bert/profile_transformer_inference.py \ + --init_checkpoint=mrpc_output_fp16/fp16_model.ckpt \ + --tf_profile=false \ + --output_dir=mrpc_output_fp16 \ + --profiling_output_file=time_elapsed \ + --xla=false \ + --floatx=float16 + ``` + + The results of FP32 would be like: + + ```bash + average time (seconds) elapsed original tensorflow: 0.02553061246871948 + average time (seconds) elapsed fast transformer: 0.018373918533325196 + ``` + + The results of FP16 would be like: + + ```bash + average time (seconds) elapsed original tensorflow: 0.012212872505187988 + average time (seconds) elapsed fast transformer: 0.005685007572174073 + ``` + + 2.5 Run FasterTransformer for SQuAD 1.1 dataset + + This subsection demonstrates how to integrate the FasterTransformer in TensorFlow and evaluates the accuracy of FasterTransformer on SQuAD 1.1 dataset. To evaluate on SQuAD 1.1 dataset, it requires the repo of [BERT](https://github.com/google-research/bert). + + 2.5.1 Prepare the BERT codes and download the fine-tuned model of SQuAD 1.1 from NGC + + Because the training time of SQuAD is longer, and the NVIDIA NGC has provided the fine-tuned BERT model, we download the fine-tuned model directly. + + ```bash + git clone https://github.com/google-research/bert.git tensorflow/tensorflow_bert/bert + wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/bert_tf_v1_1_base_fp32_128/versions/2/zip -O bert_tf_v1_1_base_fp32_128_2.zip + unzip bert_tf_v1_1_base_fp32_128_2.zip -d squad_model + ``` + + 2.5.2 Download the SQuAD dataset. + + ```bash + mkdir squad_data + wget -P squad_data https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json + wget -P squad_data https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json + ``` + + 2.5.3 Evaluate the accuracy of TensorFlow under FP32 + + ```bash + python tensorflow/tensorflow_bert/bert/run_squad.py \ + --predict_batch_size=8 \ + --vocab_file=squad_model/vocab.txt \ + --bert_config_file=squad_model/bert_config.json \ + --init_checkpoint=squad_model/model.ckpt-5474 \ + --train_file=squad_data/train-v1.1.json \ + --do_predict=True \ + --predict_file=squad_data/dev-v1.1.json \ + --max_seq_length=384 \ + --output_dir=./squad_tf_output/fp_32/ + + python tensorflow/tensorflow_bert/squad_evaluate-v1.1.py squad_data/dev-v1.1.json squad_tf_output/fp_32/predictions.json + ``` + + The results of TensorFlow would be like: + + ```bash + {"exact_match": 78.13623462630085, "f1": 85.84460577952547} + ``` + + 2.5.4 Evaluate the accuracy of FasterTransformer under FP32 + + To evaluate the accuracy of FasterTransformer, we can use `tensorflow/tensorflow_bert/run_squad_wrap.py`. This file uses `run_squad.py` of bert repo, replacing the transformer model by FasterTransformer, and add some additional arguments like `--floatx`. + + ```bash + ../bin/encoder_gemm 8 384 12 64 0 + python tensorflow/tensorflow_bert/run_squad_wrap.py \ + --floatx=float32 \ + --predict_batch_size=8 \ + --vocab_file=squad_model/vocab.txt \ + --bert_config_file=squad_model/bert_config.json \ + --init_checkpoint=squad_model/model.ckpt-5474 \ + --train_file=squad_data/train-v1.1.json \ + --do_predict=True \ + --predict_file=squad_data/dev-v1.1.json \ + --max_seq_length=384 \ + --output_dir=./squad_ft_output/fp_32/ + + python tensorflow/tensorflow_bert/squad_evaluate-v1.1.py squad_data/dev-v1.1.json squad_ft_output/fp_32/predictions.json + ``` + + The results of TensorFlow would be like: + + ```bash + {"exact_match": 78.13623462630085, "f1": 85.84460577952547} + ``` + + 2.5.5 Convert the checkpoint to FP16 and evaluate the accuracy of TensorFlow and FasterTransformer under FP16 + + To convert the checkpoint from FP32 to FP16, we can use `tensorflow/tensorflow_bert/ckpt_type_convert.py` to convert the checkpoint. This file requires two arguments, the location of FP32 checkpoint, and the location putting the FP16 checkpoint. + + ```bash + python tensorflow/tensorflow_bert/ckpt_type_convert.py --init_checkpoint=squad_model/model.ckpt-5474 --fp16_checkpoint=squad_fp16_model/model.ckpt + + ../bin/encoder_gemm 8 384 12 64 1 + python tensorflow/tensorflow_bert/run_squad_wrap.py \ + --floatx=float16 \ + --predict_batch_size=8 \ + --vocab_file=squad_model/vocab.txt \ + --bert_config_file=squad_model/bert_config.json \ + --init_checkpoint=squad_fp16_model/model.ckpt \ + --train_file=squad_data/train-v1.1.json \ + --do_predict=True \ + --predict_file=squad_data/dev-v1.1.json \ + --max_seq_length=384 \ + --output_dir=./squad_ft_output/fp_16/ + + python tensorflow/tensorflow_bert/squad_evaluate-v1.1.py squad_data/dev-v1.1.json squad_ft_output/fp_16/predictions.json + ``` + + The results of TensorFlow would be like: + + ```bash + {"exact_match": 78.0321665089877, "f1": 85.77861816524597} + ``` + + 2.5.6 Compare the speed of BERT of TensorFlow and FasterTransformer under both FP32 and FP16. + + ```bash + ./bin/encoder_gemm 8 128 12 64 0 + python tensorflow/tensorflow_bert/profile_transformer_inference.py \ + --init_checkpoint=mrpc_output/model.ckpt-343 \ + --tf_profile=false \ + --output_dir=mrpc_output \ + --profiling_output_file=time_elapsed \ + --xla=false \ + --floatx=float32 + ./bin/encoder_gemm 8 128 12 64 1 + python tensorflow/tensorflow_bert/profile_transformer_inference.py \ + --init_checkpoint=mrpc_output_fp16/fp16_model.ckpt \ + --tf_profile=false \ + --output_dir=mrpc_output_fp16 \ + --profiling_output_file=time_elapsed \ + --xla=false \ + --floatx=float16 + ``` + + The results of FP32 would be like: + + ```bash + average time (seconds) elapsed original tensorflow: 0.02553061246871948 + average time (seconds) elapsed fast transformer: 0.018373918533325196 + ``` + + The results of FP16 would be like: + + ```bash + average time (seconds) elapsed original tensorflow: 0.012212872505187988 + average time (seconds) elapsed fast transformer: 0.005685007572174073 + ``` + +3. Run FasterTransformer on PyTorch + + Please install HuggingFace's transformers first before run the demos by + ```bash + pip install transformers==2.5.1 + ``` + + 3.1 Generate the `gemm_config.in` file: + + ```bash + ./bin/encoder_gemm + ./bin/encoder_gemm 1 32 12 64 1 + ``` + If you want to use the library in other directory, please generate this file according to your setting and copy it to your working directory. + + 3.2 Run the PyTorch encoder sample: + + ```bash + python pytorch/encoder_sample.py <--fp16> <--time> <--ths> <--ths_type> <--remove_padding> <--use_pretrained> + python pytorch/encoder_sample.py 1 12 32 12 64 --fp16 --time + ``` + + Remove `--fp16` for fp32 mode. `--ths` will run on TorchScript mode. `--ths_type` 0 will be the custom torchscript class (build with `-DBUILD_THS=ON`), others will be the function op (build with `-DBUILD_THSOP=ON`). `--remove_padding` will remove the padding of sentence and this brings speedups when the average of sequence length is smaller than the maximum sequence length. `--remove_padding` and `--fp16` may lead to `nan` output, this is due to random initializtion and use pretrained weights (`--use_pretrained`, `--weight_path` is used to set the weight path) can avoid this issue. -Assume the settings are the same as above, and the decoder contains 6 transformer layers. + The outputs should be like to the following: -a. Run the decoding in C++ by running the following script: + ```bash + Mean diff: 0.0009646415710449219 + Max diff: 0.00830078125 + Min diff: 0.0 + [INFO] HuggingFaceEnocder time costs: 8.32 ms + [INFO] FasterTransformer time costs: 1.40 ms + ``` -`./bin/decoding_sample` runs the decoding in the `cpp`. The arguments of `encoder_sample` is: + 3.3 Run the BERT application code: -```bash -./bin/decoding_sample -``` + We have two BERT application code samples, SQuAD and MRPC, `thsext` of `run_squad.sh` uses the custom torchscript class (build with `-DBUILD_THS=ON`), and `thsext` of `run_mrpc.sh` uses the custom torchscript op (build with `-DBUILD_THSOP=ON`). -Then the following scripts can run the decoding under the above settings. + ```bash + bash pytorch/script/run_squad.sh + bash pytorch/script/run_mrpc.sh + ``` + the `` can be: + - `ori`: original HuggingFace's BERT encoder + - `ext`: our PyTorch eager extension + - `ths`: original HuggingFace's BERT encoder in TorchScript mode + - `thsext`: our TorchScript custom class/op -```bash -./bin/decoding_sample 32 4 8 64 30000 32 6 768 0 -``` + the `` can be `fp32` or `fp16` -The outputs should be similar to the following: - -```bash -Device Tesla T4 -[batch_size 32 beam_width 4 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000] costs 191.21 ms -done -``` + For example, run HuggingFace's BERT under FP32 by following scripts: -b. Run the decoder in TensorFlow by running the following script: + ```bash + bash pytorch/scripts/run_mrpc.sh ori fp32 + ``` -```bash -python decoder_sample.py \ - --batch_size 32 \ - --beam_width 4 \ - --max_seq_len 32 \ - --head_number 8 \ - --size_per_head 64 \ - --memory_hidden_dim 768 \ - --num_layer 6 \ - --data_type fp32 \ - --decoder_type 2 -``` + The outputs should be like to the following: -The outputs should be similar to the following: + ```bash + 06/28/2020 07:29:59 - INFO - __main__ - Evaluation for mrpc done in total 4.646116 secs (0.011388 sec per example) + 06/28/2020 07:29:59 - INFO - __main__ - ***** Eval results ***** + 06/28/2020 07:29:59 - INFO - __main__ - acc = 0.8284313725490197 + 06/28/2020 07:29:59 - INFO - __main__ - acc_and_f1 = 0.8556872581808643 + 06/28/2020 07:29:59 - INFO - __main__ - f1 = 0.8829431438127091 + ``` -```bash -[[INFO][PYTHON] step:][0][max diff: ][5.00679e-06][ op val: ][2.3735888][ tf val: ][2.37359381][True] -[[INFO][PYTHON] step:][1][max diff: ][4.64916229e-06][ op val: ][-0.588810563][ tf val: ][-0.588815212][True] -[[INFO][PYTHON] step:][2][max diff: ][5.36441803e-06][ op val: ][-1.46514082][ tf val: ][-1.46514618][True] -... -[[INFO][PYTHON] step:][29][max diff: ][4.529953e-06][ op val: ][2.88768935][ tf val: ][2.88769388][True] -[[INFO][PYTHON] step:][30][max diff: ][4.17232513e-06][ op val: ][-1.28717053][ tf val: ][-1.2871747][True] -[[INFO][PYTHON] step:][31][max diff: ][4.05311584e-06][ op val: ][-1.01830876][ tf val: ][-1.01831281][True] -``` + For example, run our PyTorch custom op under FP16 by following scripts: -The results show that the differences between the decoder of TensorFlow and decoder are smaller than threshold. Note that the differences are absolute differences, so the differences may be large when the op val is large. In this case, the differences are larger than the threshold and the checking will return "False", but it may be not affect the final results. + ```bash + bash pytorch/scripts/run_mrpc.sh thsext fp16 + ``` -The option `decoder_type` decides to use the decoder of TensorFlow or decoder of FasterTransformer. `decoder_type 2` uses both decoders and compares their results. + The outputs should be like to the following: -The following script demonstrates the execution time of the FasterTransformer decoder. + ```bash + 06/28/2020 07:30:19 - INFO - __main__ - Evaluation for mrpc done in total 1.725153 secs (0.004228 sec per example) + 06/28/2020 07:30:19 - INFO - __main__ - ***** Eval results ***** + 06/28/2020 07:30:19 - INFO - __main__ - acc = 0.8284313725490197 + 06/28/2020 07:30:19 - INFO - __main__ - acc_and_f1 = 0.8556872581808643 + 06/28/2020 07:30:19 - INFO - __main__ - f1 = 0.8829431438127091 + ``` -```bash -python decoder_sample.py \ - --batch_size 32 \ - --beam_width 4 \ - --max_seq_len 32 \ - --head_number 8 \ - --size_per_head 64 \ - --memory_hidden_dim 768 \ - --num_layer 6 \ - --data_type fp32 \ - --decoder_type 1 \ - --test_time 1 -``` +#### Decoder and decoding process -The outputs should be similar to the following: +1. Run FasterTransformer decoding on c++ -```bash -[INFO] time costs of OP decoder: 248.046 ms. -``` + 1.1 Generate the `decoding_gemm_config.in` file. -The following script demonstrates the execution time of the TensorFlow decoder. - -```bash -python decoder_sample.py \ - --batch_size 32 \ - --beam_width 4 \ - --max_seq_len 32 \ - --head_number 8 \ - --size_per_head 64 \ - --memory_hidden_dim 768 \ - --num_layer 6 \ - --data_type fp32 \ - --decoder_type 0 \ - --test_time 1 -``` + `./bin/decoding_gemm` can generate the best GEMM configuration. The arguments of `decoding_gemm` are: -c. Run the decoding in TensorFlow by running the following script: - -```bash -python decoding_sample.py \ - --batch_size 32 \ - --beam_width 4 \ - --max_seq_len 32 \ - --head_number 8 \ - --size_per_head 64 \ - --memory_hidden_dim 768 \ - --num_layer 6 \ - --data_type fp32 -``` + ```bash + ./bin/decoding_gemm + ``` -The outputs should be similar to the following: + Assume the settings of decoding are as follows. -```bash - Output ids cross-check: True - - Parent ids cross-check: True - - Sequence lengths cross-check: True - - Finalized output ids cross-check: True -``` + - `batch_size`=32 + - `beam_width`=4 + - `head_number`=8 + - `size_per_head`=64 + - `vocabulary_size`=30000 + - `sequence_length`=32 + - `encoder's hidden dimension`=512 + - `data_type`=FP32 -Note that the results of OP and the results of TensorFlow are often different in the random inputs and weights. + Then the following scripts can generate the best GEMM configuration under such settings, and record the configuration into the `decoding_gemm_config.in` file. -3. Run the encoder and decoding at the same time. + ```bash + ./bin/decoding_gemm 32 4 8 64 30000 32 512 0 + ``` -```bash -python encoder_decoding_sample.py \ - --batch_size 32 \ - --beam_width 4 \ - --max_seq_len 32 \ - --encoder_head_number 12 \ - --encoder_size_per_head 64 \ - --decoder_head_number 8 \ - --decoder_size_per_head 64 \ - --encoder_num_layer 6 \ - --decoder_num_layer 6 \ - --data_type fp32 -``` + 1.2 Run decoding under FP32 on c++ -#### Translation process + Assume the settings are the same as above, and the decoder contains 6 transformer layers. -This subsection demonstrates how to use FasterTansformer decoding to translate a sentence. We use the pretrained model and testing data in [OpenNMT-tf](https://opennmt.net/Models-tf/), which translate from English to German. + In the decoding, we provide two kinds of methods to choose the tokens from the candidates. The first kind of method is the beam search algorithm. The second kind of method is sampling algorithm. -Because the FasterTransformer Encoder is based on BERT, we cannot restore the model of encoder of OpenNMT to FasterTransformer Encoder. Therefore, we use OpenNMT-tf to build the encoder and preprocess the source sentence. + For beam search, we provide a simple diverse decoding of [link](https://arxiv.org/pdf/1611.08562.pdf). When the diversity rate is set to 0, it is equivalent to the naive beam search. -Another problem is that the implementation of FasterTransformer Decoder and decoder of OpenNMT-tf is a little different. For example, the decoder of OpenNMT-tf uses one convolution to compute query, key and value in masked-multihead-attention; but FasterTransformer Decoder splits them into three gemms. The tool `utils/dump_model.py` will convert the pretrained model to fit the model structure of FasterTransformer Decoder. + For sampling, we provide the top k sampling and top p sampling. Here, k is an integer number and p is a float point number. Note that we cannot use both of them in the same time. So, only one of both can be non-zero value. -`download_model_data.sh` will install the OpenNMT-tf v1, downloads the pretrained model into the `translation` folder, and convert the model. + `./bin/decoding_beamsearch_sample` runs the decoding with beam search in the `c++`. The arguments of `decoding_beamsearch_sample` is: -```bash -bash utils/translation/download_model_data.sh -``` + ```bash + ./bin/decoding_beamsearch_sample + ``` -Then run the translation sample by the following script: + Then the following scripts can run the decoding with beam search under the above settings. -```bash -./bin/decoding_gemm 1 4 8 64 32001 100 512 0 -python translate_sample.py -``` + ```bash + ./bin/decoding_beamsearch_sample 32 4 8 64 30000 32 6 512 0 + ``` -The outputs should be similar to the following: + The outputs should be like to the following: -```bash -[INFO] opennmt: ▁28 - jährige r ▁Chef koch ▁to t ▁in ▁San ▁Francisco -[INFO] tf : ▁28 - jährige r ▁Chef koch ▁to t ▁in ▁San ▁Francisco -[INFO] op : ▁28 - jährige r ▁Chef koch ▁to t ▁in ▁San ▁Francisco -``` + ```bash + Device Tesla V100-PCIE-32GB + [INFO] batch_size 32 beam_width 4 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 FT-CPP-decoding-beamsearch-time 73.36 ms + ``` -## Performance + `./bin/decoding_sampling_sample` runs the decoding with sampling in the `c++`. The arguments of `decoding_sampling_sample` is: -Hardware settings: -* CPU: Intel(R) Xeon(R) Gold 6132 CPU @ 2.60GHz -* T4 (with mclk 5000MHz, pclk 1590MHz) -* P4 (with mclk 3003MHz, pclk 1531MHz) -* V100 (with mclk 877MHz, pclk 1380MHz) + ```bash + ./bin/decoding_sampling_sample + ``` -In the following experiments, we updated the following parameters: -* head_num = 8 -* size_per_head = 64 -* transformer layers = 6 -* vocabulary_size = 30000 + where `candidate_num` is the k value of top k, while `probability_threshold` is the p value of top p. -For Encoder, the reported time is the average inference time for 100 iterations with 100 warm-up iterations. + Note that the beam width of sampling algorithm is always 1, so we need to generate the new configuration. -For Decoder and Decoding, the reported time the is average inference time for 50 iterations with 50 warm-up iterations. + The following scripts can run the decoding with top k sampling with under the above settings. -### Encoder performance + ```bash + ./bin/decoding_gemm 32 1 8 64 30000 32 512 0 + ./bin/decoding_sampling_sample 32 4 0.0 8 64 30000 32 6 512 0 + ``` -We demonstrate the inference time of FasterTransformer in C++ and compare it to the inference time of TensorFlow in Python. + The outputs should be like to the following: -| | P4 FP32 (in ms) | T4 FP32 (in ms) | T4 FP16 (in ms) | -|:--------------------:|:----:|:---------:|:-----------:| -| (1, 12, 32, 12, 64) | 3.43 | 2.74 | 1.56 | -| (1, 12, 64, 12, 64) | 4.04 | 3.64 | 1.77 | -| (1, 12, 128, 12, 64) | 6.22 | 5.93 | 2.23 | + ```bash + Device Tesla V100-PCIE-32GB + [INFO] batch_size 32 topk 4 topp 0.000000 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 FT-CPP-decoding-sampling-time 41.65 ms + ``` -For large batch size cases, we report both TensorFlow XLA and faster transformer's performance. + The following scripts can run the decoding with top p sampling with under the above settings. -| | TensorFlow XLA on V100 FP16 (in ms) | FasterTransformer V100 FP16 (in ms) | Speedup | -|:-------------:|:-------------:|:---------:|:-----------:| -| (100, 12, 32, 12, 64) | 13.96 | 9.57 | 1.459 | -| (200, 12, 32, 12, 64) | 26.47 | 18.37 | 1.44 | -| (300, 12, 32, 12, 64) | 38.4 | 27.41 | 1.401 | -| (400, 12, 32, 12, 64) | 49.65 | 35.63 | 1.393 | -| (500, 12, 32, 12, 64) | 62.2 | 44.57 | 1.396 | + ```bash + ./bin/decoding_gemm 32 1 8 64 30000 32 512 0 + ./bin/decoding_sampling_sample 32 0 0.01 8 64 30000 32 6 512 0 + ``` -| | TensorFlow XLA on V100 FP16 (in ms) | FasterTransformer V100 FP16 (in ms) | Speedup | -|:-------------:|:-------------:|:---------:|:-----------:| -| (100, 12, 32, 4, 32) | 3.49 | 1.73 | 2.017 | -| (200, 12, 32, 4, 32) | 4.9 | 2.55 | 1.922 | -| (300, 12, 32, 4, 32) | 6.35 | 3.356 | 1.892 | -| (400, 12, 32, 4, 32) | 8 | 4.31 | 1.856 | -| (500, 12, 32, 4, 32) | 9.93 | 5.13 | 1.936 | + The outputs should be like to the following: -### Decoder performance on T4 + ```bash + Device Tesla V100-PCIE-32GB + [INFO] batch_size 32 topk 0 topp 0.010000 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 FT-CPP-decoding-sampling-time 61.63 ms + ``` -We do not demonstrate the performance of TensorFlow with XLA since we did not find that using XLA has obvious speedup. + 1.3 Run decoding under FP16 on c++ -The following results of FasterTransformer are generated by + So far, we use the FP32 to run the FasterTransformer. If we use the volta or newer NVIDIA gpu, we can use tensor core to accelerate when we use the FP16. -```bash -bash scripts/profile_decoder_op_performance.sh -``` + To use the FP16, we only need to set the `` flag to 1 like following: -* We set beam_width = 1 -* We replace the decoder of tensorflow with our decoder op. + ```bash + ./bin/decoding_gemm 32 4 8 64 30000 32 512 1 + ./bin/decoding_beamsearch_sample 32 4 8 64 30000 32 6 512 1 + ``` -| | TensorFlow FP32 (in ms) | Decoder FP32 (in ms) | FP32 Speedup | TensorFlow FP16 (in ms) | Decoder FP16 (in ms) | FP16 Speedup | -|:---------:|:-------:|:------:|:----:|:-------:|:------:|:----:| -| (1, 32) | 441.68 | 111.14 | 3.97 | 508.81 | 165.88 | 3.06 | -| (1, 64) | 872.39 | 207.37 | 4.20 | 1038.71 | 326.69 | 3.18 | -| (1, 128) | 1714.01 | 457.62 | 3.74 | 2082.92 | 661.00 | 3.41 | -| (32, 32) | 470.93 | 119.87 | 3.92 | 568.83 | 167.42 | 3.39 | -| (64, 32) | 503.57 | 153.62 | 3.27 | 579.21 | 183.74 | 3.15 | -| (128, 32) | 614.59 | 245.94 | 2.50 | 641.98 | 238.27 | 2.69 | -| (256, 32) | 802.18 | 439.33 | 2.01 | 735.67 | 348.74 | 2.11 | + Note that the configuration of FP32 and FP16 are different, so we need to generate the configuration again. -### Decoding performance on T4 + The outputs should be like to the following: -We do not demonstrate the performance of TensorFlow with XLA since we did not find that using XLA has obvious speedup. + ```bash + Device Tesla V100-PCIE-32GB + [INFO] batch_size 32 beam_width 4 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 FT-CPP-decoding-beamsearch-time 47.89 ms + ``` + +2. Run FasterTransformer decoder/decoding on TensorFlow + + 2.1 Run FasterTransformer decoder under FP32 on TensorFlow + + ```bash + ./bin/decoding_gemm 32 4 8 64 30000 32 512 0 + python tensorflow/decoder_sample.py \ + --batch_size 32 \ + --beam_width 4 \ + --head_number 8 \ + --size_per_head 64 \ + --vocab_size 30000 \ + --max_seq_len 32 \ + --num_layer 6 \ + --memory_hidden_dim 512 \ + --data_type fp32 \ + --decoder_type 2 + ``` + + The outputs should be like to the following: + + ```bash + [[INFO][PYTHON] step:][29][True][max abs diff: ][4.17232513e-06][ op val: ][1.23598516][ tf val: ][1.23598933] + [[INFO][PYTHON] step:][30][True][max abs diff: ][4.05311584e-06][ op val: ][-2.40530682][ tf val: ][-2.40531087] + [[INFO][PYTHON] step:][31][False][max abs diff: ][3.7997961e-06][ op val: ][-0.120998174][ tf val: ][-0.121001974] + ``` + + The results show that the differences between the decoder of TensorFlow and decoder are smaller than threshold. Sometimes, the differences are larger than the threshold and the checking will return "False", but it does not affect the results. + + The argument `decoder_type` decides to use the decoder of TensorFlow or decoder of FasterTransformer. `decoder_type 2` uses both decoders and compares their results. + + The following script demonstrates the execution time of the FasterTransformer decoder. + + ```bash + ./bin/decoding_gemm 32 4 8 64 30000 32 512 0 + python tensorflow/decoder_sample.py \ + --batch_size 32 \ + --beam_width 4 \ + --head_number 8 \ + --size_per_head 64 \ + --vocab_size 30000 \ + --max_seq_len 32 \ + --num_layer 6 \ + --memory_hidden_dim 512 \ + --data_type fp32 \ + --decoder_type 1 \ + --test_time 1 + ``` + + The outputs should be like to the following: + + ```bash + [INFO] batch_size 32 beam_width 4 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 FT-OP-decoder-time 138.90 ms. + ``` + + The following script demonstrates the execution time of the TensorFlow decoder. + + ```bash + python tensorflow/decoder_sample.py \ + --batch_size 32 \ + --beam_width 4 \ + --head_number 8 \ + --size_per_head 64 \ + --vocab_size 30000 \ + --max_seq_len 32 \ + --num_layer 6 \ + --memory_hidden_dim 512 \ + --data_type fp32 \ + --decoder_type 0 \ + --test_time 1 + ``` + + The outputs should be like to the following: + + ```bash + [INFO] batch_size 32 beam_width 4 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 TF-decoding-beamsearch-time 564.37 ms. + ``` + + 2.2 Run FasterTransformer decoder under FP16 on TensorFlow + + To use the FP16 in TensorFlow, we only need to set the `--data_type fp16` like following: + + ```bash + ./bin/decoding_gemm 32 4 8 64 30000 32 512 1 + python tensorflow/decoder_sample.py \ + --batch_size 32 \ + --beam_width 4 \ + --head_number 8 \ + --size_per_head 64 \ + --vocab_size 30000 \ + --max_seq_len 32 \ + --num_layer 6 \ + --memory_hidden_dim 512 \ + --data_type fp16 \ + --decoder_type 2 + ``` + + The outputs should be like to the following: + + ```bash + [[INFO][PYTHON] step:][29][True][max abs diff: ][0.01171875][ op val: ][2.03125][ tf val: ][2.04296875] + [[INFO][PYTHON] step:][30][True][max abs diff: ][0.01171875][ op val: ][2.3671875][ tf val: ][2.35546875] + [[INFO][PYTHON] step:][31][True][max abs diff: ][0.01171875][ op val: ][2.33398438][ tf val: ][2.32226562] + ``` + + The following script demonstrates the execution time of the FasterTransformer decoder. + + ```bash + ./bin/decoding_gemm 32 4 8 64 30000 32 512 1 + python tensorflow/decoder_sample.py \ + --batch_size 32 \ + --beam_width 4 \ + --head_number 8 \ + --size_per_head 64 \ + --vocab_size 30000 \ + --max_seq_len 32 \ + --num_layer 6 \ + --memory_hidden_dim 512 \ + --data_type fp16 \ + --decoder_type 1 \ + --test_time 1 + ``` + + The outputs should be like to the following: + + ```bash + [INFO] batch_size 32 beam_width 4 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 FT-OP-decoder-time 132.48 ms. + ``` + + The following script demonstrates the execution time of the TensorFlow decoder. + + ```bash + python tensorflow/decoder_sample.py \ + --batch_size 32 \ + --beam_width 4 \ + --head_number 8 \ + --size_per_head 64 \ + --vocab_size 30000 \ + --max_seq_len 32 \ + --num_layer 6 \ + --memory_hidden_dim 512 \ + --data_type fp16 \ + --decoder_type 0 \ + --test_time 1 + ``` + + The outputs should be like to the following: + + ```bash + [INFO] batch_size 32 beam_width 4 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 TF-decoding-beamsearch-time 503.52 ms. + ``` + + Note that when the batch size is small, using FP16 may cause the inference speed to become slower. This is because that decoding is not computing bound and using FP16 in TensorFlow leads to some additional operation and casting. + + 2.3 Run FasterTransformer decoding under FP32 on TensorFlow + + In the decoding, we provide two kinds of methods to choose the tokens from the candidates. The first kind of method is the beam search algorithm. The second kind of method is sampling algorithm. + + For beam search, we provide a simple diverse decoding of [link](https://arxiv.org/pdf/1611.08562.pdf). When the `--beam_search_diversity_rate` is set to 0, it is equivalent to the naive beam search. + + For sampling, we provide the top k sampling and top p sampling, which are set by the arguments `--sampling_topk` and `--sampling_topp`. Here, k is an integer number and p is a float point number. Note that we cannot use both in the same time. So, only one of both can be non-zero value. + + The following script uses diverse decoding with diversity rate 0 and top k sampling with k = 4: + + ```bash + ./bin/decoding_gemm 32 4 8 64 30000 32 512 0 + python tensorflow/decoding_sample.py \ + --batch_size 32 \ + --beam_width 4 \ + --head_number 8 \ + --size_per_head 64 \ + --vocab_size 30000 \ + --max_seq_len 32 \ + --num_layer 6 \ + --memory_hidden_dim 512 \ + --data_type fp32 \ + --beam_search_diversity_rate 0 \ + --sampling_topk 4 \ + --sampling_topp 0.0 \ + --test_time 0123 + ``` + + The outputs should be like to the following: + + ```bash + [INFO] batch_size 32 beam_width 4 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 TF-decoding-beamsearch-time 555.87 ms. + [INFO] batch_size 32 beam_width 4 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 FT-OP-decoding-beamsearch-time 75.80 ms. + [INFO] batch_size 32 topk 4 topp 0.0 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 TF-decoding-sampling-time 432.40 ms. + [INFO] batch_size 32 topk 4 topp 0.0 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 FT-OP-decoding-sampling-time 46.68 ms. + ``` + + Note that the results of FasterTransformer may be different, especially when the batch size is larger. + + Here, we use same configuration to run the decoding with beam search and sampling in the same time. This is not correct because the beam width of decoding with sampling is always 1, so the configurations of them are same only when the beam width is 1. However, this only little reduce the speed of decoding with sampling, so we ignore this problem here. + + Here, the meaning of argument `--test_time` is different. 0 means testing the TensorFlow with beam search; 1 means testing the FasterTransformer with beam search; 2 means testing the TensorFlow with sampling; 3 means testing the FasterTransformer with sampling. + + The following script uses diverse decoding with diversity rate -1.3 and top p sampling with p = 0.01: + + ```bash + ./bin/decoding_gemm 32 4 8 64 30000 32 512 0 + python tensorflow/decoding_sample.py \ + --batch_size 32 \ + --beam_width 4 \ + --head_number 8 \ + --size_per_head 64 \ + --vocab_size 30000 \ + --max_seq_len 32 \ + --num_layer 6 \ + --memory_hidden_dim 512 \ + --data_type fp32 \ + --beam_search_diversity_rate -1.3 \ + --sampling_topk 0 \ + --sampling_topp 0.01 \ + --test_time 0123 + ``` + + The outputs should be like to the following: + + ```bash + [INFO] batch_size 32 beam_width 4 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 TF-decoding-beamsearch-time 525.55 ms. + [INFO] batch_size 32 beam_width 4 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 FT-OP-decoding-beamsearch-time 76.79 ms. + [INFO] batch_size 32 topk 4 topp 0.0 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 TF-decoding-sampling-time 420.98 ms. + [INFO] batch_size 32 topk 4 topp 0.0 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 FT-OP-decoding-sampling-time 46.37 ms. + ``` + + For the sampling algorithm, the results of TensorFlow and FasterTransformer are often different. + + 2.4 Run FasterTransformer decoding under FP16 on TensorFlow + + ```bash + ./bin/decoding_gemm 32 4 8 64 30000 32 512 1 + python tensorflow/decoding_sample.py \ + --batch_size 32 \ + --beam_width 4 \ + --head_number 8 \ + --size_per_head 64 \ + --vocab_size 30000 \ + --max_seq_len 32 \ + --num_layer 6 \ + --memory_hidden_dim 512 \ + --data_type fp16 \ + --beam_search_diversity_rate 0.0 \ + --sampling_topk 4 \ + --sampling_topp 0.00 \ + --test_time 0123 + ``` + + The outputs should be like to the following: + + ```bash + [INFO] batch_size 32 beam_width 4 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 TF-decoding-beamsearch-time 494.23 ms. + [INFO] batch_size 32 beam_width 4 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 FT-OP-decoding-beamsearch-time 50.43 ms. + [INFO] batch_size 32 topk 4 topp 0.0 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 TF-decoding-sampling-time 382.34 ms. + [INFO] batch_size 32 topk 4 topp 0.0 head_num 8 size_per_head 64 seq_len 32 decoder_layers 6 vocab_size 30000 FT-OP-decoding-sampling-time 33.19 ms. + ``` + + Note that the results of FasterTransformer may be different, especially when the batch size is larger. + + 2.5 Run FasterTransformer encoder and decoder/decoding on TensorFlow in the same time + + In this subsection, we demonstrate how to use the FasterTransformer encoder and decoder/decoding in the same time. + + ```bash + ./bin/encoder_gemm 32 32 8 64 0 + ./bin/decoding_gemm 32 4 8 64 30000 32 512 0 + python tensorflow/encoder_decoder_sample.py \ + --batch_size 32 \ + --beam_width 4 \ + --encoder_head_number 8 \ + --encoder_size_per_head 64 \ + --decoder_head_number 8 \ + --decoder_size_per_head 64 \ + --vocab_size 30000 \ + --max_seq_len 32 \ + --encoder_num_layer 6 \ + --decoder_num_layer 6 \ + --data_type fp32 + ``` + + The `encoder_decoder_sample.py` files show the results of "TensorFlow encoder + FasterTransformer decoder" and the results of "FasterTransformer encoder + FasterTransformer decoder. The usage is similar to `decoder_sample.py`. + + ```bash + ./bin/encoder_gemm 32 32 8 64 0 + ./bin/decoding_gemm 32 4 8 64 30000 32 512 0 + python tensorflow/encoder_decoding_sample.py \ + --batch_size 32 \ + --beam_width 4 \ + --encoder_head_number 8 \ + --encoder_size_per_head 64 \ + --decoder_head_number 8 \ + --decoder_size_per_head 64 \ + --vocab_size 30000 \ + --max_seq_len 32 \ + --encoder_num_layer 6 \ + --decoder_num_layer 6 \ + --data_type fp32 + ``` + + For convenience, we only show how to use the FasterTransformer encoder and decoding with beam search in the `encoder_decoding_sample.py`. The usage is similar to `decoding_sample.py`. + +3. Run FasterTransformer decoder/decoding on PyTorch + + Please install OpenNMT-py first before run the demos by + ```bash + pip install opennmt-py==1.1.1 + ``` + + 3.1 Generate the `decoding_gemm_config.in` file: + + ```bash + ./bin/decoding_gemm + ./bin/decoding_gemm 8 4 8 64 31538 32 512 1 + ``` + If you want to use the library in other directory, please generate this file according to your setting and copy it to your working directory. + + 3.2 Run the PyTorch decoder sample: + + ```bash + python pytorch/decoder_sample.py <--fp16> <--time> + python pytorch/decoder_sample.py 8 6 32 8 64 --fp16 --time + ``` + Remove `--fp16` for fp32 mode. `--ths` will use the TorchScript custom class. + + The outputs should be like to the following: + + ```bash + step: 30 Mean relative diff: 0.01395416259765625 Max relative diff: 1.38671875 Min relative diff: 0.0 + step: 31 Mean relative diff: 0.0148468017578125 Max relative diff: 2.880859375 Min relative diff: 0.0 + [INFO] ONMTDecoder time costs: 218.37 ms + [INFO] FTDecoder time costs: 25.15 ms + ``` + + Note that the relative diff is very large. It is caused by the random initial weights and inputs, and it does not affect the result of translation. + + 3.3 Run the PyTorch decoding sample: + + ```bash + python pytorch/decoding_sample.py <--fp16> <--time> + python pytorch/decoding_sample.py 8 6 32 8 64 4 31538 --fp16 --time + ``` + Remove `--fp16` for fp32 mode. `--ths` will use the TorchScript custom class. + + The outputs should be like to the following: + + ```bash + [INFO] TorchDecoding time costs: 289.08 ms + [INFO] TorchDecoding (with FTDecoder) time costs: 104.15 ms + [INFO] FTDecoding time costs: 30.57 ms + ``` + + Random initialized parameters may lead to different results. You can download the pretrained model following the instruction in the next part, and add `--use_pretrained`, then you can get the same results. -The following results are generated by + +#### Translation process + +1. Translation with FasterTransformer on TensorFlow + + This subsection demonstrates how to use FasterTansformer decoding to translate a sentence. We use the pretrained model and testing data in [OpenNMT-tf](https://opennmt.net/Models-tf/), which translates from English to German. + + Because the FasterTransformer Encoder is based on BERT, we cannot restore the model of encoder of OpenNMT to FasterTransformer Encoder. Therefore, we use OpenNMT-tf to build the encoder and preprocess the source sentence. + + Another problem is that the implementation of FasterTransformer Decoder and decoder of OpenNMT-tf is a little different. For example, the decoder of OpenNMT-tf uses one convolution to compute query, key and value in masked-multihead-attention; but FasterTransformer Decoder splits them into three gemms. One method is using the tool `utils/dump_model.py` to convert the pretrained model to fit the model structure of FasterTransformer Decoder. Another method is Splitting the weights during inference. + + `download_model_data.sh` will install the OpenNMT-tf v1, downloading the pretrained model into the `translation` folder, and convert the model. + + ```bash + bash tensorflow/utils/translation/download_model_data.sh + ``` + + Then run the translation sample by the following script: + + ```bash + ./bin/decoding_gemm 128 4 8 64 32001 100 512 0 + python tensorflow/translate_sample.py \ + --batch_size 128 \ + --beam_width 4 \ + --encoder_head_number 8 \ + --encoder_size_per_head 64 \ + --decoder_head_number 8 \ + --decoder_size_per_head 64 \ + --max_seq_len 32 \ + --encoder_num_layer 6 \ + --decoder_num_layer 6 \ + --data_type fp32 \ + --beam_search_diversity_rate 0.0 \ + --sampling_topk 1 \ + --sampling_topp 0.00 \ + --test_time 012345 + ``` + + The outputs of should be similar to the following: + + ```bash + [INFO] tf-decoding-beamsearch translates 24 batches taking 31.39 ms to translate 67092 tokens, BLEU score: 26.29, 2137 tokens/sec. + [INFO] op-decoder-beamsearch translates 24 batches taking 10.37 ms to translate 67092 tokens, BLEU score: 26.29, 6473 tokens/sec. + [INFO] op-decoding-beamsearch translates 24 batches taking 7.88 ms to translate 67124 tokens, BLEU score: 26.31, 8513 tokens/sec. + [INFO] tf-decoding-sampling translates 24 batches taking 16.23 ms to translate 67813 tokens, BLEU score: 25.79, 4178 tokens/sec. + [INFO] op-decoder-sampling translates 24 batches taking 6.29 ms to translate 67813 tokens, BLEU score: 25.79, 10781 tokens/sec. + [INFO] op-decoding-sampling translates 24 batches taking 4.10 ms to translate 67813 tokens, BLEU score: 25.79, 16524 tokens/sec. + ``` + + The scripts of running under FP16 is following: + + ```bash + python tensorflow/tensorflow_bert/ckpt_type_convert.py --init_checkpoint=translation/ckpt/model.ckpt-500000 --fp16_checkpoint=translation/ckpt/fp16_model.ckpt-500000 + ./bin/decoding_gemm 128 4 8 64 32001 100 512 1 + python tensorflow/translate_sample.py \ + --batch_size 128 \ + --beam_width 4 \ + --encoder_head_number 8 \ + --encoder_size_per_head 64 \ + --decoder_head_number 8 \ + --decoder_size_per_head 64 \ + --max_seq_len 32 \ + --encoder_num_layer 6 \ + --decoder_num_layer 6 \ + --data_type fp16 \ + --beam_search_diversity_rate 0.0 \ + --sampling_topk 1 \ + --sampling_topp 0.00 \ + --test_time 012345 + ``` + + The outputs of should be similar to the following: + + ```bash + [INFO] tf-decoding-beamsearch translates 24 batches taking 22.75 ms to translate 67094 tokens, BLEU score: 26.31, 2949 tokens/sec. + [INFO] op-decoder-beamsearch translates 24 batches taking 7.73 ms to translate 67089 tokens, BLEU score: 26.30, 8682 tokens/sec. + [INFO] op-decoding-beamsearch translates 24 batches taking 5.27 ms to translate 67130 tokens, BLEU score: 26.33, 12746 tokens/sec. + [INFO] tf-decoding-sampling translates 24 batches taking 13.65 ms to translate 67828 tokens, BLEU score: 25.83, 4968 tokens/sec. + [INFO] op-decoder-sampling translates 24 batches taking 4.92 ms to translate 67831 tokens, BLEU score: 25.80, 13773 tokens/sec. + [INFO] op-decoding-sampling translates 24 batches taking 2.54 ms to translate 67844 tokens, BLEU score: 25.82, 26718 tokens/sec. + ``` + +2. Translation with FasterTransformer on PyTorch + + We have a translation demo for En-De translation. + + You need to download the pretrained_model first by: + + ```bash + bash pytorch/scripts/download_translation_model.sh + ``` + + Then you can run the demo by: + + ```bash + python pytorch/run_translation.py --batch_size --beam_size --model_type --data_type --output_file + ``` + you can also use `--module_path` to set the FasterTransformer module `.so` file path, and use `--input_file` to set the input file to be translated. + + the `` can be: + - `ori`: original OpenNMT model + - `decoder_ext`: replace the decoder in OpenNMT model with our FasterTransformer decoder + - `decoding_ext`: using our FasterTransformer decoding module + - `torch_decoding`: PyTorch version decoding with the method FasterTransformer decoding uses + - `torch_decoding_with_decoder_ext`: PyTorch version decoding with the method FasterTransformer decoding uses but replace the decoder with the FasterTransformer decoder + + the `` can be `fp32` or `fp16` + + if you do not specify the output file, it only print to the stdout. + + If you want to evaluate the BLEU score, please recover the BPE first by: + ```bash + python pytorch/utils/recover_bpe.py + python pytorch/utils/recover_bpe.py + ``` + the `` for our demo is `pytorch/translation/data/test.de`, the `` is the output from `run_translation.py`. + + Then you can evalute the BLEU score, for example, through `sacrebleu`: + ```bash + pip install sacrebleu + cat | sacrebleu + ``` + + The following scripts run translation under FP32 and get the bleu score: + + ```bash + ./bin/decoding_gemm 128 4 8 64 31538 100 512 0 + python pytorch/run_translation.py --batch_size 128 --beam_size 4 --model_type decoding_ext --data_type fp32 --output_file output.txt + python pytorch/utils/recover_bpe.py pytorch/translation/data/test.de debpe_ref.txt + python pytorch/utils/recover_bpe.py output.txt debpe_output.txt + pip install sacrebleu + cat debpe_output.txt | sacrebleu debpe_ref.txt + ``` + +## Performance + +Hardware settings: +* CPU: Intel(R) Xeon(R) Gold 6132 CPU @ 2.60GHz +* T4 (with mclk 5000MHz, pclk 1590MHz) with Intel(R) Xeon(R) CPU E5-2603 v4 @ 1.70GHz +* V100 (with mclk 877MHz, pclk 1380MHz) with Intel(R) Xeon(R) CPU E5-2698 v4 @ 2.20GHz (dgx-1 server) + +In order to run the following benchmark, we need to install the unix computing tool "bc" by ```bash -bash scripts/profile_decoding_op_performance.sh +apt-get install bc ``` -* We set beam_width = 4 +### Encoder performance + +We demonstrate the inference time of FasterTransformer in C++, TensorFlow and PyTorch, and compare to the performance of pure TensorFlow and PyTorch on T4 and V100. + +For the benchmark of TensorFlow, we compare the performance of TensorFlow with XLA (TF), the performance of TensorFlow with FasterTransformer OP (FT-OP) and the performance of FasterTransformer on C++ (TF-CPP), and show the speedup of FT-OP and FT-CPP compare to the TensorFlow. + +For the benchmark of PyTorch, we compare the performance of PyTorch, and performance of TorchScript and the performance of PyTorch with FasterTransformer custom extension (CustomExt), and show the speedup of CustomExt compare to the PyTorch and TorchScript. Because CustomExt has no obvious overhead compare to the FasterTransformer on C++, we skip the comparison with the C++ implementation. -| | TensorFlow FP32 (in ms) | Decoding FP32 (in ms) | FP32 Speedup | TensorFlow FP16 (in ms) | Decoding FP16 (in ms) | FP16 Speedup | -|:------------:|:-------:|:-------:|:----:|:-------:|:------:|:-----:| -| (1, 32) | 430.39 | 64.16 | 6.70 | 537.95 | 49.07 | 10.96 | -| (1, 64) | 876.24 | 135.42 | 6.47 | 1056.78 | 97.45 | 10.84 | -| (1, 128) | 1799.16 | 318.65 | 5.64 | 2145.74 | 240.85 | 8.91 | -| (32, 32) | 597.42 | 217.61 | 2.74 | 646.07 | 128.39 | 5.03 | -| (64, 32) | 789.22 | 395.85 | 1.99 | 769.17 | 246.89 | 3.11 | -| (128, 32) | 1223.72 | 726.43 | 1.68 | 996.03 | 424.53 | 2.34 | -| (256, 32) | 2188.00 | 1385.60 | 1.58 | 1599.58 | 781.38 | 2.04 | +The results of c++ and TensorFlow were obtained by running the `sample/tensorflow/scripts/profile_encoder_performance.sh`. -### Decoding performance on V100 +The results of PyTorch were obtained by running the `sample/pytorch/scripts/profile_encoder.sh`. + +In the experiments of encoder, we updated the following parameters: + +* head_num = 12 +* size_per_head = 64 +* num_layers = 12 + +#### Encoder performance on T4 and TensorFlow + +* Performance on FP32 + +| | TF (ms) | FT-OP (ms) | FT-OP Speedup | FT-CPP (ms) | FT-CPP Speedup | +|:---------------------:|:-------:|:----------:|:-------------:|:-----------:|:--------------:| +| <1, 32> | 6.39 | 4.43 | 1.44 | 2.54 | 2.51 | +| <1, 64> | 6.41 | 4.84 | 1.32 | 3.60 | 1.78 | +| <1, 128> | 8.97 | 7.66 | 1.17 | 6.34 | 1.41 | +| <8, 32> | 14.52 | 13.28 | 1.09 | 11.71 | 1.23 | +| <8, 64> | 24.88 | 24.43 | 1.01 | 23.03 | 1.08 | +| <8, 128> | 50.66 | 49.55 | 1.02 | 47.04 | 1.07 | +| <32, 32> | 47.91 | 48.04 | .99 | 46.04 | 1.04 | +| <32, 64> | 103.95 | 95.93 | 1.08 | 92.31 | 1.12 | +| <32, 128> | 201.42 | 184.32 | 1.09 | 176.14 | 1.14 | +| <64, 32> | 97.49 | 96.23 | 1.01 | 93.57 | 1.04 | +| <64, 64> | 187.60 | 180.49 | 1.03 | 173.42 | 1.08 | +| <64, 128> | 392.96 | 363.74 | 1.08 | 345.40 | 1.13 | +| <128, 32> | 208.60 | 178.55 | 1.16 | 171.43 | 1.21 | +| <128, 64> | 400.00 | 353.95 | 1.13 | 338.34 | 1.18 | +| <128, 128> | 844.07 | 729.22 | 1.15 | 692.58 | 1.21 | + +* Performance on FP16 + +| | TF (ms) | FT-OP (ms) | FT-OP Speedup | FT-CPP (ms) | FT-CPP Speedup | +|:---------------------:|:-------:|:----------:|:-------------:|:-----------:|:--------------:| +| <1, 32> | 6.53 | 4.19 | 1.55 | 1.72 | 3.79 | +| <1, 64> | 6.93 | 4.96 | 1.39 | 1.86 | 3.72 | +| <1, 128> | 6.32 | 4.12 | 1.53 | 2.12 | 2.98 | +| <8, 32> | 6.89 | 4.58 | 1.50 | 2.93 | 2.35 | +| <8, 64> | 8.33 | 6.43 | 1.29 | 4.80 | 1.73 | +| <8, 128> | 15.33 | 11.46 | 1.33 | 9.40 | 1.63 | +| <32, 32> | 14.64 | 11.45 | 1.27 | 9.20 | 1.59 | +| <32, 64> | 26.50 | 21.03 | 1.26 | 18.56 | 1.42 | +| <32, 128> | 54.28 | 41.44 | 1.30 | 38.23 | 1.41 | +| <64, 32> | 26.53 | 20.99 | 1.26 | 18.84 | 1.40 | +| <64, 64> | 49.99 | 40.41 | 1.23 | 36.99 | 1.35 | +| <64, 128> | 101.39 | 83.46 | 1.21 | 77.41 | 1.30 | +| <128, 32> | 51.67 | 40.58 | 1.27 | 37.39 | 1.38 | +| <128, 64> | 98.07 | 80.91 | 1.21 | 72.67 | 1.34 | +| <128, 128> | 202.76 | 166.32 | 1.21 | 153.19 | 1.32 | + +#### Encoder performance on V100 and TensorFlow + +* Performance on FP32 + +| | TF (ms) | FT-OP (ms) | FT-OP Speedup | FT-CPP (ms) | FT-CPP Speedup | +|:---------------------:|:-------:|:----------:|:-------------:|:-----------:|:--------------:| +| <1, 32> | 3.78 | 2.99 | 1.26 | 1.76 | 2.14 | +| <1, 64> | 4.55 | 3.29 | 1.38 | 2.16 | 2.10 | +| <1, 128> | 5.23 | 4.15 | 1.26 | 2.94 | 1.77 | +| <8, 32> | 7.42 | 6.14 | 1.20 | 4.66 | 1.59 | +| <8, 64> | 10.80 | 9.98 | 1.08 | 8.48 | 1.27 | +| <8, 128> | 18.73 | 17.63 | 1.06 | 15.50 | 1.20 | +| <32, 32> | 18.16 | 16.97 | 1.07 | 15.34 | 1.18 | +| <32, 64> | 33.87 | 32.69 | 1.03 | 30.01 | 1.12 | +| <32, 128> | 66.11 | 64.31 | 1.02 | 59.46 | 1.11 | +| <64, 32> | 34.17 | 32.56 | 1.04 | 29.91 | 1.14 | +| <64, 64> | 66.21 | 63.51 | 1.04 | 58.84 | 1.12 | +| <64, 128> | 133.61 | 126.58 | 1.05 | 119.08 | 1.12 | +| <128, 32> | 65.36 | 62.72 | 1.04 | 58.22 | 1.12 | +| <128, 64> | 131.12 | 123.94 | 1.05 | 117.80 | 1.11 | +| <128, 128> | 253.90 | 251.03 | 1.01 | 234.30 | 1.08 | + +* Performance on FP16 + +| | TF (ms) | FT-OP (ms) | FT-OP Speedup | FT-CPP (ms) | FT-CPP Speedup | +|:---------------------:|:-------:|:----------:|:-------------:|:-----------:|:--------------:| +| <1, 32> | 3.44 | 3.05 | 1.12 | 1.24 | 2.77 | +| <1, 64> | 4.96 | 2.88 | 1.72 | 1.45 | 3.42 | +| <1, 128> | 3.59 | 2.79 | 1.28 | 1.57 | 2.28 | +| <8, 32> | 3.94 | 3.00 | 1.31 | 1.80 | 2.18 | +| <8, 64> | 5.12 | 3.86 | 1.32 | 2.45 | 2.08 | +| <8, 128> | 7.16 | 5.21 | 1.37 | 3.79 | 1.88 | +| <32, 32> | 7.27 | 5.25 | 1.38 | 3.60 | 2.01 | +| <32, 64> | 11.26 | 8.47 | 1.32 | 6.61 | 1.70 | +| <32, 128> | 20.62 | 15.52 | 1.32 | 12.52 | 1.64 | +| <64, 32> | 11.31 | 8.57 | 1.31 | 6.59 | 1.71 | +| <64, 64> | 19.94 | 15.63 | 1.27 | 12.22 | 1.63 | +| <64, 128> | 36.25 | 28.86 | 1.25 | 23.73 | 1.52 | +| <128, 32> | 20.15 | 15.27 | 1.31 | 12.24 | 1.64 | +| <128, 64> | 35.67 | 28.73 | 1.24 | 23.40 | 1.52 | +| <128, 128> | 68.84 | 54.53 | 1.26 | 46.11 | 1.49 | + + + +#### Effective Transformer performance on V100 and TensorFlow + +In this benchmark, we compare the performance of TensorFlow with XLA (TF), the performance of TensorFlow with FasterTransformer OP (FT-OP) and the performance of TensorFlow with FasterTransformer OP without padding (Effective FT), and show the speedup of Effecitve FT compare to the TF and FT-OP. + +The results of c++ and TensorFlow were obtained by running the `sample/tensorflow/scripts/profile_effective_transformer_performance.sh`. + +In the experiments of encoder, we updated the following parameters: + +* head_num = 12 +* size_per_head = 64 +* num_layers = 12 + +* Performance on FP32 + +| | TF (ms) | FT-OP (ms) | Effective FT (ms) | TF Speedup (ms) | FT-OP Speedup | +|:---------------------:|:-------:|:----------:|:-----------------:|:---------------:|:-------------:| +| <1, 32, 16> | 3.94 | 2.83 | 2.80 | 1.40 | 1.01 | +| <1, 64, 32> | 4.13 | 3.23 | 2.86 | 1.44 | 1.12 | +| <1, 128, 64> | 5.31 | 4.08 | 3.57 | 1.48 | 1.14 | +| <8, 32, 16> | 6.99 | 5.95 | 4.34 | 1.61 | 1.37 | +| <8, 64, 32> | 10.77 | 9.92 | 6.50 | 1.65 | 1.52 | +| <8, 128, 64> | 18.55 | 17.45 | 11.01 | 1.68 | 1.58 | +| <32, 32, 16> | 18.31 | 17.16 | 10.76 | 1.70 | 1.59 | +| <32, 64, 32> | 34.51 | 32.97 | 19.61 | 1.75 | 1.68 | +| <32, 128, 64> | 66.97 | 65.11 | 36.94 | 1.81 | 1.76 | +| <64, 32, 16> | 34.64 | 32.84 | 19.47 | 1.77 | 1.68 | +| <64, 64, 32> | 66.38 | 64.17 | 36.26 | 1.83 | 1.76 | +| <64, 128, 64> | 131.90 | 128.20 | 71.79 | 1.83 | 1.78 | +| <128, 32, 16> | 66.98 | 63.54 | 35.62 | 1.88 | 1.78 | +| <128, 64, 32> | 129.40 | 126.09 | 69.98 | 1.84 | 1.80 | +| <128, 128, 64> | 258.44 | 254.00 | 139.94 | 1.84 | 1.81 | + +* Performance on FP16 + +| | TF (ms) | FT-OP (ms) | Effective FT (ms) | TF Speedup (ms) | FT-OP Speedup | +|:---------------------:|:-------:|:----------:|:-----------------:|:---------------:|:-------------:| +| <1, 32, 16> | 3.49 | 2.74 | 2.64 | 1.32 | 1.03 | +| <1, 64, 32> | 3.27 | 2.63 | 2.77 | 1.18 | .94 | +| <1, 128, 64> | 3.49 | 2.69 | 2.74 | 1.27 | .98 | +| <8, 32, 16> | 3.87 | 2.93 | 2.83 | 1.36 | 1.03 | +| <8, 64, 32> | 5.04 | 3.77 | 3.42 | 1.47 | 1.10 | +| <8, 128, 64> | 7.11 | 5.23 | 4.44 | 1.60 | 1.17 | +| <32, 32, 16> | 7.00 | 5.08 | 4.37 | 1.60 | 1.16 | +| <32, 64, 32> | 10.99 | 8.58 | 6.03 | 1.82 | 1.42 | +| <32, 128, 64> | 19.89 | 15.42 | 10.71 | 1.85 | 1.43 | +| <64, 32, 16> | 11.06 | 8.56 | 5.98 | 1.84 | 1.43 | +| <64, 64, 32> | 19.81 | 15.18 | 10.42 | 1.90 | 1.45 | +| <64, 128, 64> | 36.47 | 28.76 | 19.21 | 1.89 | 1.49 | +| <128, 32, 16> | 19.67 | 15.08 | 10.37 | 1.89 | 1.45 | +| <128, 64, 32> | 35.34 | 27.93 | 18.58 | 1.90 | 1.50 | +| <128, 128, 64> | 69.08 | 54.86 | 36.76 | 1.87 | 1.49 | + +#### Encoder performance on T4 and PyTorch + +* Performance on FP32 + +| | PyTorch (ms) | TorchScript (ms) | CustomExt (ms) | Speedup (w/ PyTorch) | Speedup (w/ TorchScript) | +|:---------------------:|:------:|:------:|:------:|:--------:|:--------:| +| <1, 32> | 16.15 | 12.48 | 2.60 | 6.21 | 4.80 | +| <1, 64> | 20.15 | 12.51 | 3.64 | 5.53 | 3.43 | +| <1, 128> | 16.50 | 9.24 | 6.38 | 2.58 | 1.44 | +| <8, 32> | 16.60 | 14.99 | 11.71 | 1.41 | 1.28 | +| <8, 64> | 26.21 | 26.17 | 22.58 | 1.16 | 1.15 | +| <8, 128> | 52.66 | 52.29 | 43.92 | 1.19 | 1.19 | +| <32, 32> | 51.69 | 51.55 | 42.72 | 1.20 | 1.20 | +| <32, 64> | 103.17 | 102.94 | 88.71 | 1.16 | 1.16 | +| <32, 128> | 194.06 | 192.19 | 169.70 | 1.14 | 1.13 | +| <64, 32> | 103.70 | 103.35 | 88.32 | 1.17 | 1.17 | +| <64, 64> | 188.35 | 187.46 | 166.31 | 1.13 | 1.12 | +| <64, 128> | 387.92 | 384.27 | 334.53 | 1.15 | 1.14 | +| <128, 32> | 188.86 | 188.17 | 164.80 | 1.14 | 1.14 | +| <128, 64> | 376.21 | 374.40 | 326.30 | 1.15 | 1.14 | +| <128, 128> | 866.38 | 862.86 | 669.55 | 1.29 | 1.28 | + +* Performance on FP16 + +| | PyTorch (ms) | TorchScript (ms) | CustomExt (ms) | Speedup (w/ PyTorch) | Speedup (w/ TorchScript) | +|:---------------------:|:------:|:------:|:------:|:--------:|:--------:| +| <1, 32> | 20.40 | 9.75 | 2.46 | 8.29 | 3.96 | +| <1, 64> | 16.55 | 9.70 | 2.06 | 8.03 | 4.70 | +| <1, 128> | 16.29 | 12.39 | 2.36 | 6.90 | 5.25 | +| <8, 32> | 20.43 | 9.37 | 2.97 | 6.87 | 3.15 | +| <8, 64> | 15.47 | 8.58 | 4.84 | 3.19 | 1.77 | +| <8, 128> | 20.60 | 13.80 | 9.34 | 2.20 | 1.47 | +| <32, 32> | 16.63 | 12.91 | 9.07 | 1.83 | 1.42 | +| <32, 64> | 25.61 | 25.31 | 18.24 | 1.40 | 1.38 | +| <32, 128> | 54.19 | 53.28 | 36.21 | 1.49 | 1.47 | +| <64, 32> | 25.31 | 25.11 | 18.32 | 1.38 | 1.37 | +| <64, 64> | 50.91 | 50.38 | 34.88 | 1.45 | 1.44 | +| <64, 128> | 105.75 | 104.10 | 70.88 | 1.49 | 1.46 | +| <128, 32> | 50.64 | 50.21 | 35.21 | 1.43 | 1.42 | +| <128, 64> | 99.19 | 98.18 | 68.13 | 1.45 | 1.44 | +| <128, 128> | 218.95 | 215.79 | 142.66 | 1.53 | 1.51 | + +#### Encoder performance on V100 and PyTorch + +* Performance on FP32 + +| | PyTorch (ms) | TorchScript (ms) | CustomExt (ms) | Speedup (w/ PyTorch) | Speedup (w/ TorchScript) | +|:---------------------:|:------:|:------:|:------:|:--------:|:--------:| +| <1, 32> | 12.25 | 6.39 | 1.80 | 6.80 | 3.55 | +| <1, 64> | 11.59 | 8.63 | 2.20 | 5.26 | 3.92 | +| <1, 128> | 17.26 | 6.76 | 3.03 | 5.69 | 2.23 | +| <8, 32> | 11.57 | 6.71 | 4.74 | 2.44 | 1.41 | +| <8, 64> | 12.03 | 9.52 | 8.34 | 1.44 | 1.14 | +| <8, 128> | 18.60 | 18.80 | 15.34 | 1.21 | 1.22 | +| <32, 32> | 18.10 | 18.24 | 15.08 | 1.20 | 1.20 | +| <32, 64> | 34.33 | 34.39 | 29.60 | 1.15 | 1.16 | +| <32, 128> | 66.40 | 65.60 | 58.64 | 1.13 | 1.11 | +| <64, 32> | 34.86 | 34.24 | 29.60 | 1.17 | 1.15 | +| <64, 64> | 63.58 | 63.26 | 58.85 | 1.08 | 1.07 | +| <64, 128> | 130.69 | 130.51 | 117.66 | 1.11 | 1.10 | +| <128, 32> | 63.65 | 63.47 | 57.86 | 1.10 | 1.09 | +| <128, 64> | 126.79 | 126.92 | 115.19 | 1.10 | 1.10 | +| <128, 128> | 257.29 | 254.07 | 230.81 | 1.11 | 1.10 | + +* Performance on FP16 + +| | PyTorch (ms) | TorchScript (ms) | CustomExt (ms) | Speedup (w/ PyTorch) | Speedup (w/ TorchScript) | +|:---------------------:|:------:|:------:|:------:|:--------:|:--------:| +| <1, 32> | 12.30 | 8.50 | 1.69 | 7.27 | 5.02 | +| <1, 64> | 12.33 | 8.66 | 1.71 | 7.21 | 5.06 | +| <1, 128> | 14.29 | 6.74 | 1.91 | 7.48 | 3.52 | +| <8, 32> | 11.86 | 7.72 | 1.84 | 6.44 | 4.19 | +| <8, 64> | 12.76 | 6.74 | 2.51 | 5.08 | 2.68 | +| <8, 128> | 11.61 | 6.67 | 3.73 | 3.11 | 1.78 | +| <32, 32> | 12.00 | 6.19 | 3.70 | 3.24 | 1.67 | +| <32, 64> | 12.27 | 9.36 | 6.78 | 1.80 | 1.38 | +| <32, 128> | 18.61 | 18.41 | 12.63 | 1.47 | 1.45 | +| <64, 32> | 12.01 | 9.20 | 6.63 | 1.81 | 1.38 | +| <64, 64> | 17.72 | 17.35 | 12.36 | 1.43 | 1.40 | +| <64, 128> | 35.18 | 34.14 | 23.90 | 1.47 | 1.42 | +| <128, 32> | 17.35 | 17.09 | 12.32 | 1.40 | 1.38 | +| <128, 64> | 33.05 | 33.28 | 23.44 | 1.40 | 1.41 | +| <128, 128> | 67.42 | 66.03 | 46.83 | 1.43 | 1.40 | + +#### Performance on application codes of TensorFlow + +* [BERT-base-SQUAD-1.1 model](https://api.ngc.nvidia.com/v2/models/nvidia/bert_tf_v1_1_base_fp32_128/versions/2/zip), batch size 8, seq len 384, on V100 + +| Type | Exact match | F1 score | inference time (ms/example) | +|:----:|:-----------:|:--------:|:---------------------------:| +| TensorFlow FP32 | 78.13% | 85.84% | 25.53 | +| FasterTransformer OP FP32 | 78.13% | 85.84% | 18.30 | +| TensorFlow FP16 | x | x | 12.21 | +| FasterTransformer OP FP16 | 78.03% | 85.77% | 5.6 | + +#### Performance on application codes of PyTorch + +* BERT-large-SQuAD-1.1, dev set: batch size 8, seq len 384, on T4 (not TorchScipt) + +| Type | Exact match | F1 score | inference time (ms/example) | +|:----:|:-----------:|:--------:|:---------------------------:| +| PyTorch FP32 | 86.92% | 93.15% | 78.92 | +| FasterTransformer OP FP32 | 86.93% | 93.17% | 66.68 | +| PyTorch FP16 | 86.92% | 93.16% | 22.36 | +| FasterTransformer OP FP16 | 86.98% | 93.17% | 15.48 | + +* BERT-base-MRPC, dev set: batch size 8, seq len 128, on T4 (not TorchScipt) + +| Type | Exact match | F1 score | inference time (ms/example) | +|:----:|:-----------:|:--------:|:---------------------------:| +| PyTorch FP32 | 82.84% | 88.29% | 8.16 | +| FasterTransformer OP FP32 | 82.84% | 88.29% | 5.82 | +| PyTorch FP16 | 82.84% | 88.29% | 2.62 | +| FasterTransformer OP FP16 | 82.84% | 88.29% | 1.27 | + +### Decoder performance + +We demonstrate the inference time of FasterTransformer in C++, TensorFlow, and compare to the performance of pure TensorFlow on T4 and V100. The performance of PyTorch are put in the "Decoding performance" subsection. + +In this benchmark, we compare the performance of TensorFlow decoding with beam search method (TF), and the performance of replacing the decoder of TensorFlow by FasterTransformer (FT-OP), and show the speedup of FT-OP compare to TF. We do not demonstrate the performance of TensorFlow with XLA since we did not find that using XLA has obvious speedup. -The following results are generated by +Our results of c++ and TensorFlow were obtained by running the `sample/tensorflow/scripts/profile_decoder_performance.sh` -```bash -bash scripts/profile_decoding_op_performance.sh -``` +In the experiments of decoding, we updated the following parameters: -* We set beam_width = 4 +* head_num = 8 +* size_per_head = 64 +* num_layers = 6 +* vocabulary_size = 30000 for TensorFlow sample codes, 31538 for PyTorch sample codes +* memory_hidden_dim = 512 + +#### Decoder performance on T4 and TensorFlow + +* Performance on FP32 + +| | TF (ms) | FT-OP (ms) | FT-OP Speedup | +|:---------------------------------:|:-------:|:----------:|:-------------:| +| <1, 1, 32> | 509.16 | 107.98 | 4.71 | +| <1, 1, 64> | 951.49 | 223.69 | 4.25 | +| <1, 1, 128> | 1943.97 | 425.28 | 4.57 | +| <1, 4, 32> | 497.88 | 126.70 | 3.92 | +| <1, 4, 64> | 1050.92 | 243.64 | 4.31 | +| <1, 4, 128> | 2064.92 | 508.16 | 4.06 | +| <8, 1, 32> | 510.90 | 125.96 | 4.05 | +| <8, 1, 64> | 995.81 | 244.18 | 4.07 | +| <8, 1, 128> | 2041.21 | 479.02 | 4.26 | +| <8, 4, 32> | 539.70 | 129.21 | 4.17 | +| <8, 4, 64> | 1100.77 | 267.75 | 4.11 | +| <8, 4, 128> | 2100.58 | 558.91 | 3.75 | +| <32, 1, 32> | 575.80 | 123.16 | 4.67 | +| <32, 1, 64> | 1070.51 | 251.52 | 4.25 | +| <32, 1, 128> | 2172.67 | 554.32 | 3.91 | +| <32, 4, 32> | 673.70 | 204.51 | 3.29 | +| <32, 4, 64> | 1335.84 | 492.47 | 2.71 | +| <32, 4, 128> | 3136.18 | 1331.35 | 2.35 | +| <64, 1, 32> | 582.22 | 142.49 | 4.08 | +| <64, 1, 64> | 1243.74 | 312.54 | 3.97 | +| <64, 1, 128> | 2420.20 | 791.30 | 3.05 | +| <64, 4, 32> | 850.54 | 350.63 | 2.42 | +| <64, 4, 64> | 1833.49 | 874.46 | 2.09 | +| <64, 4, 128> | 4586.01 | 2450.19 | 1.87 | +| <128, 1, 32> | 656.85 | 208.91 | 3.14 | +| <128, 1, 64> | 1461.70 | 499.76 | 2.92 | +| <128, 1, 128> | 3209.60 | 1361.95 | 2.35 | +| <128, 4, 32> | 1260.55 | 656.29 | 1.92 | +| <128, 4, 64> | 2875.73 | 1663.91 | 1.72 | +| <128, 4, 128> | 8018.63 | 4718.32 | 1.69 | + +* Performance on FP16 + +| | TF (ms) | FT-OP (ms) | FT-OP Speedup | +|:---------------------------------:|:-------:|:----------:|:-------------:| +| <1, 1, 32> | 400.02 | 121.19 | 3.30 | +| <1, 1, 64> | 823.41 | 233.93 | 3.51 | +| <1, 1, 128> | 1616.38 | 422.73 | 3.82 | +| <1, 4, 32> | 476.33 | 128.45 | 3.70 | +| <1, 4, 64> | 868.67 | 261.18 | 3.32 | +| <1, 4, 128> | 1857.95 | 464.51 | 3.99 | +| <8, 1, 32> | 452.70 | 119.73 | 3.78 | +| <8, 1, 64> | 906.15 | 222.74 | 4.06 | +| <8, 1, 128> | 1789.19 | 428.80 | 4.17 | +| <8, 4, 32> | 484.09 | 127.14 | 3.80 | +| <8, 4, 64> | 973.28 | 252.81 | 3.84 | +| <8, 4, 128> | 1907.93 | 527.98 | 3.61 | +| <32, 1, 32> | 476.66 | 124.72 | 3.82 | +| <32, 1, 64> | 933.16 | 240.70 | 3.87 | +| <32, 1, 128> | 1953.02 | 518.10 | 3.76 | +| <32, 4, 32> | 607.62 | 159.24 | 3.81 | +| <32, 4, 64> | 1280.93 | 352.51 | 3.63 | +| <32, 4, 128> | 2511.20 | 882.21 | 2.84 | +| <64, 1, 32> | 501.07 | 135.40 | 3.70 | +| <64, 1, 64> | 1020.40 | 281.34 | 3.62 | +| <64, 1, 128> | 2243.14 | 627.33 | 3.57 | +| <64, 4, 32> | 692.42 | 213.80 | 3.23 | +| <64, 4, 64> | 1517.27 | 542.75 | 2.79 | +| <64, 4, 128> | 3351.21 | 1554.97 | 2.15 | +| <128, 1, 32> | 593.39 | 163.73 | 3.62 | +| <128, 1, 64> | 1258.93 | 358.26 | 3.51 | +| <128, 1, 128> | 2672.11 | 910.34 | 2.93 | +| <128, 4, 32> | 989.35 | 364.63 | 2.71 | +| <128, 4, 64> | 2216.00 | 962.84 | 2.30 | +| <128, 4, 128> | 5515.29 | 2913.02 | 1.89 | + +#### Decoder performance on V100 and TensorFlow + +* Performance on FP32 + +| | TF (ms) | FT-OP (ms) | FT-OP Speedup | +|:---------------------------------:|:-------:|:----------:|:-------------:| +| <1, 1, 32> | 239.38 | 68.88 | 3.47 | +| <1, 1, 64> | 500.20 | 133.88 | 3.73 | +| <1, 1, 128> | 1021.87 | 261.55 | 3.90 | +| <1, 4, 32> | 242.70 | 74.93 | 3.23 | +| <1, 4, 64> | 509.43 | 145.60 | 3.49 | +| <1, 4, 128> | 893.73 | 296.82 | 3.01 | +| <8, 1, 32> | 241.06 | 68.85 | 3.50 | +| <8, 1, 64> | 494.16 | 145.88 | 3.38 | +| <8, 1, 128> | 1028.89 | 285.51 | 3.60 | +| <8, 4, 32> | 274.33 | 73.38 | 3.73 | +| <8, 4, 64> | 534.15 | 152.04 | 3.51 | +| <8, 4, 128> | 1090.66 | 321.77 | 3.38 | +| <32, 1, 32> | 249.78 | 71.74 | 3.48 | +| <32, 1, 64> | 527.18 | 150.84 | 3.49 | +| <32, 1, 128> | 1053.79 | 313.93 | 3.35 | +| <32, 4, 32> | 313.01 | 114.31 | 2.73 | +| <32, 4, 64> | 666.00 | 252.23 | 2.64 | +| <32, 4, 128> | 1376.10 | 593.28 | 2.31 | +| <64, 1, 32> | 288.73 | 86.66 | 3.33 | +| <64, 1, 64> | 553.34 | 177.65 | 3.11 | +| <64, 1, 128> | 1125.72 | 404.00 | 2.78 | +| <64, 4, 32> | 377.06 | 156.55 | 2.40 | +| <64, 4, 64> | 806.34 | 373.36 | 2.15 | +| <64, 4, 128> | 1913.47 | 974.17 | 1.96 | +| <128, 1, 32> | 319.11 | 110.49 | 2.88 | +| <128, 1, 64> | 666.36 | 243.54 | 2.73 | +| <128, 1, 128> | 1426.32 | 591.99 | 2.40 | +| <128, 4, 32> | 528.52 | 256.18 | 2.06 | +| <128, 4, 64> | 1215.82 | 620.55 | 1.95 | +| <128, 4, 128> | 3167.89 | 1733.38 | 1.82 | + +* Performance on FP16 + +| | TF (ms) | FT-OP (ms) | FT-OP Speedup | +|:---------------------------------:|:-------:|:----------:|:-------------:| +| <1, 1, 32> | 209.70 | 70.37 | 2.97 | +| <1, 1, 64> | 423.41 | 141.34 | 2.99 | +| <1, 1, 128> | 775.10 | 287.64 | 2.69 | +| <1, 4, 32> | 215.05 | 81.37 | 2.64 | +| <1, 4, 64> | 449.72 | 146.28 | 3.07 | +| <1, 4, 128> | 910.03 | 291.50 | 3.12 | +| <8, 1, 32> | 226.01 | 68.60 | 3.29 | +| <8, 1, 64> | 437.30 | 153.32 | 2.85 | +| <8, 1, 128> | 915.96 | 286.39 | 3.19 | +| <8, 4, 32> | 248.44 | 75.81 | 3.27 | +| <8, 4, 64> | 463.51 | 154.71 | 2.99 | +| <8, 4, 128> | 960.88 | 293.46 | 3.27 | +| <32, 1, 32> | 233.93 | 69.80 | 3.35 | +| <32, 1, 64> | 482.73 | 147.54 | 3.27 | +| <32, 1, 128> | 922.02 | 294.40 | 3.13 | +| <32, 4, 32> | 279.34 | 88.29 | 3.16 | +| <32, 4, 64> | 582.95 | 193.42 | 3.01 | +| <32, 4, 128> | 1198.26 | 454.66 | 2.63 | +| <64, 1, 32> | 245.73 | 76.29 | 3.22 | +| <64, 1, 64> | 463.44 | 158.65 | 2.92 | +| <64, 1, 128> | 1007.24 | 332.69 | 3.02 | +| <64, 4, 32> | 331.58 | 114.84 | 2.88 | +| <64, 4, 64> | 699.38 | 262.69 | 2.66 | +| <64, 4, 128> | 1618.15 | 695.07 | 2.32 | +| <128, 1, 32> | 270.86 | 82.38 | 3.28 | +| <128, 1, 64> | 537.55 | 181.03 | 2.96 | +| <128, 1, 128> | 1183.11 | 442.73 | 2.67 | +| <128, 4, 32> | 433.38 | 165.23 | 2.62 | +| <128, 4, 64> | 928.87 | 410.96 | 2.26 | +| <128, 4, 128> | 2297.10 | 1175.40 | 1.95 | + + + +### Decoding performance + +We demonstrate the inference time of FasterTransformer in C++, TensorFlow and PyTorch, and compare to the performance of pure TensorFlow and PyTorch on T4 and V100. + +For the benchmark of TensorFlow, we compare the performance of TensorFlow (TF), the performance of TensorFlow with FasterTransformer OP (FT-OP) and the performance of FasterTransformer on C++ (TF-CPP), and show the speedup of FT-OP and FT-CPP compare to the TensorFlow. -| | TensorFlow FP32 (in ms) | Decoding FP32 (in ms) | FP32 Speedup | TensorFlow FP16 (in ms) | Decoding FP16 (in ms) | FP16 Speedup | -|:------------:|:-------:|:------:|:----:|:-------:|:------:|:-----:| -| (1, 32) | 440.46 | 58.70 | 7.50 | 531.70 | 46.18 | 11.51 | -| (1, 64) | 888.19 | 122.50 | 7.25 | 1065.76 | 93.84 | 11.35 | -| (1, 128) | 1821.76 | 293.21 | 6.21 | 2076.63 | 293.21 | 7.08 | -| (32, 32) | 543.27 | 101.35 | 5.36 | 630.55 | 73.37 | 8.59 | -| (64, 32) | 648.27 | 157.54 | 4.11 | 793.83 | 106.77 | 7.43 | -| (128, 32) | 838.43 | 277.77 | 3.02 | 867.71 | 169.04 | 5.13 | -| (256, 32) | 1221.30 | 493.85 | 2.47 | 1101.36 | 290.44 | 3.79 | +We do not demonstrate the performance of TensorFlow with XLA since we did not find that using XLA has obvious speedup. + +For the benchmark of PyTorch, we compare the performance of PyTorch decoding with beam search (PyTorch), the performance of replacing the decoder of PyTorch by FasterTransformer (Decoder) and performance of FasterTransformer Decoding with beam search (Decoding), and show the speedup Decoder and Decoding compare to the PyTorch. Due to the dynamic property, it is hard to trace/script the PyTorch decoder/decoding model, so we only test on plain PyTorch. + +The results of c++ and TensorFlow were obtained by running the `sample/tensorflow/scripts/profile_decoding_performance.sh`. + +The results of PyTorch were obtained by running the `../sample/pytorch/scripts/profile_decoder_decoding.sh`. + +In the experiments of decoding, we updated the following parameters: + +* head_num = 8 +* size_per_head = 64 +* num_layers = 6 +* vocabulary_size = 30000 for TensorFlow sample codes, 31538 for PyTorch sample codes +* memory_hidden_dim = 512 + +#### Decoding performance on T4 and TensorFlow + +* Performance on FP32 + +| | TF (ms) | FT-OP (ms) | FT-OP Speedup | FT-CPP (ms) | FT-CPP Speedup | +|:---------------------------------:|:-------:|:----------:|:-------------:|:-----------:|:--------------:| +| <1, 1, 32> | 453.10 | 31.84 | 14.23 | 28.00 | 16.18 | +| <1, 1, 64> | 882.08 | 61.51 | 14.34 | 57.33 | 15.38 | +| <1, 1, 128> | 1843.03 | 126.54 | 14.56 | 122.76 | 15.01 | +| <1, 4, 32> | 471.63 | 40.71 | 11.58 | 36.44 | 12.94 | +| <1, 4, 64> | 937.28 | 79.41 | 11.80 | 75.54 | 12.40 | +| <1, 4, 128> | 1926.79 | 166.26 | 11.58 | 160.75 | 11.98 | +| <8, 1, 32> | 482.82 | 43.48 | 11.10 | 39.85 | 12.11 | +| <8, 1, 64> | 921.57 | 87.21 | 10.56 | 83.39 | 11.05 | +| <8, 1, 128> | 1894.78 | 184.38 | 10.27 | 183.43 | 10.32 | +| <8, 4, 32> | 515.76 | 56.47 | 9.13 | 53.63 | 9.61 | +| <8, 4, 64> | 1014.02 | 119.61 | 8.47 | 120.85 | 8.39 | +| <8, 4, 128> | 2020.41 | 277.44 | 7.28 | 300.16 | 6.73 | +| <32, 1, 32> | 534.25 | 56.06 | 9.52 | 53.65 | 9.95 | +| <32, 1, 64> | 1034.65 | 121.27 | 8.53 | 121.52 | 8.51 | +| <32, 1, 128> | 1966.53 | 285.25 | 6.89 | 300.35 | 6.54 | +| <32, 4, 32> | 640.24 | 154.65 | 4.13 | 154.34 | 4.14 | +| <32, 4, 64> | 1354.65 | 350.07 | 3.86 | 367.81 | 3.68 | +| <32, 4, 128> | 3027.38 | 859.86 | 3.52 | 947.46 | 3.19 | +| <64, 1, 32> | 553.85 | 86.66 | 6.39 | 85.61 | 6.46 | +| <64, 1, 64> | 1114.51 | 192.89 | 5.77 | 198.66 | 5.61 | +| <64, 1, 128> | 2318.32 | 472.83 | 4.90 | 512.98 | 4.51 | +| <64, 4, 32> | 825.52 | 285.46 | 2.89 | 289.26 | 2.85 | +| <64, 4, 64> | 1752.80 | 653.98 | 2.68 | 685.59 | 2.55 | +| <64, 4, 128> | 4390.23 | 1631.13 | 2.69 | 1798.83 | 2.44 | +| <128, 1, 32> | 620.29 | 151.94 | 4.08 | 153.28 | 4.04 | +| <128, 1, 64> | 1366.14 | 342.94 | 3.98 | 358.99 | 3.80 | +| <128, 1, 128> | 2987.18 | 868.05 | 3.44 | 945.11 | 3.16 | +| <128, 4, 32> | 1170.25 | 542.47 | 2.15 | 552.39 | 2.11 | +| <128, 4, 64> | 2760.15 | 1257.03 | 2.19 | 1334.39 | 2.06 | +| <128, 4, 128> | 7774.93 | 3155.91 | 2.46 | 3445.01 | 2.25 | + +* Performance on FP16 + +| | TF (ms) | FT-OP (ms) | FT-OP Speedup | FT-CPP (ms) | FT-CPP Speedup | +|:---------------------------------:|:-------:|:----------:|:-------------:|:-----------:|:--------------:| +| <1, 1, 32> | 396.28 | 34.38 | 11.52 | 26.66 | 14.86 | +| <1, 1, 64> | 768.43 | 63.88 | 12.02 | 56.44 | 13.61 | +| <1, 1, 128> | 1543.99 | 129.90 | 11.88 | 123.63 | 12.48 | +| <1, 4, 32> | 419.53 | 35.09 | 11.95 | 26.25 | 15.98 | +| <1, 4, 64> | 806.38 | 59.80 | 13.48 | 54.02 | 14.92 | +| <1, 4, 128> | 1570.90 | 123.67 | 12.70 | 115.83 | 13.56 | +| <8, 1, 32> | 410.31 | 36.86 | 11.13 | 26.83 | 15.29 | +| <8, 1, 64> | 795.15 | 63.40 | 12.54 | 58.65 | 13.55 | +| <8, 1, 128> | 1639.86 | 132.13 | 12.41 | 127.12 | 12.90 | +| <8, 4, 32> | 439.64 | 38.89 | 11.30 | 35.99 | 12.21 | +| <8, 4, 64> | 891.54 | 82.09 | 10.86 | 79.82 | 11.16 | +| <8, 4, 128> | 1766.03 | 182.58 | 9.67 | 193.54 | 9.12 | +| <32, 1, 32> | 466.24 | 40.58 | 11.48 | 35.76 | 13.03 | +| <32, 1, 64> | 886.57 | 82.15 | 10.79 | 80.28 | 11.04 | +| <32, 1, 128> | 1837.41 | 187.04 | 9.82 | 195.01 | 9.42 | +| <32, 4, 32> | 536.00 | 84.37 | 6.35 | 82.82 | 6.47 | +| <32, 4, 64> | 1116.74 | 189.16 | 5.90 | 198.95 | 5.61 | +| <32, 4, 128> | 2473.57 | 470.40 | 5.25 | 518.77 | 4.76 | +| <64, 1, 32> | 480.88 | 53.39 | 9.00 | 50.89 | 9.44 | +| <64, 1, 64> | 939.87 | 114.97 | 8.17 | 118.25 | 7.94 | +| <64, 1, 128> | 2051.09 | 280.67 | 7.30 | 305.32 | 6.71 | +| <64, 4, 32> | 668.45 | 143.41 | 4.66 | 144.53 | 4.62 | +| <64, 4, 64> | 1476.17 | 332.89 | 4.43 | 351.14 | 4.20 | +| <64, 4, 128> | 3282.27 | 860.21 | 3.81 | 966.68 | 3.39 | +| <128, 1, 32> | 587.50 | 80.61 | 7.28 | 80.79 | 7.27 | +| <128, 1, 64> | 1107.02 | 182.72 | 6.05 | 193.22 | 5.72 | +| <128, 1, 128> | 2635.13 | 467.93 | 5.63 | 518.73 | 5.07 | +| <128, 4, 32> | 996.88 | 265.51 | 3.75 | 271.80 | 3.66 | +| <128, 4, 64> | 2157.85 | 627.24 | 3.44 | 671.76 | 3.21 | +| <128, 4, 128> | 5389.81 | 1646.64 | 3.27 | 1848.24 | 2.91 | + +#### Decoding performance on V100 and TensorFlow + +* Performance of FP32 + +| | TF (ms) | FT-OP (ms) | FT-OP Speedup | FT-CPP (ms) | FT-CPP Speedup | +|:---------------------------------:|:-------:|:----------:|:-------------:|:-----------:|:--------------:| +| <1, 1, 32> | 247.70 | 20.99 | 11.80 | 19.17 | 12.92 | +| <1, 1, 64> | 495.89 | 43.63 | 11.36 | 39.93 | 12.41 | +| <1, 1, 128> | 936.57 | 90.46 | 10.35 | 87.20 | 10.74 | +| <1, 4, 32> | 234.78 | 30.85 | 7.61 | 28.12 | 8.34 | +| <1, 4, 64> | 464.19 | 54.83 | 8.46 | 52.79 | 8.79 | +| <1, 4, 128> | 909.90 | 117.46 | 7.74 | 113.13 | 8.04 | +| <8, 1, 32> | 231.98 | 28.18 | 8.23 | 25.61 | 9.05 | +| <8, 1, 64> | 457.38 | 56.72 | 8.06 | 53.44 | 8.55 | +| <8, 1, 128> | 923.71 | 121.91 | 7.57 | 117.66 | 7.85 | +| <8, 4, 32> | 249.10 | 31.72 | 7.85 | 29.34 | 8.49 | +| <8, 4, 64> | 503.95 | 65.72 | 7.66 | 64.22 | 7.84 | +| <8, 4, 128> | 1020.94 | 147.66 | 6.91 | 149.51 | 6.82 | +| <32, 1, 32> | 245.18 | 31.71 | 7.73 | 29.16 | 8.40 | +| <32, 1, 64> | 521.13 | 65.71 | 7.93 | 64.31 | 8.10 | +| <32, 1, 128> | 968.92 | 149.11 | 6.49 | 149.72 | 6.47 | +| <32, 4, 32> | 290.96 | 67.00 | 4.34 | 66.66 | 4.36 | +| <32, 4, 64> | 662.04 | 147.43 | 4.49 | 155.35 | 4.26 | +| <32, 4, 128> | 1445.38 | 352.77 | 4.09 | 382.38 | 3.77 | +| <64, 1, 32> | 267.80 | 42.61 | 6.28 | 42.18 | 6.34 | +| <64, 1, 64> | 573.75 | 93.68 | 6.12 | 94.01 | 6.10 | +| <64, 1, 128> | 1204.28 | 217.32 | 5.54 | 228.94 | 5.26 | +| <64, 4, 32> | 369.10 | 113.17 | 3.26 | 114.41 | 3.22 | +| <64, 4, 64> | 811.20 | 251.04 | 3.23 | 265.57 | 3.05 | +| <64, 4, 128> | 1896.34 | 615.58 | 3.08 | 687.73 | 2.75 | +| <128, 1, 32> | 300.77 | 67.01 | 4.48 | 66.01 | 4.55 | +| <128, 1, 64> | 619.74 | 150.08 | 4.12 | 151.31 | 4.09 | +| <128, 1, 128> | 1406.48 | 356.22 | 3.94 | 387.80 | 3.62 | +| <128, 4, 32> | 497.61 | 202.93 | 2.45 | 207.86 | 2.39 | +| <128, 4, 64> | 1194.74 | 463.58 | 2.57 | 496.50 | 2.40 | +| <128, 4, 128> | 3068.19 | 1135.37 | 2.70 | 1259.20 | 2.43 | + +* Performance of FP16 + +| | TF (ms) | FT-OP (ms) | FT-OP Speedup | FT-CPP (ms) | FT-CPP Speedup | +|:---------------------------------:|:-------:|:----------:|:-------------:|:-----------:|:--------------:| +| <1, 1, 32> | 179.29 | 22.79 | 7.86 | 19.90 | 9.00 | +| <1, 1, 64> | 424.71 | 46.31 | 9.17 | 42.07 | 10.09 | +| <1, 1, 128> | 800.49 | 106.68 | 7.50 | 102.70 | 7.79 | +| <1, 4, 32> | 215.21 | 22.99 | 9.36 | 20.42 | 10.53 | +| <1, 4, 64> | 426.36 | 47.33 | 9.00 | 42.67 | 9.99 | +| <1, 4, 128> | 842.32 | 105.93 | 7.95 | 105.07 | 8.01 | +| <8, 1, 32> | 218.83 | 22.45 | 9.74 | 20.29 | 10.78 | +| <8, 1, 64> | 429.64 | 46.16 | 9.30 | 42.66 | 10.07 | +| <8, 1, 128> | 827.80 | 96.64 | 8.56 | 94.76 | 8.73 | +| <8, 4, 32> | 228.45 | 25.30 | 9.02 | 23.36 | 9.77 | +| <8, 4, 64> | 434.26 | 51.36 | 8.45 | 49.95 | 8.69 | +| <8, 4, 128> | 879.69 | 113.05 | 7.78 | 115.80 | 7.59 | +| <32, 1, 32> | 224.73 | 25.34 | 8.86 | 23.12 | 9.72 | +| <32, 1, 64> | 447.28 | 51.98 | 8.60 | 50.01 | 8.94 | +| <32, 1, 128> | 887.31 | 114.14 | 7.77 | 114.74 | 7.73 | +| <32, 4, 32> | 249.40 | 43.55 | 5.72 | 43.17 | 5.77 | +| <32, 4, 64> | 549.04 | 96.69 | 5.67 | 101.74 | 5.39 | +| <32, 4, 128> | 1182.18 | 225.50 | 5.24 | 248.09 | 4.76 | +| <64, 1, 32> | 227.12 | 30.99 | 7.32 | 29.93 | 7.58 | +| <64, 1, 64> | 494.82 | 67.05 | 7.37 | 67.49 | 7.33 | +| <64, 1, 128> | 1000.46 | 154.54 | 6.47 | 160.94 | 6.21 | +| <64, 4, 32> | 304.52 | 68.84 | 4.42 | 69.72 | 4.36 | +| <64, 4, 64> | 666.90 | 154.89 | 4.30 | 164.80 | 4.04 | +| <64, 4, 128> | 1494.30 | 373.57 | 4.00 | 425.44 | 3.51 | +| <128, 1, 32> | 252.69 | 43.08 | 5.86 | 42.74 | 5.91 | +| <128, 1, 64> | 535.56 | 93.53 | 5.72 | 97.05 | 5.51 | +| <128, 1, 128> | 1134.44 | 225.94 | 5.02 | 245.81 | 4.61 | +| <128, 4, 32> | 410.80 | 114.56 | 3.58 | 118.16 | 3.47 | +| <128, 4, 64> | 934.86 | 263.50 | 3.54 | 283.36 | 3.29 | +| <128, 4, 128> | 2236.95 | 653.69 | 3.42 | 746.66 | 2.99 | + + + +#### Decoder and decoding performance on T4 and PyTorch + +* Performance on FP32 + +| | PyTorch (ms) | Decoder (ms) | Decoding (ms) | Decoder Speedup | Decoding Speedup | +|:-----------------------:|:------:|:------:|:------:|:---------:|:---------:| +| <1, 32, 1> | 484.75 | 144.20 | 29.08 | 3.36 | 16.66 | +| <1, 64, 1> | 964.91 | 295.16 | 57.97 | 3.26 | 16.64 | +| <1, 128, 1> | 2482.00 | 716.21 | 118.97 | 3.46 | 20.86 | +| <8, 32, 1> | 640.09 | 198.37 | 41.27 | 3.22 | 15.50 | +| <8, 64, 1> | 1026.29 | 326.66 | 86.32 | 3.14 | 11.88 | +| <8, 128, 1> | 2077.31 | 683.36 | 180.75 | 3.03 | 11.49 | +| <32, 32, 1> | 539.02 | 182.05 | 55.35 | 2.96 | 9.73 | +| <32, 64, 1> | 1060.14 | 368.43 | 121.32 | 2.87 | 8.73 | +| <32, 128, 1> | 2198.63 | 822.78 | 294.63 | 2.67 | 7.46 | +| <64, 32, 1> | 544.38 | 216.06 | 87.28 | 2.51 | 6.23 | +| <64, 64, 1> | 1359.49 | 483.68 | 196.35 | 2.81 | 6.92 | +| <64, 128, 1> | 2409.26 | 1239.34 | 487.91 | 1.94 | 4.93 | +| <128, 32, 1> | 705.29 | 321.99 | 157.30 | 2.19 | 4.48 | +| <128, 64, 1> | 1490.15 | 765.70 | 359.43 | 1.94 | 4.14 | +| <128, 128, 1> | 3328.75 | 2032.92 | 900.86 | 1.63 | 3.69 | +| <1, 32, 4> | 519.91 | 170.90 | 37.49 | 3.04 | 13.86 | +| <1, 64, 4> | 1022.17 | 329.85 | 75.47 | 3.09 | 13.54 | +| <1, 128, 4> | 2087.35 | 654.85 | 156.97 | 3.18 | 13.29 | +| <8, 32, 4> | 653.81 | 212.86 | 55.83 | 3.07 | 11.71 | +| <8, 64, 4> | 1056.50 | 363.22 | 121.80 | 2.90 | 8.67 | +| <8, 128, 4> | 2187.94 | 842.20 | 298.90 | 2.59 | 7.31 | +| <32, 32, 4> | 588.74 | 320.21 | 160.45 | 1.83 | 3.66 | +| <32, 64, 4> | 1280.28 | 773.54 | 363.31 | 1.65 | 3.52 | +| <32, 128, 4> | 2869.27 | 2116.43 | 916.30 | 1.35 | 3.13 | +| <64, 32, 4> | 694.86 | 530.53 | 297.42 | 1.30 | 2.33 | +| <64, 64, 4> | 1777.26 | 1331.30 | 687.77 | 1.33 | 2.58 | +| <64, 128, 4> | 4769.54 | 3960.06 | 1740.75 | 1.20 | 2.73 | +| <128, 32, 4> | 990.83 | 975.95 | 576.75 | 1.01 | 1.71 | +| <128, 64, 4> | 2794.30 | 2610.29 | 1310.25 | 1.07 | 2.13 | + +* Performance on FP16 + +| | PyTorch (ms) | Decoder (ms) | Decoding (ms) | Decoder Speedup | Decoding Speedup | +|:-----------------------:|:------:|:------:|:------:|:---------:|:---------:| +| <1, 32, 1> | 636.17 | 187.04 | 28.32 | 3.40 | 22.46 | +| <1, 64, 1> | 1030.81 | 313.46 | 53.82 | 3.28 | 19.15 | +| <1, 128, 1> | 2029.57 | 612.47 | 121.08 | 3.31 | 16.76 | +| <8, 32, 1> | 546.08 | 163.20 | 34.43 | 3.34 | 15.86 | +| <8, 64, 1> | 1112.37 | 315.34 | 73.64 | 3.52 | 15.10 | +| <8, 128, 1> | 2237.78 | 638.65 | 160.04 | 3.50 | 13.98 | +| <32, 32, 1> | 546.68 | 171.72 | 40.91 | 3.18 | 13.36 | +| <32, 64, 1> | 1374.25 | 342.27 | 89.34 | 4.01 | 15.38 | +| <32, 128, 1> | 2219.99 | 712.94 | 206.78 | 3.11 | 10.73 | +| <64, 32, 1> | 557.29 | 196.28 | 60.96 | 2.83 | 9.14 | +| <64, 64, 1> | 1127.56 | 423.53 | 133.64 | 2.66 | 8.43 | +| <64, 128, 1> | 2431.01 | 1024.73 | 324.01 | 2.37 | 7.50 | +| <128, 32, 1> | 604.19 | 260.15 | 100.36 | 2.32 | 6.02 | +| <128, 64, 1> | 1252.95 | 594.85 | 228.57 | 2.10 | 5.48 | +| <128, 128, 1> | 2727.85 | 1526.56 | 567.00 | 1.78 | 4.81 | +| <1, 32, 4> | 568.26 | 165.05 | 33.89 | 3.44 | 16.76 | +| <1, 64, 4> | 1099.60 | 321.63 | 68.78 | 3.41 | 15.98 | +| <1, 128, 4> | 2177.06 | 630.75 | 146.24 | 3.45 | 14.88 | +| <8, 32, 4> | 558.22 | 173.52 | 41.02 | 3.21 | 13.60 | +| <8, 64, 4> | 1105.78 | 343.64 | 88.14 | 3.21 | 12.54 | +| <8, 128, 4> | 2240.45 | 728.21 | 205.81 | 3.07 | 10.88 | +| <32, 32, 4> | 606.68 | 267.60 | 104.44 | 2.26 | 5.80 | +| <32, 64, 4> | 1254.07 | 606.08 | 237.79 | 2.06 | 5.27 | +| <32, 128, 4> | 2741.17 | 1553.44 | 580.81 | 1.76 | 4.71 | +| <64, 32, 4> | 669.47 | 399.96 | 192.19 | 1.67 | 3.48 | +| <64, 64, 4> | 1424.02 | 966.43 | 436.73 | 1.47 | 3.26 | +| <64, 128, 4> | 3638.59 | 2843.25 | 1091.42 | 1.27 | 3.33 | +| <128, 32, 4> | 968.40 | 690.89 | 369.87 | 1.40 | 2.61 | +| <128, 64, 4> | 2087.75 | 1808.63 | 838.92 | 1.15 | 2.48 | +| <128, 128, 4> | 6735.41 | 5440.68 | 2082.84 | 1.23 | 3.23 | + +#### Decoder and decoding performance on V100 and PyTorch + +* Performance on FP32 + +| | PyTorch (ms) | Decoder (ms) | Decoding (ms) | Decoder Speedup | Decoding Speedup | +|:-----------------------:|:------:|:------:|:------:|:---------:|:---------:| +| <1, 32, 1> | 353.90 | 103.39 | 19.72 | 3.42 | 17.94 | +| <1, 64, 1> | 698.88 | 212.27 | 40.61 | 3.29 | 17.20 | +| <1, 128, 1> | 1449.20 | 441.20 | 79.19 | 3.28 | 18.30 | +| <8, 32, 1> | 439.07 | 139.12 | 27.43 | 3.15 | 16.00 | +| <8, 64, 1> | 761.94 | 237.07 | 55.40 | 3.21 | 13.75 | +| <8, 128, 1> | 1731.31 | 535.99 | 117.83 | 3.23 | 14.69 | +| <32, 32, 1> | 373.02 | 124.94 | 30.53 | 2.98 | 12.21 | +| <32, 64, 1> | 771.97 | 250.84 | 66.12 | 3.07 | 11.67 | +| <32, 128, 1> | 1563.37 | 527.23 | 147.27 | 2.96 | 10.61 | +| <64, 32, 1> | 391.65 | 166.63 | 43.54 | 2.35 | 8.99 | +| <64, 64, 1> | 763.75 | 347.91 | 95.53 | 2.19 | 7.99 | +| <64, 128, 1> | 1626.91 | 734.35 | 225.06 | 2.21 | 7.22 | +| <128, 32, 1> | 399.32 | 205.76 | 65.84 | 1.94 | 6.06 | +| <128, 64, 1> | 845.62 | 428.30 | 147.87 | 1.97 | 5.71 | +| <128, 128, 1> | 1780.45 | 1061.66 | 362.33 | 1.67 | 4.91 | +| <1, 32, 4> | 361.21 | 113.60 | 29.08 | 3.17 | 12.42 | +| <1, 64, 4> | 733.17 | 220.84 | 52.21 | 3.31 | 14.04 | +| <1, 128, 4> | 1489.75 | 467.02 | 125.59 | 3.18 | 11.86 | +| <8, 32, 4> | 382.98 | 124.76 | 30.43 | 3.06 | 12.58 | +| <8, 64, 4> | 768.14 | 248.43 | 64.50 | 3.09 | 11.90 | +| <8, 128, 4> | 1535.88 | 532.08 | 149.88 | 2.88 | 10.24 | +| <32, 32, 4> | 401.86 | 196.38 | 69.34 | 2.04 | 5.79 | +| <32, 64, 4> | 842.37 | 435.26 | 151.97 | 1.93 | 5.54 | +| <32, 128, 4> | 1758.36 | 1076.28 | 367.99 | 1.63 | 4.77 | +| <64, 32, 4> | 433.80 | 283.74 | 114.21 | 1.52 | 3.79 | +| <64, 64, 4> | 955.72 | 698.55 | 256.37 | 1.36 | 3.72 | +| <64, 128, 4> | 2137.94 | 1777.37 | 642.46 | 1.20 | 3.32 | +| <128, 32, 4> | 510.07 | 456.99 | 213.86 | 1.11 | 2.38 | +| <128, 64, 4> | 1140.04 | 1192.74 | 485.95 | .95 | 2.34 | + +* Performance on FP16 + +| | PyTorch (ms) | Decoder (ms) | Decoding (ms) | Decoder Speedup | Decoding Speedup | +|:-----------------------:|:------:|:------:|:------:|:---------:|:---------:| +| <1, 32, 1> | 364.93 | 104.67 | 23.59 | 3.48 | 15.46 | +| <1, 64, 1> | 730.63 | 219.29 | 48.02 | 3.33 | 15.21 | +| <1, 128, 1> | 1448.80 | 435.08 | 90.06 | 3.32 | 16.08 | +| <8, 32, 1> | 396.70 | 113.47 | 28.43 | 3.49 | 13.95 | +| <8, 64, 1> | 766.96 | 213.44 | 58.41 | 3.59 | 13.13 | +| <8, 128, 1> | 1508.97 | 430.11 | 123.92 | 3.50 | 12.17 | +| <32, 32, 1> | 380.00 | 113.32 | 30.81 | 3.35 | 12.33 | +| <32, 64, 1> | 755.43 | 230.70 | 56.28 | 3.27 | 13.42 | +| <32, 128, 1> | 1592.17 | 481.88 | 140.00 | 3.30 | 11.37 | +| <64, 32, 1> | 385.02 | 150.23 | 36.38 | 2.56 | 10.58 | +| <64, 64, 1> | 1006.94 | 352.55 | 77.56 | 2.85 | 12.98 | +| <64, 128, 1> | 1647.93 | 669.11 | 174.38 | 2.46 | 9.45 | +| <128, 32, 1> | 393.47 | 172.10 | 49.39 | 2.28 | 7.96 | +| <128, 64, 1> | 846.32 | 371.34 | 109.92 | 2.27 | 7.69 | +| <128, 128, 1> | 1812.89 | 892.29 | 260.72 | 2.03 | 6.95 | +| <1, 32, 4> | 403.72 | 111.89 | 28.33 | 3.60 | 14.25 | +| <1, 64, 4> | 758.80 | 215.31 | 58.97 | 3.52 | 12.86 | +| <1, 128, 4> | 1565.94 | 431.89 | 113.51 | 3.62 | 13.79 | +| <8, 32, 4> | 388.91 | 117.17 | 31.56 | 3.31 | 12.32 | +| <8, 64, 4> | 768.24 | 232.11 | 61.85 | 3.30 | 12.42 | +| <8, 128, 4> | 1618.71 | 497.68 | 136.25 | 3.25 | 11.88 | +| <32, 32, 4> | 415.84 | 183.10 | 51.08 | 2.27 | 8.14 | +| <32, 64, 4> | 874.10 | 390.93 | 112.19 | 2.23 | 7.79 | +| <32, 128, 4> | 1806.96 | 876.53 | 255.26 | 2.06 | 7.07 | +| <64, 32, 4> | 453.94 | 234.66 | 84.20 | 1.93 | 5.39 | +| <64, 64, 4> | 948.13 | 517.52 | 185.68 | 1.83 | 5.10 | +| <64, 128, 4> | 2071.99 | 1333.14 | 446.57 | 1.55 | 4.63 | +| <128, 32, 4> | 486.71 | 349.62 | 146.36 | 1.39 | 3.32 | +| <128, 64, 4> | 1084.80 | 808.79 | 330.19 | 1.34 | 3.28 | +| <128, 128, 4> | 2638.70 | 2248.28 | 800.58 | 1.17 | 3.29 | + +#### TensorFlow performance on translation + +We test with batch_size 128, beam width 4 on V100. + +| Type | tokens per seconds | BLEU | +|:----:|:------------------:|:----:| +| TensorFlow, beam search, FP32 | 2137 | BLEU 26.29 | +| Decoder, beam search, FP32 | 6473 | BLEU 26.29 | +| Decoding, beam search, FP32 | 8513 | BLEU 26.31 | +| TensorFlow, sampling, FP32 | 4178 | BLEU 25.79 | +| Decoder, sampling, FP32 | 10781 | BLEU 25.79 | +| Decoding, sampling, FP32 | 16524 | BLEU 25.79 | +| TensorFlow, beam search, FP16 | 2949 | BLEU 26.31 | +| Decoder, beam search, FP16 | 8682 | BLEU 26.30 | +| Decoding, beam search, FP16 | 12746 | BLEU 26.33 | +| TensorFlow, sampling, FP16 | 6968 | BLEU 25.83 | +| Decoder, sampling, FP16 | 13773 | BLEU 25.80 | +| Decoding, sampling, FP16 | 26718 | BLEU 25.82 | + +#### PyTorch performance on translation + +We test with batch_size 128, beam width 4, beam search algorithm on V100. + +| Type | tokens per seconds | BLEU | +|:----:|:------------------:|:----:| +| PyTorch, FP32 | 2294 | BLEU 28.0 | +| Decoder, FP32 | 2911 | BLEU 28.0 | +| Decoding, FP32 | 3674 | BLEU 28.0 | +| PyTorch, FP16 | 2245 | BLEU 28.0 | +| Decoder, FP16 | 3711 | BLEU 28.0 | +| Decoding, FP16 | 4932 | BLEU 28.0 | ## Release notes ### Changelog +June 2020 +- Add [effective transformer](https://github.com/bytedance/effective_transformer) idea into encoder. +- Optimize the beam search kernels. +- Add PyTorch op supporting + +May 2020 +- Fix the bug that seq_len of encoder must be larger than 3. +- Add the position_encoding of decoding as the input of FasterTransformer decoding. This is convenient to use different types of position encoding. FasterTransformer does not compute the position encoding value, but only lookup the table. +- Modifying the method of loading model in `translate_sample.py`. + April 2020 -- Fix the bug of encoder tensorrt plugin. +- Rename `decoding_opennmt.h` to `decoding_beamsearch.h` +- Add DiverseSiblingsSearch for decoding. +- Add sampling into Decoding + - The implementation is in the `decoding_sampling.h` + - Add top_k sampling, top_p sampling for decoding. +- Refactor the tensorflow custom op codes. + - Merge `bert_transformer_op.h`, `bert_transformer_op.cu.cc` into `bert_transformer_op.cc` + - Merge `decoder.h`, `decoder.cu.cc` into `decoder.cc` + - Merge `decoding_beamsearch.h`, `decoding_beamsearch.cu.cc` into `decoding_beamsearch.cc` +- Fix the bugs of finalize function decoding.py. +- Fix the bug of tf DiverseSiblingSearch. +- Add BLEU scorer `bleu_score.py` into `utils`. Note that the BLEU score requires python3. +- Fuse QKV Gemm of encoder and masked_multi_head_attention of decoder. +- Add dynamic batch size and dynamic sequence length features into all ops. March 2020 - Add feature in FasterTransformer 2.0 @@ -884,19 +2788,26 @@ March 2020 - Add a normalization for inputs of decoder Febuary 2020 -- Release the FasterTransformer 2.0 +- **Release the FasterTransformer 2.0** - Provide a highly optimized OpenNMT-tf based decoder and decoding, including C++ API and TensorFlow op. - Refine the sample codes of encoder. - Add dynamic batch size feature into encoder op. July 2019 -- Release the FasterTransformer 1.0 +- **Release the FasterTransformer 1.0** - Provide a highly optimized bert equivalent transformer layer, including C++ API, TensorFlow op and TensorRT plugin. ### Known issues +- Undefined symbol errors when import the extension + - Please `import torch` first. If this has been done, it is due to the incompatible C++ ABI. You may need to check the PyTorch used during compilation and execution are the same, or you need to check how your PyTorch is compiled, or the version of your GCC, etc. - batch_size should be smaller or equal to 1024 in Decoder. - batch_size x beam_width should be smaller or equal to 1024 in Decoding. - Results of TensorFlow and OP would be different in decoding. This problem is caused by the accumulated log probability, and we do not avoid this problem. - Cmake 15 or Cmake 16 fail to build this project. Cmake 14 is no problem. -- Max sequence length of encoder and decoder should be the same. +- If encounter some problem in the custom environment, try to use the gcc/g++ 4.8 to build the project of TensorFlow op, especially for TensorFlow 1.14. + +### TODO + +- Refactor the codes +- Split the initialization of top k and top p sampling \ No newline at end of file diff --git a/fastertransformer/CMakeLists.txt b/fastertransformer/CMakeLists.txt index 2d29aa541..9e3d4e75c 100644 --- a/fastertransformer/CMakeLists.txt +++ b/fastertransformer/CMakeLists.txt @@ -16,3 +16,7 @@ add_subdirectory(cuda) if(BUILD_TF) add_subdirectory(tf_op) endif() + +if(BUILD_THE OR BUILD_THS OR BUILD_THSOP) + add_subdirectory(th_op) +endif() diff --git a/fastertransformer/allocator.h b/fastertransformer/allocator.h index 82fe359c6..43d98c43d 100644 --- a/fastertransformer/allocator.h +++ b/fastertransformer/allocator.h @@ -22,6 +22,7 @@ #include "fastertransformer/common.h" #include "fastertransformer/utils.h" #include +#include #ifdef GOOGLE_CUDA #include "tensorflow/core/framework/op.h" @@ -35,6 +36,11 @@ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #endif +#ifdef TORCH_CUDA +#include +#include "torch/extension.h" +#endif + namespace fastertransformer { @@ -76,8 +82,6 @@ class Allocator : public IAllocator } }; -//TODO: allocator of TensorFlow -// You can add context to constructor #ifdef GOOGLE_CUDA using namespace tensorflow; template <> @@ -85,9 +89,10 @@ class Allocator : public IAllocator { OpKernelContext *context_; std::vector *allocated_tensor_vector; + cudaStream_t stream_; public: - Allocator(OpKernelContext *context) : context_(context) + Allocator(OpKernelContext *context, cudaStream_t stream) : context_(context), stream_(stream) { allocated_tensor_vector = new std::vector; } @@ -104,7 +109,7 @@ class Allocator : public IAllocator auto flat = buf.flat(); void *ptr = (void *)flat.data(); - cudaMemset(ptr, 0, buf_size); + cudaMemsetAsync(ptr, 0, buf_size, stream_); return ptr; } @@ -123,4 +128,31 @@ class Allocator : public IAllocator } }; #endif + +#ifdef TORCH_CUDA +template <> +class Allocator : public IAllocator +{ + std::shared_ptr> allocated_tensor_vector; + +public: + Allocator() : allocated_tensor_vector(std::make_shared>()) {} + + void *malloc(size_t size) const + { + int64_t buf_size = static_cast(size); + torch::Tensor buf = torch::empty({buf_size}, torch::dtype(torch::kUInt8).device(torch::kCUDA)); + allocated_tensor_vector->push_back(buf); + return (*allocated_tensor_vector)[allocated_tensor_vector->size()-1].data_ptr(); + } + + void free(void *ptr) const + { +#ifndef NDEBUG + printf("call from allocator free\n"); +#endif + return; + } +}; +#endif } //namespace fastertransformer diff --git a/fastertransformer/arguments.h b/fastertransformer/arguments.h new file mode 100644 index 000000000..cbd483319 --- /dev/null +++ b/fastertransformer/arguments.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Decoder transformer + **/ + +#pragma once + +#include "fastertransformer/common.h" +#include "fastertransformer/common_structure.h" +#include +#include + +namespace fastertransformer +{ + +template +class DecodingInitParam +{ +public: + /* weights for masked_multi_head_attention */ + const T *embedding_table; + const T *embedding_kernel; + const float *embedding_bias; + + const T *memory_tensor; + const int *memory_sequence_length; + + const T *position_encoding_table; + + LayerNormWeight layernorm; + + int *output_ids; + int *parent_ids; + int *sequence_length; + cublasHandle_t cublas_handle; + cudaStream_t stream; +}; + +struct TransformerArguments +{ + int batch_size_; + int seq_len_; + int head_num_; + int size_per_head_; + int hidden_units_; +}; + +struct DecodingArguments : public TransformerArguments +{ + int decoder_layers_; + int vocab_size_; + int start_id_; + int end_id_; +}; + +struct DecodingSamplingArguments : public DecodingArguments +{ + int candidate_num_; + float probability_threshold_; + size_t temp_storage_size_; +}; + +struct DecodingBeamsearchArguments : public DecodingArguments +{ + int beam_width_; + int temp_storage_size_; + float beam_search_diversity_rate_; +}; + +} // end of namespace transformer \ No newline at end of file diff --git a/fastertransformer/beamsearch_opennmt.h b/fastertransformer/beamsearch_opennmt.h deleted file mode 100644 index fff238c14..000000000 --- a/fastertransformer/beamsearch_opennmt.h +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * BeamSearch OpenNMT - **/ - -#pragma once - -#include -#include "fastertransformer/allocator.h" -#include "fastertransformer/cuda/cuda_kernels.h" -#include "fastertransformer/cuda/open_attention.h" -#include "fastertransformer/cuda/decoding_kernel_check.h" - -namespace fastertransformer -{ - -template -void BeamSearch_OpenNMT( - float *log_probs, float *cum_log_probs, bool *finished, - T **key_cache, T **value_cache, - int *parent_ids, - int *sequence_length, - int *word_ids, - int *ids, - int *output_ids, - const int batch_size, const int beam_width, - const int vocab_size, const int hidden_dim, const int step, - const int cache_size, const int decoder_layers, cudaStream_t stream, - const int end_id, - int *finished_count) -{ -#ifdef NDEBUG - /* adding cum_log_probs to log_probs */ - broadcast_kernelLauncher(log_probs, cum_log_probs, batch_size, beam_width, vocab_size, stream); -#else - broadcast_kernelLauncher(log_probs, cum_log_probs, batch_size, beam_width, vocab_size, stream); - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); - - /* - User can check the broadcast_kernel by broadcast_kernel_check. - broadcast_kernel_check will compare the results of GPU and CPU. - Note that broadcast_kernel_check contains broadcast_kernelLauncher and uses do not need to call it again. - */ - // broadcast_kernel_check(log_probs, cum_log_probs, batch_size, beam_width, vocab_size, stream); -#endif - -#ifdef NDEBUG - /*Use two round kernels to pick the topK values for each batch */ - topK(log_probs, ids, batch_size, beam_width, vocab_size, stream); -#else - topK(log_probs, ids, batch_size, beam_width, vocab_size, stream); - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); - - /* - User can check the topK by topK_check. - topK_check will compare the results of GPU and CPU. - Note that topK_check contains topK and uses do not need to call it again. - */ - // topK_kernel_check(log_probs, ids, batch_size, beam_width, vocab_size, stream); -#endif - -#ifdef NDEBUG - update(log_probs, cum_log_probs, ids, finished, - parent_ids, sequence_length, word_ids, output_ids, - batch_size, beam_width, vocab_size, stream, - end_id, finished_count); -#else - update(log_probs, cum_log_probs, ids, finished, - parent_ids, sequence_length, word_ids, output_ids, - batch_size, beam_width, vocab_size, stream, - end_id, finished_count); - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); - - /* - User can check the update by update_kernel_check. - update_kernel_check will compare the results of GPU and CPU. - Note that update_kernel_check contains update and uses do not need to call it again. - */ - // update_kernel_check(log_probs, cum_log_probs, ids, finished, parent_ids, sequence_length, word_ids, output_ids, - // batch_size, beam_width, vocab_size, stream, end_id, finished_count); -#endif - -#ifdef NDEBUG - update_KV_cache(key_cache, value_cache, parent_ids, batch_size, - beam_width, hidden_dim, step, cache_size, - decoder_layers, stream); -#else - update_KV_cache(key_cache, value_cache, parent_ids, batch_size, - beam_width, hidden_dim, step, cache_size, - decoder_layers, stream); - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); - - /* - User can check the update_KV_cache by update_KV_cache_kernel_check. - update_KV_cache_kernel_check will compare the results of GPU and CPU. - Note that update_KV_cache_kernel_check contains update_KV_cache and uses do not need to call it again. - */ - // update_KV_cache_kernel_check(key_cache, value_cache, parent_ids, batch_size, beam_width, hidden_dim, step, cache_size, decoder_layers, stream); -#endif -} - -} // namespace fastertransformer diff --git a/fastertransformer/bert_encoder_transformer.h b/fastertransformer/bert_encoder_transformer.h index 2be66126f..fc56bb014 100644 --- a/fastertransformer/bert_encoder_transformer.h +++ b/fastertransformer/bert_encoder_transformer.h @@ -33,19 +33,23 @@ template class EncoderInitParam { public: - const T *from_tensor; - const T *to_tensor; + const T *from_tensor = nullptr; + const T *to_tensor = nullptr; AttentionWeight self_attention; - const T *attr_mask; + const T *attr_mask = nullptr; LayerNormWeight self_layernorm; FFNWeight ffn; LayerNormWeight ffn_layernorm; T *transformer_out; - cublasHandle_t cublas_handle; - cudaStream_t stream; + cublasHandle_t cublas_handle = nullptr; + cudaStream_t stream = 0; + + const int* sequence_id_offset = nullptr; + int valid_word_num = -1; + }; template class MultiHeadAttention_> @@ -85,6 +89,10 @@ class BertEncoderTransformer DataType_ *attr_out_buf_; DataType_ *attr_matmul_buf_; DataType_ *inter_matmul_buf_; + + DataType_ *attr_out_tmp_buf_; + DataType_ *out_tmp_buf_; + DataType_ *from_tensor_tmp_buf_; int batch_size_; int from_seq_len_; @@ -109,14 +117,18 @@ class BertEncoderTransformer try { - buf_ = reinterpret_cast(allocator_.malloc(sizeof(DataType_) * buf_size * 6)); + buf_ = reinterpret_cast(allocator_.malloc(sizeof(DataType_) * buf_size * (6 + 3))); if (buf_ == nullptr) - throw std::runtime_error(std::string("Tensorflow Allocator failed to allocate internal buffer.")); + throw std::runtime_error(std::string("Allocator failed to allocate internal buffer.")); attr_out_buf_ = buf_; attr_matmul_buf_ = attr_out_buf_ + buf_size; inter_matmul_buf_ = attr_matmul_buf_ + buf_size; - + + attr_out_tmp_buf_ = inter_matmul_buf_ + 4 * buf_size; + out_tmp_buf_ = attr_out_tmp_buf_ + buf_size; + from_tensor_tmp_buf_ = out_tmp_buf_ + buf_size; + attention_ = new typename Traits_::MultiHeadAttention(allocator_, batch_size_, from_seq_len_, to_seq_len_, head_num_, size_per_head_); FILE *fd = fopen("gemm_config.in", "r"); int err = 0; @@ -124,7 +136,7 @@ class BertEncoderTransformer printf("gemm_config.in is not found\n"); else { - err = fscanf(fd, "%d%d%d%*d%*d", &cublasAlgo_[0], &cublasAlgo_[1], &cublasAlgo_[2]); + err = fscanf(fd, "%d %*f %d %*f %d %*f %*d %*f %*d %*f %*d %*f", &cublasAlgo_[0], &cublasAlgo_[1], &cublasAlgo_[2]); fclose(fd); } if (err != 3) @@ -161,7 +173,6 @@ class BertEncoderTransformer #endif param_ = param; cuda::MultiHeadInitParam multi_head_init_param; - multi_head_init_param.from_tensor = param.from_tensor; multi_head_init_param.to_tensor = param.to_tensor; multi_head_init_param.self_attention = param.self_attention; @@ -169,6 +180,8 @@ class BertEncoderTransformer multi_head_init_param.stream = param.stream; multi_head_init_param.cublas_handle = param.cublas_handle; multi_head_init_param.attr_out = attr_out_buf_; + multi_head_init_param.valid_word_num = param.valid_word_num; + multi_head_init_param.sequence_id_offset = param.sequence_id_offset; attention_->initialize(multi_head_init_param); } @@ -192,10 +205,10 @@ class BertEncoderTransformer DataType_ alpha = (DataType_)1.0f; DataType_ beta = (DataType_)0.0f; - int m = batch_size_ * from_seq_len_; + const int m = param_.sequence_id_offset == nullptr ? batch_size_ * from_seq_len_ : param_.valid_word_num; int k = head_num_ * size_per_head_; int n = k; - + check_cuda_error(cublasGemmEx(param_.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, @@ -253,6 +266,7 @@ class BertEncoderTransformer param_.ffn_layernorm.gamma, param_.ffn_layernorm.beta, m, n, param_.stream); + #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); diff --git a/fastertransformer/common.h b/fastertransformer/common.h index e8d7ae40b..dddd5db9f 100644 --- a/fastertransformer/common.h +++ b/fastertransformer/common.h @@ -24,7 +24,7 @@ namespace fastertransformer{ enum class OperationType{FP32, FP16}; - enum class AllocatorType{CUDA, TF}; + enum class AllocatorType{CUDA, TF, TH}; #define PRINT_FUNC_NAME_() do{\ std::cout << "[FT][CALL] " << __FUNCTION__ << " " << std::endl; \ @@ -116,4 +116,4 @@ void check_max_val(const T* result, const int size){ printf("[INFO][CUDA] addr %p max val: %f \n", result, max_val); } -}//namespace fastertransformer +}//namespace fastertransformer \ No newline at end of file diff --git a/fastertransformer/common_structure.h b/fastertransformer/common_structure.h index c67a47089..b759296b9 100644 --- a/fastertransformer/common_structure.h +++ b/fastertransformer/common_structure.h @@ -18,14 +18,14 @@ template struct DenseWeight{ - const T* kernel; - const T* bias; + const T* kernel = nullptr; + const T* bias = nullptr; }; template struct LayerNormWeight{ - const T* gamma; - const T* beta; + const T* gamma = nullptr; + const T* beta = nullptr; }; template diff --git a/fastertransformer/cuda/CMakeLists.txt b/fastertransformer/cuda/CMakeLists.txt index 0d7df3e94..f187de381 100644 --- a/fastertransformer/cuda/CMakeLists.txt +++ b/fastertransformer/cuda/CMakeLists.txt @@ -13,13 +13,46 @@ # limitations under the License. cmake_minimum_required(VERSION 3.8) -set(cuda_kernel_files - cuda_kernels.cu +set(encoder_kernel_files open_attention.cu +) + +set(decoder_kernel_files open_decoder.cu - decoding_kernel_check.cpp ) -add_library(fastertransformer STATIC ${cuda_kernel_files}) -target_link_libraries(fastertransformer PUBLIC -lcublas -lcudart) +set(online_softmax_beamsearch_kernel_files + online_softmax_beamsearch_kernels.cu +) + +set(topk_kernel_files + topk_kernels.cu +) + +set(decoding_kernel_files + decoding_kernels.cu +) + +add_library(cuda_kernels STATIC cuda_kernels.cu) +target_link_libraries(cuda_kernels PUBLIC -lcublas -lcudart -lcurand) +set_property(TARGET cuda_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) + +add_library(encoder STATIC ${encoder_kernel_files}) +target_link_libraries(encoder PUBLIC -lcublas -lcudart -lcurand cuda_kernels) +set_property(TARGET encoder PROPERTY POSITION_INDEPENDENT_CODE ON) + +add_library(decoder STATIC ${decoder_kernel_files}) +target_link_libraries(decoder PUBLIC -lcublas -lcudart -lcurand) +set_property(TARGET decoder PROPERTY POSITION_INDEPENDENT_CODE ON) + +add_library(online_softmax_beamsearch STATIC ${online_softmax_beamsearch_kernel_files}) +target_link_libraries(online_softmax_beamsearch PUBLIC -lcublas -lcudart -lcurand) +set_property(TARGET online_softmax_beamsearch PROPERTY POSITION_INDEPENDENT_CODE ON) + +add_library(topk STATIC ${topk_kernel_files}) +target_link_libraries(topk PUBLIC -lcublas -lcudart -lcurand) +set_property(TARGET topk PROPERTY POSITION_INDEPENDENT_CODE ON) +add_library(decoding STATIC ${decoding_kernel_files}) +target_link_libraries(decoding PUBLIC -lcublas -lcudart -lcurand topk online_softmax_beamsearch cuda_kernels) +set_property(TARGET decoding PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/fastertransformer/cuda/cub/agent/agent_histogram.cuh b/fastertransformer/cuda/cub/agent/agent_histogram.cuh new file mode 100644 index 000000000..37b1ec973 --- /dev/null +++ b/fastertransformer/cuda/cub/agent/agent_histogram.cuh @@ -0,0 +1,787 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::AgentHistogram implements a stateful abstraction of CUDA thread blocks for participating in device-wide histogram . + */ + +#pragma once + +#include + +#include "../util_type.cuh" +#include "../block/block_load.cuh" +#include "../grid/grid_queue.cuh" +#include "../iterator/cache_modified_input_iterator.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/****************************************************************************** + * Tuning policy + ******************************************************************************/ + +/** + * + */ +enum BlockHistogramMemoryPreference +{ + GMEM, + SMEM, + BLEND +}; + + +/** + * Parameterizable tuning policy type for AgentHistogram + */ +template < + int _BLOCK_THREADS, ///< Threads per thread block + int _PIXELS_PER_THREAD, ///< Pixels per thread (per tile of input) + BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use + CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements + bool _RLE_COMPRESS, ///< Whether to perform localized RLE to compress samples before histogramming + BlockHistogramMemoryPreference _MEM_PREFERENCE, ///< Whether to prefer privatized shared-memory bins (versus privatized global-memory bins) + bool _WORK_STEALING> ///< Whether to dequeue tiles from a global work queue +struct AgentHistogramPolicy +{ + enum + { + BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block + PIXELS_PER_THREAD = _PIXELS_PER_THREAD, ///< Pixels per thread (per tile of input) + IS_RLE_COMPRESS = _RLE_COMPRESS, ///< Whether to perform localized RLE to compress samples before histogramming + MEM_PREFERENCE = _MEM_PREFERENCE, ///< Whether to prefer privatized shared-memory bins (versus privatized global-memory bins) + IS_WORK_STEALING = _WORK_STEALING, ///< Whether to dequeue tiles from a global work queue + }; + + static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use + static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements +}; + + +/****************************************************************************** + * Thread block abstractions + ******************************************************************************/ + +/** + * \brief AgentHistogram implements a stateful abstraction of CUDA thread blocks for participating in device-wide histogram . + */ +template < + typename AgentHistogramPolicyT, ///< Parameterized AgentHistogramPolicy tuning policy type + int PRIVATIZED_SMEM_BINS, ///< Number of privatized shared-memory histogram bins of any channel. Zero indicates privatized counters to be maintained in device-accessible memory. + int NUM_CHANNELS, ///< Number of channels interleaved in the input data. Supports up to four channels. + int NUM_ACTIVE_CHANNELS, ///< Number of channels actively being histogrammed + typename SampleIteratorT, ///< Random-access input iterator type for reading samples + typename CounterT, ///< Integer type for counting sample occurrences per histogram bin + typename PrivatizedDecodeOpT, ///< The transform operator type for determining privatized counter indices from samples, one for each channel + typename OutputDecodeOpT, ///< The transform operator type for determining output bin-ids from privatized counter indices, one for each channel + typename OffsetT, ///< Signed integer type for global offsets + int PTX_ARCH = CUB_PTX_ARCH> ///< PTX compute capability +struct AgentHistogram +{ + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + /// The sample type of the input iterator + typedef typename std::iterator_traits::value_type SampleT; + + /// The pixel type of SampleT + typedef typename CubVector::Type PixelT; + + /// The quad type of SampleT + typedef typename CubVector::Type QuadT; + + /// Constants + enum + { + BLOCK_THREADS = AgentHistogramPolicyT::BLOCK_THREADS, + + PIXELS_PER_THREAD = AgentHistogramPolicyT::PIXELS_PER_THREAD, + SAMPLES_PER_THREAD = PIXELS_PER_THREAD * NUM_CHANNELS, + QUADS_PER_THREAD = SAMPLES_PER_THREAD / 4, + + TILE_PIXELS = PIXELS_PER_THREAD * BLOCK_THREADS, + TILE_SAMPLES = SAMPLES_PER_THREAD * BLOCK_THREADS, + + IS_RLE_COMPRESS = AgentHistogramPolicyT::IS_RLE_COMPRESS, + + MEM_PREFERENCE = (PRIVATIZED_SMEM_BINS > 0) ? + AgentHistogramPolicyT::MEM_PREFERENCE : + GMEM, + + IS_WORK_STEALING = AgentHistogramPolicyT::IS_WORK_STEALING, + }; + + /// Cache load modifier for reading input elements + static const CacheLoadModifier LOAD_MODIFIER = AgentHistogramPolicyT::LOAD_MODIFIER; + + + /// Input iterator wrapper type (for applying cache modifier) + typedef typename If::VALUE, + CacheModifiedInputIterator, // Wrap the native input pointer with CacheModifiedInputIterator + SampleIteratorT>::Type // Directly use the supplied input iterator type + WrappedSampleIteratorT; + + /// Pixel input iterator type (for applying cache modifier) + typedef CacheModifiedInputIterator + WrappedPixelIteratorT; + + /// Qaud input iterator type (for applying cache modifier) + typedef CacheModifiedInputIterator + WrappedQuadIteratorT; + + /// Parameterized BlockLoad type for samples + typedef BlockLoad< + SampleT, + BLOCK_THREADS, + SAMPLES_PER_THREAD, + AgentHistogramPolicyT::LOAD_ALGORITHM> + BlockLoadSampleT; + + /// Parameterized BlockLoad type for pixels + typedef BlockLoad< + PixelT, + BLOCK_THREADS, + PIXELS_PER_THREAD, + AgentHistogramPolicyT::LOAD_ALGORITHM> + BlockLoadPixelT; + + /// Parameterized BlockLoad type for quads + typedef BlockLoad< + QuadT, + BLOCK_THREADS, + QUADS_PER_THREAD, + AgentHistogramPolicyT::LOAD_ALGORITHM> + BlockLoadQuadT; + + /// Shared memory type required by this thread block + struct _TempStorage + { + CounterT histograms[NUM_ACTIVE_CHANNELS][PRIVATIZED_SMEM_BINS + 1]; // Smem needed for block-privatized smem histogram (with 1 word of padding) + + int tile_idx; + + // Aliasable storage layout + union Aliasable + { + typename BlockLoadSampleT::TempStorage sample_load; // Smem needed for loading a tile of samples + typename BlockLoadPixelT::TempStorage pixel_load; // Smem needed for loading a tile of pixels + typename BlockLoadQuadT::TempStorage quad_load; // Smem needed for loading a tile of quads + + } aliasable; + }; + + + /// Temporary storage type (unionable) + struct TempStorage : Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + /// Reference to temp_storage + _TempStorage &temp_storage; + + /// Sample input iterator (with cache modifier applied, if possible) + WrappedSampleIteratorT d_wrapped_samples; + + /// Native pointer for input samples (possibly NULL if unavailable) + SampleT* d_native_samples; + + /// The number of output bins for each channel + int (&num_output_bins)[NUM_ACTIVE_CHANNELS]; + + /// The number of privatized bins for each channel + int (&num_privatized_bins)[NUM_ACTIVE_CHANNELS]; + + /// Reference to gmem privatized histograms for each channel + CounterT* d_privatized_histograms[NUM_ACTIVE_CHANNELS]; + + /// Reference to final output histograms (gmem) + CounterT* (&d_output_histograms)[NUM_ACTIVE_CHANNELS]; + + /// The transform operator for determining output bin-ids from privatized counter indices, one for each channel + OutputDecodeOpT (&output_decode_op)[NUM_ACTIVE_CHANNELS]; + + /// The transform operator for determining privatized counter indices from samples, one for each channel + PrivatizedDecodeOpT (&privatized_decode_op)[NUM_ACTIVE_CHANNELS]; + + /// Whether to prefer privatized smem counters vs privatized global counters + bool prefer_smem; + + + //--------------------------------------------------------------------- + // Initialize privatized bin counters + //--------------------------------------------------------------------- + + // Initialize privatized bin counters + __device__ __forceinline__ void InitBinCounters(CounterT* privatized_histograms[NUM_ACTIVE_CHANNELS]) + { + // Initialize histogram bin counts to zeros + #pragma unroll + for (int CHANNEL = 0; CHANNEL < NUM_ACTIVE_CHANNELS; ++CHANNEL) + { + for (int privatized_bin = threadIdx.x; privatized_bin < num_privatized_bins[CHANNEL]; privatized_bin += BLOCK_THREADS) + { + privatized_histograms[CHANNEL][privatized_bin] = 0; + } + } + + // Barrier to make sure all threads are done updating counters + CTA_SYNC(); + } + + + // Initialize privatized bin counters. Specialized for privatized shared-memory counters + __device__ __forceinline__ void InitSmemBinCounters() + { + CounterT* privatized_histograms[NUM_ACTIVE_CHANNELS]; + + for (int CHANNEL = 0; CHANNEL < NUM_ACTIVE_CHANNELS; ++CHANNEL) + privatized_histograms[CHANNEL] = temp_storage.histograms[CHANNEL]; + + InitBinCounters(privatized_histograms); + } + + + // Initialize privatized bin counters. Specialized for privatized global-memory counters + __device__ __forceinline__ void InitGmemBinCounters() + { + InitBinCounters(d_privatized_histograms); + } + + + //--------------------------------------------------------------------- + // Update final output histograms + //--------------------------------------------------------------------- + + // Update final output histograms from privatized histograms + __device__ __forceinline__ void StoreOutput(CounterT* privatized_histograms[NUM_ACTIVE_CHANNELS]) + { + // Barrier to make sure all threads are done updating counters + CTA_SYNC(); + + // Apply privatized bin counts to output bin counts + #pragma unroll + for (int CHANNEL = 0; CHANNEL < NUM_ACTIVE_CHANNELS; ++CHANNEL) + { + int channel_bins = num_privatized_bins[CHANNEL]; + for (int privatized_bin = threadIdx.x; + privatized_bin < channel_bins; + privatized_bin += BLOCK_THREADS) + { + int output_bin = -1; + CounterT count = privatized_histograms[CHANNEL][privatized_bin]; + bool is_valid = count > 0; + + output_decode_op[CHANNEL].template BinSelect((SampleT) privatized_bin, output_bin, is_valid); + + if (output_bin >= 0) + { + atomicAdd(&d_output_histograms[CHANNEL][output_bin], count); + } + + } + } + } + + + // Update final output histograms from privatized histograms. Specialized for privatized shared-memory counters + __device__ __forceinline__ void StoreSmemOutput() + { + CounterT* privatized_histograms[NUM_ACTIVE_CHANNELS]; + for (int CHANNEL = 0; CHANNEL < NUM_ACTIVE_CHANNELS; ++CHANNEL) + privatized_histograms[CHANNEL] = temp_storage.histograms[CHANNEL]; + + StoreOutput(privatized_histograms); + } + + + // Update final output histograms from privatized histograms. Specialized for privatized global-memory counters + __device__ __forceinline__ void StoreGmemOutput() + { + StoreOutput(d_privatized_histograms); + } + + + //--------------------------------------------------------------------- + // Tile accumulation + //--------------------------------------------------------------------- + + // Accumulate pixels. Specialized for RLE compression. + __device__ __forceinline__ void AccumulatePixels( + SampleT samples[PIXELS_PER_THREAD][NUM_CHANNELS], + bool is_valid[PIXELS_PER_THREAD], + CounterT* privatized_histograms[NUM_ACTIVE_CHANNELS], + Int2Type is_rle_compress) + { + #pragma unroll + for (int CHANNEL = 0; CHANNEL < NUM_ACTIVE_CHANNELS; ++CHANNEL) + { + // Bin pixels + int bins[PIXELS_PER_THREAD]; + + #pragma unroll + for (int PIXEL = 0; PIXEL < PIXELS_PER_THREAD; ++PIXEL) + { + bins[PIXEL] = -1; + privatized_decode_op[CHANNEL].template BinSelect(samples[PIXEL][CHANNEL], bins[PIXEL], is_valid[PIXEL]); + } + + CounterT accumulator = 1; + + #pragma unroll + for (int PIXEL = 0; PIXEL < PIXELS_PER_THREAD - 1; ++PIXEL) + { + if (bins[PIXEL] != bins[PIXEL + 1]) + { + if (bins[PIXEL] >= 0) + atomicAdd(privatized_histograms[CHANNEL] + bins[PIXEL], accumulator); + + accumulator = 0; + } + accumulator++; + } + + // Last pixel + if (bins[PIXELS_PER_THREAD - 1] >= 0) + atomicAdd(privatized_histograms[CHANNEL] + bins[PIXELS_PER_THREAD - 1], accumulator); + } + } + + + // Accumulate pixels. Specialized for individual accumulation of each pixel. + __device__ __forceinline__ void AccumulatePixels( + SampleT samples[PIXELS_PER_THREAD][NUM_CHANNELS], + bool is_valid[PIXELS_PER_THREAD], + CounterT* privatized_histograms[NUM_ACTIVE_CHANNELS], + Int2Type is_rle_compress) + { + #pragma unroll + for (int PIXEL = 0; PIXEL < PIXELS_PER_THREAD; ++PIXEL) + { + #pragma unroll + for (int CHANNEL = 0; CHANNEL < NUM_ACTIVE_CHANNELS; ++CHANNEL) + { + int bin = -1; + privatized_decode_op[CHANNEL].template BinSelect(samples[PIXEL][CHANNEL], bin, is_valid[PIXEL]); + if (bin >= 0) + atomicAdd(privatized_histograms[CHANNEL] + bin, 1); + } + } + } + + + /** + * Accumulate pixel, specialized for smem privatized histogram + */ + __device__ __forceinline__ void AccumulateSmemPixels( + SampleT samples[PIXELS_PER_THREAD][NUM_CHANNELS], + bool is_valid[PIXELS_PER_THREAD]) + { + CounterT* privatized_histograms[NUM_ACTIVE_CHANNELS]; + + for (int CHANNEL = 0; CHANNEL < NUM_ACTIVE_CHANNELS; ++CHANNEL) + privatized_histograms[CHANNEL] = temp_storage.histograms[CHANNEL]; + + AccumulatePixels(samples, is_valid, privatized_histograms, Int2Type()); + } + + + /** + * Accumulate pixel, specialized for gmem privatized histogram + */ + __device__ __forceinline__ void AccumulateGmemPixels( + SampleT samples[PIXELS_PER_THREAD][NUM_CHANNELS], + bool is_valid[PIXELS_PER_THREAD]) + { + AccumulatePixels(samples, is_valid, d_privatized_histograms, Int2Type()); + } + + + + //--------------------------------------------------------------------- + // Tile loading + //--------------------------------------------------------------------- + + // Load full, aligned tile using pixel iterator (multi-channel) + template + __device__ __forceinline__ void LoadFullAlignedTile( + OffsetT block_offset, + int valid_samples, + SampleT (&samples)[PIXELS_PER_THREAD][NUM_CHANNELS], + Int2Type<_NUM_ACTIVE_CHANNELS> num_active_channels) + { + typedef PixelT AliasedPixels[PIXELS_PER_THREAD]; + + WrappedPixelIteratorT d_wrapped_pixels((PixelT*) (d_native_samples + block_offset)); + + // Load using a wrapped pixel iterator + BlockLoadPixelT(temp_storage.aliasable.pixel_load).Load( + d_wrapped_pixels, + reinterpret_cast(samples)); + } + + // Load full, aligned tile using quad iterator (single-channel) + __device__ __forceinline__ void LoadFullAlignedTile( + OffsetT block_offset, + int valid_samples, + SampleT (&samples)[PIXELS_PER_THREAD][NUM_CHANNELS], + Int2Type<1> num_active_channels) + { + typedef QuadT AliasedQuads[QUADS_PER_THREAD]; + + WrappedQuadIteratorT d_wrapped_quads((QuadT*) (d_native_samples + block_offset)); + + // Load using a wrapped quad iterator + BlockLoadQuadT(temp_storage.aliasable.quad_load).Load( + d_wrapped_quads, + reinterpret_cast(samples)); + } + + // Load full, aligned tile + __device__ __forceinline__ void LoadTile( + OffsetT block_offset, + int valid_samples, + SampleT (&samples)[PIXELS_PER_THREAD][NUM_CHANNELS], + Int2Type is_full_tile, + Int2Type is_aligned) + { + LoadFullAlignedTile(block_offset, valid_samples, samples, Int2Type()); + } + + // Load full, mis-aligned tile using sample iterator + __device__ __forceinline__ void LoadTile( + OffsetT block_offset, + int valid_samples, + SampleT (&samples)[PIXELS_PER_THREAD][NUM_CHANNELS], + Int2Type is_full_tile, + Int2Type is_aligned) + { + typedef SampleT AliasedSamples[SAMPLES_PER_THREAD]; + + // Load using sample iterator + BlockLoadSampleT(temp_storage.aliasable.sample_load).Load( + d_wrapped_samples + block_offset, + reinterpret_cast(samples)); + } + + // Load partially-full, aligned tile using the pixel iterator + __device__ __forceinline__ void LoadTile( + OffsetT block_offset, + int valid_samples, + SampleT (&samples)[PIXELS_PER_THREAD][NUM_CHANNELS], + Int2Type is_full_tile, + Int2Type is_aligned) + { + typedef PixelT AliasedPixels[PIXELS_PER_THREAD]; + + WrappedPixelIteratorT d_wrapped_pixels((PixelT*) (d_native_samples + block_offset)); + + int valid_pixels = valid_samples / NUM_CHANNELS; + + // Load using a wrapped pixel iterator + BlockLoadPixelT(temp_storage.aliasable.pixel_load).Load( + d_wrapped_pixels, + reinterpret_cast(samples), + valid_pixels); + } + + // Load partially-full, mis-aligned tile using sample iterator + __device__ __forceinline__ void LoadTile( + OffsetT block_offset, + int valid_samples, + SampleT (&samples)[PIXELS_PER_THREAD][NUM_CHANNELS], + Int2Type is_full_tile, + Int2Type is_aligned) + { + typedef SampleT AliasedSamples[SAMPLES_PER_THREAD]; + + BlockLoadSampleT(temp_storage.aliasable.sample_load).Load( + d_wrapped_samples + block_offset, + reinterpret_cast(samples), + valid_samples); + } + + + //--------------------------------------------------------------------- + // Tile processing + //--------------------------------------------------------------------- + + // Consume a tile of data samples + template < + bool IS_ALIGNED, // Whether the tile offset is aligned (quad-aligned for single-channel, pixel-aligned for multi-channel) + bool IS_FULL_TILE> // Whether the tile is full + __device__ __forceinline__ void ConsumeTile(OffsetT block_offset, int valid_samples) + { + SampleT samples[PIXELS_PER_THREAD][NUM_CHANNELS]; + bool is_valid[PIXELS_PER_THREAD]; + + // Load tile + LoadTile( + block_offset, + valid_samples, + samples, + Int2Type(), + Int2Type()); + + // Set valid flags + #pragma unroll + for (int PIXEL = 0; PIXEL < PIXELS_PER_THREAD; ++PIXEL) + is_valid[PIXEL] = IS_FULL_TILE || (((threadIdx.x * PIXELS_PER_THREAD + PIXEL) * NUM_CHANNELS) < valid_samples); + + // Accumulate samples +#if CUB_PTX_ARCH >= 120 + if (prefer_smem) + AccumulateSmemPixels(samples, is_valid); + else + AccumulateGmemPixels(samples, is_valid); +#else + AccumulateGmemPixels(samples, is_valid); +#endif + + } + + + // Consume row tiles. Specialized for work-stealing from queue + template + __device__ __forceinline__ void ConsumeTiles( + OffsetT num_row_pixels, ///< The number of multi-channel pixels per row in the region of interest + OffsetT num_rows, ///< The number of rows in the region of interest + OffsetT row_stride_samples, ///< The number of samples between starts of consecutive rows in the region of interest + int tiles_per_row, ///< Number of image tiles per row + GridQueue tile_queue, + Int2Type is_work_stealing) + { + + int num_tiles = num_rows * tiles_per_row; + int tile_idx = (blockIdx.y * gridDim.x) + blockIdx.x; + OffsetT num_even_share_tiles = gridDim.x * gridDim.y; + + while (tile_idx < num_tiles) + { + int row = tile_idx / tiles_per_row; + int col = tile_idx - (row * tiles_per_row); + OffsetT row_offset = row * row_stride_samples; + OffsetT col_offset = (col * TILE_SAMPLES); + OffsetT tile_offset = row_offset + col_offset; + + if (col == tiles_per_row - 1) + { + // Consume a partially-full tile at the end of the row + OffsetT num_remaining = (num_row_pixels * NUM_CHANNELS) - col_offset; + ConsumeTile(tile_offset, num_remaining); + } + else + { + // Consume full tile + ConsumeTile(tile_offset, TILE_SAMPLES); + } + + CTA_SYNC(); + + // Get next tile + if (threadIdx.x == 0) + temp_storage.tile_idx = tile_queue.Drain(1) + num_even_share_tiles; + + CTA_SYNC(); + + tile_idx = temp_storage.tile_idx; + } + } + + + // Consume row tiles. Specialized for even-share (striped across thread blocks) + template + __device__ __forceinline__ void ConsumeTiles( + OffsetT num_row_pixels, ///< The number of multi-channel pixels per row in the region of interest + OffsetT num_rows, ///< The number of rows in the region of interest + OffsetT row_stride_samples, ///< The number of samples between starts of consecutive rows in the region of interest + int tiles_per_row, ///< Number of image tiles per row + GridQueue tile_queue, + Int2Type is_work_stealing) + { + for (int row = blockIdx.y; row < num_rows; row += gridDim.y) + { + OffsetT row_begin = row * row_stride_samples; + OffsetT row_end = row_begin + (num_row_pixels * NUM_CHANNELS); + OffsetT tile_offset = row_begin + (blockIdx.x * TILE_SAMPLES); + + while (tile_offset < row_end) + { + OffsetT num_remaining = row_end - tile_offset; + + if (num_remaining < TILE_SAMPLES) + { + // Consume partial tile + ConsumeTile(tile_offset, num_remaining); + break; + } + + // Consume full tile + ConsumeTile(tile_offset, TILE_SAMPLES); + tile_offset += gridDim.x * TILE_SAMPLES; + } + } + } + + + //--------------------------------------------------------------------- + // Parameter extraction + //--------------------------------------------------------------------- + + // Return a native pixel pointer (specialized for CacheModifiedInputIterator types) + template < + CacheLoadModifier _MODIFIER, + typename _ValueT, + typename _OffsetT> + __device__ __forceinline__ SampleT* NativePointer(CacheModifiedInputIterator<_MODIFIER, _ValueT, _OffsetT> itr) + { + return itr.ptr; + } + + // Return a native pixel pointer (specialized for other types) + template + __device__ __forceinline__ SampleT* NativePointer(IteratorT itr) + { + return NULL; + } + + + + //--------------------------------------------------------------------- + // Interface + //--------------------------------------------------------------------- + + + /** + * Constructor + */ + __device__ __forceinline__ AgentHistogram( + TempStorage &temp_storage, ///< Reference to temp_storage + SampleIteratorT d_samples, ///< Input data to reduce + int (&num_output_bins)[NUM_ACTIVE_CHANNELS], ///< The number bins per final output histogram + int (&num_privatized_bins)[NUM_ACTIVE_CHANNELS], ///< The number bins per privatized histogram + CounterT* (&d_output_histograms)[NUM_ACTIVE_CHANNELS], ///< Reference to final output histograms + CounterT* (&d_privatized_histograms)[NUM_ACTIVE_CHANNELS], ///< Reference to privatized histograms + OutputDecodeOpT (&output_decode_op)[NUM_ACTIVE_CHANNELS], ///< The transform operator for determining output bin-ids from privatized counter indices, one for each channel + PrivatizedDecodeOpT (&privatized_decode_op)[NUM_ACTIVE_CHANNELS]) ///< The transform operator for determining privatized counter indices from samples, one for each channel + : + temp_storage(temp_storage.Alias()), + d_wrapped_samples(d_samples), + num_output_bins(num_output_bins), + num_privatized_bins(num_privatized_bins), + d_output_histograms(d_output_histograms), + privatized_decode_op(privatized_decode_op), + output_decode_op(output_decode_op), + d_native_samples(NativePointer(d_wrapped_samples)), + prefer_smem((MEM_PREFERENCE == SMEM) ? + true : // prefer smem privatized histograms + (MEM_PREFERENCE == GMEM) ? + false : // prefer gmem privatized histograms + blockIdx.x & 1) // prefer blended privatized histograms + { + int blockId = (blockIdx.y * gridDim.x) + blockIdx.x; + + // Initialize the locations of this block's privatized histograms + for (int CHANNEL = 0; CHANNEL < NUM_ACTIVE_CHANNELS; ++CHANNEL) + this->d_privatized_histograms[CHANNEL] = d_privatized_histograms[CHANNEL] + (blockId * num_privatized_bins[CHANNEL]); + } + + + /** + * Consume image + */ + __device__ __forceinline__ void ConsumeTiles( + OffsetT num_row_pixels, ///< The number of multi-channel pixels per row in the region of interest + OffsetT num_rows, ///< The number of rows in the region of interest + OffsetT row_stride_samples, ///< The number of samples between starts of consecutive rows in the region of interest + int tiles_per_row, ///< Number of image tiles per row + GridQueue tile_queue) ///< Queue descriptor for assigning tiles of work to thread blocks + { + // Check whether all row starting offsets are quad-aligned (in single-channel) or pixel-aligned (in multi-channel) + int quad_mask = AlignBytes::ALIGN_BYTES - 1; + int pixel_mask = AlignBytes::ALIGN_BYTES - 1; + size_t row_bytes = sizeof(SampleT) * row_stride_samples; + + bool quad_aligned_rows = (NUM_CHANNELS == 1) && (SAMPLES_PER_THREAD % 4 == 0) && // Single channel + ((size_t(d_native_samples) & quad_mask) == 0) && // ptr is quad-aligned + ((num_rows == 1) || ((row_bytes & quad_mask) == 0)); // number of row-samples is a multiple of the alignment of the quad + + bool pixel_aligned_rows = (NUM_CHANNELS > 1) && // Multi channel + ((size_t(d_native_samples) & pixel_mask) == 0) && // ptr is pixel-aligned + ((row_bytes & pixel_mask) == 0); // number of row-samples is a multiple of the alignment of the pixel + + // Whether rows are aligned and can be vectorized + if ((d_native_samples != NULL) && (quad_aligned_rows || pixel_aligned_rows)) + ConsumeTiles(num_row_pixels, num_rows, row_stride_samples, tiles_per_row, tile_queue, Int2Type()); + else + ConsumeTiles(num_row_pixels, num_rows, row_stride_samples, tiles_per_row, tile_queue, Int2Type()); + } + + + /** + * Initialize privatized bin counters. Specialized for privatized shared-memory counters + */ + __device__ __forceinline__ void InitBinCounters() + { + if (prefer_smem) + InitSmemBinCounters(); + else + InitGmemBinCounters(); + } + + + /** + * Store privatized histogram to device-accessible memory. Specialized for privatized shared-memory counters + */ + __device__ __forceinline__ void StoreOutput() + { + if (prefer_smem) + StoreSmemOutput(); + else + StoreGmemOutput(); + } + + +}; + + + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/agent/agent_radix_sort_downsweep.cuh b/fastertransformer/cuda/cub/agent/agent_radix_sort_downsweep.cuh new file mode 100644 index 000000000..faea88138 --- /dev/null +++ b/fastertransformer/cuda/cub/agent/agent_radix_sort_downsweep.cuh @@ -0,0 +1,789 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * AgentRadixSortDownsweep implements a stateful abstraction of CUDA thread blocks for participating in device-wide radix sort downsweep . + */ + + +#pragma once + +#include + +#include "../thread/thread_load.cuh" +#include "../block/block_load.cuh" +#include "../block/block_store.cuh" +#include "../block/block_radix_rank.cuh" +#include "../block/block_exchange.cuh" +#include "../util_type.cuh" +#include "../iterator/cache_modified_input_iterator.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/****************************************************************************** + * Tuning policy types + ******************************************************************************/ + +/** + * Radix ranking algorithm + */ +enum RadixRankAlgorithm +{ + RADIX_RANK_BASIC, + RADIX_RANK_MEMOIZE, + RADIX_RANK_MATCH +}; + +/** + * Parameterizable tuning policy type for AgentRadixSortDownsweep + */ +template < + int _BLOCK_THREADS, ///< Threads per thread block + int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) + BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use + CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading keys (and values) + RadixRankAlgorithm _RANK_ALGORITHM, ///< The radix ranking algorithm to use + BlockScanAlgorithm _SCAN_ALGORITHM, ///< The block scan algorithm to use + int _RADIX_BITS> ///< The number of radix bits, i.e., log2(bins) +struct AgentRadixSortDownsweepPolicy +{ + enum + { + BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block + ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) + RADIX_BITS = _RADIX_BITS, ///< The number of radix bits, i.e., log2(bins) + }; + + static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use + static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading keys (and values) + static const RadixRankAlgorithm RANK_ALGORITHM = _RANK_ALGORITHM; ///< The radix ranking algorithm to use + static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use +}; + + +/****************************************************************************** + * Thread block abstractions + ******************************************************************************/ + + + + + +/** + * \brief AgentRadixSortDownsweep implements a stateful abstraction of CUDA thread blocks for participating in device-wide radix sort downsweep . + */ +template < + typename AgentRadixSortDownsweepPolicy, ///< Parameterized AgentRadixSortDownsweepPolicy tuning policy type + bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low + typename KeyT, ///< KeyT type + typename ValueT, ///< ValueT type + typename OffsetT> ///< Signed integer type for global offsets +struct AgentRadixSortDownsweep +{ + //--------------------------------------------------------------------- + // Type definitions and constants + //--------------------------------------------------------------------- + + // Appropriate unsigned-bits representation of KeyT + typedef typename Traits::UnsignedBits UnsignedBits; + + static const UnsignedBits LOWEST_KEY = Traits::LOWEST_KEY; + static const UnsignedBits MAX_KEY = Traits::MAX_KEY; + + static const BlockLoadAlgorithm LOAD_ALGORITHM = AgentRadixSortDownsweepPolicy::LOAD_ALGORITHM; + static const CacheLoadModifier LOAD_MODIFIER = AgentRadixSortDownsweepPolicy::LOAD_MODIFIER; + static const RadixRankAlgorithm RANK_ALGORITHM = AgentRadixSortDownsweepPolicy::RANK_ALGORITHM; + static const BlockScanAlgorithm SCAN_ALGORITHM = AgentRadixSortDownsweepPolicy::SCAN_ALGORITHM; + + enum + { + BLOCK_THREADS = AgentRadixSortDownsweepPolicy::BLOCK_THREADS, + ITEMS_PER_THREAD = AgentRadixSortDownsweepPolicy::ITEMS_PER_THREAD, + RADIX_BITS = AgentRadixSortDownsweepPolicy::RADIX_BITS, + TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, + + RADIX_DIGITS = 1 << RADIX_BITS, + KEYS_ONLY = Equals::VALUE, + }; + + // Input iterator wrapper type (for applying cache modifier)s + typedef CacheModifiedInputIterator KeysItr; + typedef CacheModifiedInputIterator ValuesItr; + + // Radix ranking type to use + typedef typename If<(RANK_ALGORITHM == RADIX_RANK_BASIC), + BlockRadixRank, + typename If<(RANK_ALGORITHM == RADIX_RANK_MEMOIZE), + BlockRadixRank, + BlockRadixRankMatch + >::Type + >::Type BlockRadixRankT; + + enum + { + /// Number of bin-starting offsets tracked per thread + BINS_TRACKED_PER_THREAD = BlockRadixRankT::BINS_TRACKED_PER_THREAD + }; + + // BlockLoad type (keys) + typedef BlockLoad< + UnsignedBits, + BLOCK_THREADS, + ITEMS_PER_THREAD, + LOAD_ALGORITHM> BlockLoadKeysT; + + // BlockLoad type (values) + typedef BlockLoad< + ValueT, + BLOCK_THREADS, + ITEMS_PER_THREAD, + LOAD_ALGORITHM> BlockLoadValuesT; + + // Value exchange array type + typedef ValueT ValueExchangeT[TILE_ITEMS]; + + /** + * Shared memory storage layout + */ + union __align__(16) _TempStorage + { + typename BlockLoadKeysT::TempStorage load_keys; + typename BlockLoadValuesT::TempStorage load_values; + typename BlockRadixRankT::TempStorage radix_rank; + + struct + { + UnsignedBits exchange_keys[TILE_ITEMS]; + OffsetT relative_bin_offsets[RADIX_DIGITS]; + }; + + Uninitialized exchange_values; + + OffsetT exclusive_digit_prefix[RADIX_DIGITS]; + }; + + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Thread fields + //--------------------------------------------------------------------- + + // Shared storage for this CTA + _TempStorage &temp_storage; + + // Input and output device pointers + KeysItr d_keys_in; + ValuesItr d_values_in; + UnsignedBits *d_keys_out; + ValueT *d_values_out; + + // The global scatter base offset for each digit (valid in the first RADIX_DIGITS threads) + OffsetT bin_offset[BINS_TRACKED_PER_THREAD]; + + // The least-significant bit position of the current digit to extract + int current_bit; + + // Number of bits in current digit + int num_bits; + + // Whether to short-cirucit + int short_circuit; + + //--------------------------------------------------------------------- + // Utility methods + //--------------------------------------------------------------------- + + + /** + * Scatter ranked keys through shared memory, then to device-accessible memory + */ + template + __device__ __forceinline__ void ScatterKeys( + UnsignedBits (&twiddled_keys)[ITEMS_PER_THREAD], + OffsetT (&relative_bin_offsets)[ITEMS_PER_THREAD], + int (&ranks)[ITEMS_PER_THREAD], + OffsetT valid_items) + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + temp_storage.exchange_keys[ranks[ITEM]] = twiddled_keys[ITEM]; + } + + CTA_SYNC(); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + UnsignedBits key = temp_storage.exchange_keys[threadIdx.x + (ITEM * BLOCK_THREADS)]; + UnsignedBits digit = BFE(key, current_bit, num_bits); + relative_bin_offsets[ITEM] = temp_storage.relative_bin_offsets[digit]; + + // Un-twiddle + key = Traits::TwiddleOut(key); + + if (FULL_TILE || + (static_cast(threadIdx.x + (ITEM * BLOCK_THREADS)) < valid_items)) + { + d_keys_out[relative_bin_offsets[ITEM] + threadIdx.x + (ITEM * BLOCK_THREADS)] = key; + } + } + } + + + /** + * Scatter ranked values through shared memory, then to device-accessible memory + */ + template + __device__ __forceinline__ void ScatterValues( + ValueT (&values)[ITEMS_PER_THREAD], + OffsetT (&relative_bin_offsets)[ITEMS_PER_THREAD], + int (&ranks)[ITEMS_PER_THREAD], + OffsetT valid_items) + { + CTA_SYNC(); + + ValueExchangeT &exchange_values = temp_storage.exchange_values.Alias(); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + exchange_values[ranks[ITEM]] = values[ITEM]; + } + + CTA_SYNC(); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + ValueT value = exchange_values[threadIdx.x + (ITEM * BLOCK_THREADS)]; + + if (FULL_TILE || + (static_cast(threadIdx.x + (ITEM * BLOCK_THREADS)) < valid_items)) + { + d_values_out[relative_bin_offsets[ITEM] + threadIdx.x + (ITEM * BLOCK_THREADS)] = value; + } + } + } + + /** + * Load a tile of keys (specialized for full tile, any ranking algorithm) + */ + template + __device__ __forceinline__ void LoadKeys( + UnsignedBits (&keys)[ITEMS_PER_THREAD], + OffsetT block_offset, + OffsetT valid_items, + UnsignedBits oob_item, + Int2Type is_full_tile, + Int2Type<_RANK_ALGORITHM> rank_algorithm) + { + BlockLoadKeysT(temp_storage.load_keys).Load( + d_keys_in + block_offset, keys); + + CTA_SYNC(); + } + + + /** + * Load a tile of keys (specialized for partial tile, any ranking algorithm) + */ + template + __device__ __forceinline__ void LoadKeys( + UnsignedBits (&keys)[ITEMS_PER_THREAD], + OffsetT block_offset, + OffsetT valid_items, + UnsignedBits oob_item, + Int2Type is_full_tile, + Int2Type<_RANK_ALGORITHM> rank_algorithm) + { + // Register pressure work-around: moving valid_items through shfl prevents compiler + // from reusing guards/addressing from prior guarded loads + valid_items = ShuffleIndex(valid_items, 0, 0xffffffff); + + BlockLoadKeysT(temp_storage.load_keys).Load( + d_keys_in + block_offset, keys, valid_items, oob_item); + + CTA_SYNC(); + } + + + /** + * Load a tile of keys (specialized for full tile, match ranking algorithm) + */ + __device__ __forceinline__ void LoadKeys( + UnsignedBits (&keys)[ITEMS_PER_THREAD], + OffsetT block_offset, + OffsetT valid_items, + UnsignedBits oob_item, + Int2Type is_full_tile, + Int2Type rank_algorithm) + { + LoadDirectWarpStriped(threadIdx.x, d_keys_in + block_offset, keys); + } + + + /** + * Load a tile of keys (specialized for partial tile, match ranking algorithm) + */ + __device__ __forceinline__ void LoadKeys( + UnsignedBits (&keys)[ITEMS_PER_THREAD], + OffsetT block_offset, + OffsetT valid_items, + UnsignedBits oob_item, + Int2Type is_full_tile, + Int2Type rank_algorithm) + { + // Register pressure work-around: moving valid_items through shfl prevents compiler + // from reusing guards/addressing from prior guarded loads + valid_items = ShuffleIndex(valid_items, 0, 0xffffffff); + + LoadDirectWarpStriped(threadIdx.x, d_keys_in + block_offset, keys, valid_items, oob_item); + } + + + /** + * Load a tile of values (specialized for full tile, any ranking algorithm) + */ + template + __device__ __forceinline__ void LoadValues( + ValueT (&values)[ITEMS_PER_THREAD], + OffsetT block_offset, + OffsetT valid_items, + Int2Type is_full_tile, + Int2Type<_RANK_ALGORITHM> rank_algorithm) + { + BlockLoadValuesT(temp_storage.load_values).Load( + d_values_in + block_offset, values); + + CTA_SYNC(); + } + + + /** + * Load a tile of values (specialized for partial tile, any ranking algorithm) + */ + template + __device__ __forceinline__ void LoadValues( + ValueT (&values)[ITEMS_PER_THREAD], + OffsetT block_offset, + OffsetT valid_items, + Int2Type is_full_tile, + Int2Type<_RANK_ALGORITHM> rank_algorithm) + { + // Register pressure work-around: moving valid_items through shfl prevents compiler + // from reusing guards/addressing from prior guarded loads + valid_items = ShuffleIndex(valid_items, 0, 0xffffffff); + + BlockLoadValuesT(temp_storage.load_values).Load( + d_values_in + block_offset, values, valid_items); + + CTA_SYNC(); + } + + + /** + * Load a tile of items (specialized for full tile, match ranking algorithm) + */ + __device__ __forceinline__ void LoadValues( + ValueT (&values)[ITEMS_PER_THREAD], + OffsetT block_offset, + OffsetT valid_items, + Int2Type is_full_tile, + Int2Type rank_algorithm) + { + LoadDirectWarpStriped(threadIdx.x, d_values_in + block_offset, values); + } + + + /** + * Load a tile of items (specialized for partial tile, match ranking algorithm) + */ + __device__ __forceinline__ void LoadValues( + ValueT (&values)[ITEMS_PER_THREAD], + OffsetT block_offset, + OffsetT valid_items, + Int2Type is_full_tile, + Int2Type rank_algorithm) + { + // Register pressure work-around: moving valid_items through shfl prevents compiler + // from reusing guards/addressing from prior guarded loads + valid_items = ShuffleIndex(valid_items, 0, 0xffffffff); + + LoadDirectWarpStriped(threadIdx.x, d_values_in + block_offset, values, valid_items); + } + + + /** + * Truck along associated values + */ + template + __device__ __forceinline__ void GatherScatterValues( + OffsetT (&relative_bin_offsets)[ITEMS_PER_THREAD], + int (&ranks)[ITEMS_PER_THREAD], + OffsetT block_offset, + OffsetT valid_items, + Int2Type /*is_keys_only*/) + { + ValueT values[ITEMS_PER_THREAD]; + + CTA_SYNC(); + + LoadValues( + values, + block_offset, + valid_items, + Int2Type(), + Int2Type()); + + ScatterValues( + values, + relative_bin_offsets, + ranks, + valid_items); + } + + + /** + * Truck along associated values (specialized for key-only sorting) + */ + template + __device__ __forceinline__ void GatherScatterValues( + OffsetT (&/*relative_bin_offsets*/)[ITEMS_PER_THREAD], + int (&/*ranks*/)[ITEMS_PER_THREAD], + OffsetT /*block_offset*/, + OffsetT /*valid_items*/, + Int2Type /*is_keys_only*/) + {} + + + /** + * Process tile + */ + template + __device__ __forceinline__ void ProcessTile( + OffsetT block_offset, + const OffsetT &valid_items = TILE_ITEMS) + { + UnsignedBits keys[ITEMS_PER_THREAD]; + int ranks[ITEMS_PER_THREAD]; + OffsetT relative_bin_offsets[ITEMS_PER_THREAD]; + + // Assign default (min/max) value to all keys + UnsignedBits default_key = (IS_DESCENDING) ? LOWEST_KEY : MAX_KEY; + + // Load tile of keys + LoadKeys( + keys, + block_offset, + valid_items, + default_key, + Int2Type(), + Int2Type()); + + // Twiddle key bits if necessary + #pragma unroll + for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) + { + keys[KEY] = Traits::TwiddleIn(keys[KEY]); + } + + // Rank the twiddled keys + int exclusive_digit_prefix[BINS_TRACKED_PER_THREAD]; + BlockRadixRankT(temp_storage.radix_rank).RankKeys( + keys, + ranks, + current_bit, + num_bits, + exclusive_digit_prefix); + + CTA_SYNC(); + + // Share exclusive digit prefix + #pragma unroll + for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) + { + int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track; + if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) + { + // Store exclusive prefix + temp_storage.exclusive_digit_prefix[bin_idx] = + exclusive_digit_prefix[track]; + } + } + + CTA_SYNC(); + + // Get inclusive digit prefix + int inclusive_digit_prefix[BINS_TRACKED_PER_THREAD]; + + #pragma unroll + for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) + { + int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track; + if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) + { + if (IS_DESCENDING) + { + // Get inclusive digit prefix from exclusive prefix (higher bins come first) + inclusive_digit_prefix[track] = (bin_idx == 0) ? + (BLOCK_THREADS * ITEMS_PER_THREAD) : + temp_storage.exclusive_digit_prefix[bin_idx - 1]; + } + else + { + // Get inclusive digit prefix from exclusive prefix (lower bins come first) + inclusive_digit_prefix[track] = (bin_idx == RADIX_DIGITS - 1) ? + (BLOCK_THREADS * ITEMS_PER_THREAD) : + temp_storage.exclusive_digit_prefix[bin_idx + 1]; + } + } + } + + CTA_SYNC(); + + // Update global scatter base offsets for each digit + #pragma unroll + for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) + { + int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track; + if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) + { + bin_offset[track] -= exclusive_digit_prefix[track]; + temp_storage.relative_bin_offsets[bin_idx] = bin_offset[track]; + bin_offset[track] += inclusive_digit_prefix[track]; + } + } + + CTA_SYNC(); + + // Scatter keys + ScatterKeys(keys, relative_bin_offsets, ranks, valid_items); + + // Gather/scatter values + GatherScatterValues(relative_bin_offsets , ranks, block_offset, valid_items, Int2Type()); + } + + //--------------------------------------------------------------------- + // Copy shortcut + //--------------------------------------------------------------------- + + /** + * Copy tiles within the range of input + */ + template < + typename InputIteratorT, + typename T> + __device__ __forceinline__ void Copy( + InputIteratorT d_in, + T *d_out, + OffsetT block_offset, + OffsetT block_end) + { + // Simply copy the input + while (block_offset + TILE_ITEMS <= block_end) + { + T items[ITEMS_PER_THREAD]; + + LoadDirectStriped(threadIdx.x, d_in + block_offset, items); + CTA_SYNC(); + StoreDirectStriped(threadIdx.x, d_out + block_offset, items); + + block_offset += TILE_ITEMS; + } + + // Clean up last partial tile with guarded-I/O + if (block_offset < block_end) + { + OffsetT valid_items = block_end - block_offset; + + T items[ITEMS_PER_THREAD]; + + LoadDirectStriped(threadIdx.x, d_in + block_offset, items, valid_items); + CTA_SYNC(); + StoreDirectStriped(threadIdx.x, d_out + block_offset, items, valid_items); + } + } + + + /** + * Copy tiles within the range of input (specialized for NullType) + */ + template + __device__ __forceinline__ void Copy( + InputIteratorT /*d_in*/, + NullType * /*d_out*/, + OffsetT /*block_offset*/, + OffsetT /*block_end*/) + {} + + + //--------------------------------------------------------------------- + // Interface + //--------------------------------------------------------------------- + + /** + * Constructor + */ + __device__ __forceinline__ AgentRadixSortDownsweep( + TempStorage &temp_storage, + OffsetT (&bin_offset)[BINS_TRACKED_PER_THREAD], + OffsetT num_items, + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + int current_bit, + int num_bits) + : + temp_storage(temp_storage.Alias()), + d_keys_in(reinterpret_cast(d_keys_in)), + d_values_in(d_values_in), + d_keys_out(reinterpret_cast(d_keys_out)), + d_values_out(d_values_out), + current_bit(current_bit), + num_bits(num_bits), + short_circuit(1) + { + #pragma unroll + for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) + { + this->bin_offset[track] = bin_offset[track]; + + int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track; + if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) + { + // Short circuit if the histogram has only bin counts of only zeros or problem-size + short_circuit = short_circuit && ((bin_offset[track] == 0) || (bin_offset[track] == num_items)); + } + } + + short_circuit = CTA_SYNC_AND(short_circuit); + } + + + /** + * Constructor + */ + __device__ __forceinline__ AgentRadixSortDownsweep( + TempStorage &temp_storage, + OffsetT num_items, + OffsetT *d_spine, + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + int current_bit, + int num_bits) + : + temp_storage(temp_storage.Alias()), + d_keys_in(reinterpret_cast(d_keys_in)), + d_values_in(d_values_in), + d_keys_out(reinterpret_cast(d_keys_out)), + d_values_out(d_values_out), + current_bit(current_bit), + num_bits(num_bits), + short_circuit(1) + { + #pragma unroll + for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) + { + int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track; + + // Load digit bin offsets (each of the first RADIX_DIGITS threads will load an offset for that digit) + if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) + { + if (IS_DESCENDING) + bin_idx = RADIX_DIGITS - bin_idx - 1; + + // Short circuit if the first block's histogram has only bin counts of only zeros or problem-size + OffsetT first_block_bin_offset = d_spine[gridDim.x * bin_idx]; + short_circuit = short_circuit && ((first_block_bin_offset == 0) || (first_block_bin_offset == num_items)); + + // Load my block's bin offset for my bin + bin_offset[track] = d_spine[(gridDim.x * bin_idx) + blockIdx.x]; + } + } + + short_circuit = CTA_SYNC_AND(short_circuit); + } + + + /** + * Distribute keys from a segment of input tiles. + */ + __device__ __forceinline__ void ProcessRegion( + OffsetT block_offset, + OffsetT block_end) + { + if (short_circuit) + { + // Copy keys + Copy(d_keys_in, d_keys_out, block_offset, block_end); + + // Copy values + Copy(d_values_in, d_values_out, block_offset, block_end); + } + else + { + // Process full tiles of tile_items + #pragma unroll 1 + while (block_offset + TILE_ITEMS <= block_end) + { + ProcessTile(block_offset); + block_offset += TILE_ITEMS; + + CTA_SYNC(); + } + + // Clean up last partial tile with guarded-I/O + if (block_offset < block_end) + { + ProcessTile(block_offset, block_end - block_offset); + } + + } + } + +}; + + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/agent/agent_radix_sort_upsweep.cuh b/fastertransformer/cuda/cub/agent/agent_radix_sort_upsweep.cuh new file mode 100644 index 000000000..2081cefba --- /dev/null +++ b/fastertransformer/cuda/cub/agent/agent_radix_sort_upsweep.cuh @@ -0,0 +1,526 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * AgentRadixSortUpsweep implements a stateful abstraction of CUDA thread blocks for participating in device-wide radix sort upsweep . + */ + +#pragma once + +#include "../thread/thread_reduce.cuh" +#include "../thread/thread_load.cuh" +#include "../warp/warp_reduce.cuh" +#include "../block/block_load.cuh" +#include "../util_type.cuh" +#include "../iterator/cache_modified_input_iterator.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/****************************************************************************** + * Tuning policy types + ******************************************************************************/ + +/** + * Parameterizable tuning policy type for AgentRadixSortUpsweep + */ +template < + int _BLOCK_THREADS, ///< Threads per thread block + int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) + CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading keys + int _RADIX_BITS> ///< The number of radix bits, i.e., log2(bins) +struct AgentRadixSortUpsweepPolicy +{ + enum + { + BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block + ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) + RADIX_BITS = _RADIX_BITS, ///< The number of radix bits, i.e., log2(bins) + }; + + static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading keys +}; + + +/****************************************************************************** + * Thread block abstractions + ******************************************************************************/ + +/** + * \brief AgentRadixSortUpsweep implements a stateful abstraction of CUDA thread blocks for participating in device-wide radix sort upsweep . + */ +template < + typename AgentRadixSortUpsweepPolicy, ///< Parameterized AgentRadixSortUpsweepPolicy tuning policy type + typename KeyT, ///< KeyT type + typename OffsetT> ///< Signed integer type for global offsets +struct AgentRadixSortUpsweep +{ + + //--------------------------------------------------------------------- + // Type definitions and constants + //--------------------------------------------------------------------- + + typedef typename Traits::UnsignedBits UnsignedBits; + + // Integer type for digit counters (to be packed into words of PackedCounters) + typedef unsigned char DigitCounter; + + // Integer type for packing DigitCounters into columns of shared memory banks + typedef unsigned int PackedCounter; + + static const CacheLoadModifier LOAD_MODIFIER = AgentRadixSortUpsweepPolicy::LOAD_MODIFIER; + + enum + { + RADIX_BITS = AgentRadixSortUpsweepPolicy::RADIX_BITS, + BLOCK_THREADS = AgentRadixSortUpsweepPolicy::BLOCK_THREADS, + KEYS_PER_THREAD = AgentRadixSortUpsweepPolicy::ITEMS_PER_THREAD, + + RADIX_DIGITS = 1 << RADIX_BITS, + + LOG_WARP_THREADS = CUB_PTX_LOG_WARP_THREADS, + WARP_THREADS = 1 << LOG_WARP_THREADS, + WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, + + TILE_ITEMS = BLOCK_THREADS * KEYS_PER_THREAD, + + BYTES_PER_COUNTER = sizeof(DigitCounter), + LOG_BYTES_PER_COUNTER = Log2::VALUE, + + PACKING_RATIO = sizeof(PackedCounter) / sizeof(DigitCounter), + LOG_PACKING_RATIO = Log2::VALUE, + + LOG_COUNTER_LANES = CUB_MAX(0, RADIX_BITS - LOG_PACKING_RATIO), + COUNTER_LANES = 1 << LOG_COUNTER_LANES, + + // To prevent counter overflow, we must periodically unpack and aggregate the + // digit counters back into registers. Each counter lane is assigned to a + // warp for aggregation. + + LANES_PER_WARP = CUB_MAX(1, (COUNTER_LANES + WARPS - 1) / WARPS), + + // Unroll tiles in batches without risk of counter overflow + UNROLL_COUNT = CUB_MIN(64, 255 / KEYS_PER_THREAD), + UNROLLED_ELEMENTS = UNROLL_COUNT * TILE_ITEMS, + }; + + + // Input iterator wrapper type (for applying cache modifier)s + typedef CacheModifiedInputIterator KeysItr; + + /** + * Shared memory storage layout + */ + union __align__(16) _TempStorage + { + DigitCounter thread_counters[COUNTER_LANES][BLOCK_THREADS][PACKING_RATIO]; + PackedCounter packed_thread_counters[COUNTER_LANES][BLOCK_THREADS]; + OffsetT block_counters[WARP_THREADS][RADIX_DIGITS]; + }; + + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Thread fields (aggregate state bundle) + //--------------------------------------------------------------------- + + // Shared storage for this CTA + _TempStorage &temp_storage; + + // Thread-local counters for periodically aggregating composite-counter lanes + OffsetT local_counts[LANES_PER_WARP][PACKING_RATIO]; + + // Input and output device pointers + KeysItr d_keys_in; + + // The least-significant bit position of the current digit to extract + int current_bit; + + // Number of bits in current digit + int num_bits; + + + + //--------------------------------------------------------------------- + // Helper structure for templated iteration + //--------------------------------------------------------------------- + + // Iterate + template + struct Iterate + { + // BucketKeys + static __device__ __forceinline__ void BucketKeys( + AgentRadixSortUpsweep &cta, + UnsignedBits keys[KEYS_PER_THREAD]) + { + cta.Bucket(keys[COUNT]); + + // Next + Iterate::BucketKeys(cta, keys); + } + }; + + // Terminate + template + struct Iterate + { + // BucketKeys + static __device__ __forceinline__ void BucketKeys(AgentRadixSortUpsweep &/*cta*/, UnsignedBits /*keys*/[KEYS_PER_THREAD]) {} + }; + + + //--------------------------------------------------------------------- + // Utility methods + //--------------------------------------------------------------------- + + /** + * Decode a key and increment corresponding smem digit counter + */ + __device__ __forceinline__ void Bucket(UnsignedBits key) + { + // Perform transform op + UnsignedBits converted_key = Traits::TwiddleIn(key); + + // Extract current digit bits + UnsignedBits digit = BFE(converted_key, current_bit, num_bits); + + // Get sub-counter offset + UnsignedBits sub_counter = digit & (PACKING_RATIO - 1); + + // Get row offset + UnsignedBits row_offset = digit >> LOG_PACKING_RATIO; + + // Increment counter + temp_storage.thread_counters[row_offset][threadIdx.x][sub_counter]++; + } + + + /** + * Reset composite counters + */ + __device__ __forceinline__ void ResetDigitCounters() + { + #pragma unroll + for (int LANE = 0; LANE < COUNTER_LANES; LANE++) + { + temp_storage.packed_thread_counters[LANE][threadIdx.x] = 0; + } + } + + + /** + * Reset the unpacked counters in each thread + */ + __device__ __forceinline__ void ResetUnpackedCounters() + { + #pragma unroll + for (int LANE = 0; LANE < LANES_PER_WARP; LANE++) + { + #pragma unroll + for (int UNPACKED_COUNTER = 0; UNPACKED_COUNTER < PACKING_RATIO; UNPACKED_COUNTER++) + { + local_counts[LANE][UNPACKED_COUNTER] = 0; + } + } + } + + + /** + * Extracts and aggregates the digit counters for each counter lane + * owned by this warp + */ + __device__ __forceinline__ void UnpackDigitCounts() + { + unsigned int warp_id = threadIdx.x >> LOG_WARP_THREADS; + unsigned int warp_tid = LaneId(); + + #pragma unroll + for (int LANE = 0; LANE < LANES_PER_WARP; LANE++) + { + const int counter_lane = (LANE * WARPS) + warp_id; + if (counter_lane < COUNTER_LANES) + { + #pragma unroll + for (int PACKED_COUNTER = 0; PACKED_COUNTER < BLOCK_THREADS; PACKED_COUNTER += WARP_THREADS) + { + #pragma unroll + for (int UNPACKED_COUNTER = 0; UNPACKED_COUNTER < PACKING_RATIO; UNPACKED_COUNTER++) + { + OffsetT counter = temp_storage.thread_counters[counter_lane][warp_tid + PACKED_COUNTER][UNPACKED_COUNTER]; + local_counts[LANE][UNPACKED_COUNTER] += counter; + } + } + } + } + } + + + /** + * Processes a single, full tile + */ + __device__ __forceinline__ void ProcessFullTile(OffsetT block_offset) + { + // Tile of keys + UnsignedBits keys[KEYS_PER_THREAD]; + + LoadDirectStriped(threadIdx.x, d_keys_in + block_offset, keys); + + // Prevent hoisting + CTA_SYNC(); + + // Bucket tile of keys + Iterate<0, KEYS_PER_THREAD>::BucketKeys(*this, keys); + } + + + /** + * Processes a single load (may have some threads masked off) + */ + __device__ __forceinline__ void ProcessPartialTile( + OffsetT block_offset, + const OffsetT &block_end) + { + // Process partial tile if necessary using single loads + block_offset += threadIdx.x; + while (block_offset < block_end) + { + // Load and bucket key + UnsignedBits key = d_keys_in[block_offset]; + Bucket(key); + block_offset += BLOCK_THREADS; + } + } + + + //--------------------------------------------------------------------- + // Interface + //--------------------------------------------------------------------- + + /** + * Constructor + */ + __device__ __forceinline__ AgentRadixSortUpsweep( + TempStorage &temp_storage, + const KeyT *d_keys_in, + int current_bit, + int num_bits) + : + temp_storage(temp_storage.Alias()), + d_keys_in(reinterpret_cast(d_keys_in)), + current_bit(current_bit), + num_bits(num_bits) + {} + + + /** + * Compute radix digit histograms from a segment of input tiles. + */ + __device__ __forceinline__ void ProcessRegion( + OffsetT block_offset, + const OffsetT &block_end) + { + // Reset digit counters in smem and unpacked counters in registers + ResetDigitCounters(); + ResetUnpackedCounters(); + + // Unroll batches of full tiles + while (block_offset + UNROLLED_ELEMENTS <= block_end) + { + for (int i = 0; i < UNROLL_COUNT; ++i) + { + ProcessFullTile(block_offset); + block_offset += TILE_ITEMS; + } + + CTA_SYNC(); + + // Aggregate back into local_count registers to prevent overflow + UnpackDigitCounts(); + + CTA_SYNC(); + + // Reset composite counters in lanes + ResetDigitCounters(); + } + + // Unroll single full tiles + while (block_offset + TILE_ITEMS <= block_end) + { + ProcessFullTile(block_offset); + block_offset += TILE_ITEMS; + } + + // Process partial tile if necessary + ProcessPartialTile( + block_offset, + block_end); + + CTA_SYNC(); + + // Aggregate back into local_count registers + UnpackDigitCounts(); + } + + + /** + * Extract counts (saving them to the external array) + */ + template + __device__ __forceinline__ void ExtractCounts( + OffsetT *counters, + int bin_stride = 1, + int bin_offset = 0) + { + unsigned int warp_id = threadIdx.x >> LOG_WARP_THREADS; + unsigned int warp_tid = LaneId(); + + // Place unpacked digit counters in shared memory + #pragma unroll + for (int LANE = 0; LANE < LANES_PER_WARP; LANE++) + { + int counter_lane = (LANE * WARPS) + warp_id; + if (counter_lane < COUNTER_LANES) + { + int digit_row = counter_lane << LOG_PACKING_RATIO; + + #pragma unroll + for (int UNPACKED_COUNTER = 0; UNPACKED_COUNTER < PACKING_RATIO; UNPACKED_COUNTER++) + { + int bin_idx = digit_row + UNPACKED_COUNTER; + + temp_storage.block_counters[warp_tid][bin_idx] = + local_counts[LANE][UNPACKED_COUNTER]; + } + } + } + + CTA_SYNC(); + + // Rake-reduce bin_count reductions + + // Whole blocks + #pragma unroll + for (int BIN_BASE = RADIX_DIGITS % BLOCK_THREADS; + (BIN_BASE + BLOCK_THREADS) <= RADIX_DIGITS; + BIN_BASE += BLOCK_THREADS) + { + int bin_idx = BIN_BASE + threadIdx.x; + + OffsetT bin_count = 0; + #pragma unroll + for (int i = 0; i < WARP_THREADS; ++i) + bin_count += temp_storage.block_counters[i][bin_idx]; + + if (IS_DESCENDING) + bin_idx = RADIX_DIGITS - bin_idx - 1; + + counters[(bin_stride * bin_idx) + bin_offset] = bin_count; + } + + // Remainder + if ((RADIX_DIGITS % BLOCK_THREADS != 0) && (threadIdx.x < RADIX_DIGITS)) + { + int bin_idx = threadIdx.x; + + OffsetT bin_count = 0; + #pragma unroll + for (int i = 0; i < WARP_THREADS; ++i) + bin_count += temp_storage.block_counters[i][bin_idx]; + + if (IS_DESCENDING) + bin_idx = RADIX_DIGITS - bin_idx - 1; + + counters[(bin_stride * bin_idx) + bin_offset] = bin_count; + } + } + + + /** + * Extract counts + */ + template + __device__ __forceinline__ void ExtractCounts( + OffsetT (&bin_count)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] + { + unsigned int warp_id = threadIdx.x >> LOG_WARP_THREADS; + unsigned int warp_tid = LaneId(); + + // Place unpacked digit counters in shared memory + #pragma unroll + for (int LANE = 0; LANE < LANES_PER_WARP; LANE++) + { + int counter_lane = (LANE * WARPS) + warp_id; + if (counter_lane < COUNTER_LANES) + { + int digit_row = counter_lane << LOG_PACKING_RATIO; + + #pragma unroll + for (int UNPACKED_COUNTER = 0; UNPACKED_COUNTER < PACKING_RATIO; UNPACKED_COUNTER++) + { + int bin_idx = digit_row + UNPACKED_COUNTER; + + temp_storage.block_counters[warp_tid][bin_idx] = + local_counts[LANE][UNPACKED_COUNTER]; + } + } + } + + CTA_SYNC(); + + // Rake-reduce bin_count reductions + #pragma unroll + for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) + { + int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track; + + if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) + { + bin_count[track] = 0; + + #pragma unroll + for (int i = 0; i < WARP_THREADS; ++i) + bin_count[track] += temp_storage.block_counters[i][bin_idx]; + } + } + } + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/agent/agent_reduce.cuh b/fastertransformer/cuda/cub/agent/agent_reduce.cuh new file mode 100644 index 000000000..000a905cc --- /dev/null +++ b/fastertransformer/cuda/cub/agent/agent_reduce.cuh @@ -0,0 +1,385 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::AgentReduce implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduction . + */ + +#pragma once + +#include + +#include "../block/block_load.cuh" +#include "../block/block_reduce.cuh" +#include "../grid/grid_mapping.cuh" +#include "../grid/grid_even_share.cuh" +#include "../util_type.cuh" +#include "../iterator/cache_modified_input_iterator.cuh" +#include "../util_namespace.cuh" + + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/****************************************************************************** + * Tuning policy types + ******************************************************************************/ + +/** + * Parameterizable tuning policy type for AgentReduce + */ +template < + int _BLOCK_THREADS, ///< Threads per thread block + int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) + int _VECTOR_LOAD_LENGTH, ///< Number of items per vectorized load + BlockReduceAlgorithm _BLOCK_ALGORITHM, ///< Cooperative block-wide reduction algorithm to use + CacheLoadModifier _LOAD_MODIFIER> ///< Cache load modifier for reading input elements +struct AgentReducePolicy +{ + enum + { + BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block + ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) + VECTOR_LOAD_LENGTH = _VECTOR_LOAD_LENGTH, ///< Number of items per vectorized load + }; + + static const BlockReduceAlgorithm BLOCK_ALGORITHM = _BLOCK_ALGORITHM; ///< Cooperative block-wide reduction algorithm to use + static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements +}; + + + +/****************************************************************************** + * Thread block abstractions + ******************************************************************************/ + +/** + * \brief AgentReduce implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduction . + * + * Each thread reduces only the values it loads. If \p FIRST_TILE, this + * partial reduction is stored into \p thread_aggregate. Otherwise it is + * accumulated into \p thread_aggregate. + */ +template < + typename AgentReducePolicy, ///< Parameterized AgentReducePolicy tuning policy type + typename InputIteratorT, ///< Random-access iterator type for input + typename OutputIteratorT, ///< Random-access iterator type for output + typename OffsetT, ///< Signed integer type for global offsets + typename ReductionOp> ///< Binary reduction operator type having member T operator()(const T &a, const T &b) +struct AgentReduce +{ + + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + /// The input value type + typedef typename std::iterator_traits::value_type InputT; + + /// The output value type + typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? + typename std::iterator_traits::value_type, // ... then the input iterator's value type, + typename std::iterator_traits::value_type>::Type OutputT; // ... else the output iterator's value type + + /// Vector type of InputT for data movement + typedef typename CubVector::Type VectorT; + + /// Input iterator wrapper type (for applying cache modifier) + typedef typename If::VALUE, + CacheModifiedInputIterator, // Wrap the native input pointer with CacheModifiedInputIterator + InputIteratorT>::Type // Directly use the supplied input iterator type + WrappedInputIteratorT; + + /// Constants + enum + { + BLOCK_THREADS = AgentReducePolicy::BLOCK_THREADS, + ITEMS_PER_THREAD = AgentReducePolicy::ITEMS_PER_THREAD, + VECTOR_LOAD_LENGTH = CUB_MIN(ITEMS_PER_THREAD, AgentReducePolicy::VECTOR_LOAD_LENGTH), + TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, + + // Can vectorize according to the policy if the input iterator is a native pointer to a primitive type + ATTEMPT_VECTORIZATION = (VECTOR_LOAD_LENGTH > 1) && + (ITEMS_PER_THREAD % VECTOR_LOAD_LENGTH == 0) && + (IsPointer::VALUE) && Traits::PRIMITIVE, + + }; + + static const CacheLoadModifier LOAD_MODIFIER = AgentReducePolicy::LOAD_MODIFIER; + static const BlockReduceAlgorithm BLOCK_ALGORITHM = AgentReducePolicy::BLOCK_ALGORITHM; + + /// Parameterized BlockReduce primitive + typedef BlockReduce BlockReduceT; + + /// Shared memory type required by this thread block + struct _TempStorage + { + typename BlockReduceT::TempStorage reduce; + }; + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + _TempStorage& temp_storage; ///< Reference to temp_storage + InputIteratorT d_in; ///< Input data to reduce + WrappedInputIteratorT d_wrapped_in; ///< Wrapped input data to reduce + ReductionOp reduction_op; ///< Binary reduction operator + + + //--------------------------------------------------------------------- + // Utility + //--------------------------------------------------------------------- + + + // Whether or not the input is aligned with the vector type (specialized for types we can vectorize) + template + static __device__ __forceinline__ bool IsAligned( + Iterator d_in, + Int2Type /*can_vectorize*/) + { + return (size_t(d_in) & (sizeof(VectorT) - 1)) == 0; + } + + // Whether or not the input is aligned with the vector type (specialized for types we cannot vectorize) + template + static __device__ __forceinline__ bool IsAligned( + Iterator /*d_in*/, + Int2Type /*can_vectorize*/) + { + return false; + } + + + //--------------------------------------------------------------------- + // Constructor + //--------------------------------------------------------------------- + + /** + * Constructor + */ + __device__ __forceinline__ AgentReduce( + TempStorage& temp_storage, ///< Reference to temp_storage + InputIteratorT d_in, ///< Input data to reduce + ReductionOp reduction_op) ///< Binary reduction operator + : + temp_storage(temp_storage.Alias()), + d_in(d_in), + d_wrapped_in(d_in), + reduction_op(reduction_op) + {} + + + //--------------------------------------------------------------------- + // Tile consumption + //--------------------------------------------------------------------- + + /** + * Consume a full tile of input (non-vectorized) + */ + template + __device__ __forceinline__ void ConsumeTile( + OutputT &thread_aggregate, + OffsetT block_offset, ///< The offset the tile to consume + int /*valid_items*/, ///< The number of valid items in the tile + Int2Type /*is_full_tile*/, ///< Whether or not this is a full tile + Int2Type /*can_vectorize*/) ///< Whether or not we can vectorize loads + { + OutputT items[ITEMS_PER_THREAD]; + + // Load items in striped fashion + LoadDirectStriped(threadIdx.x, d_wrapped_in + block_offset, items); + + // Reduce items within each thread stripe + thread_aggregate = (IS_FIRST_TILE) ? + internal::ThreadReduce(items, reduction_op) : + internal::ThreadReduce(items, reduction_op, thread_aggregate); + } + + + /** + * Consume a full tile of input (vectorized) + */ + template + __device__ __forceinline__ void ConsumeTile( + OutputT &thread_aggregate, + OffsetT block_offset, ///< The offset the tile to consume + int /*valid_items*/, ///< The number of valid items in the tile + Int2Type /*is_full_tile*/, ///< Whether or not this is a full tile + Int2Type /*can_vectorize*/) ///< Whether or not we can vectorize loads + { + // Alias items as an array of VectorT and load it in striped fashion + enum { WORDS = ITEMS_PER_THREAD / VECTOR_LOAD_LENGTH }; + + // Fabricate a vectorized input iterator + InputT *d_in_unqualified = const_cast(d_in) + block_offset + (threadIdx.x * VECTOR_LOAD_LENGTH); + CacheModifiedInputIterator d_vec_in( + reinterpret_cast(d_in_unqualified)); + + // Load items as vector items + InputT input_items[ITEMS_PER_THREAD]; + VectorT *vec_items = reinterpret_cast(input_items); + #pragma unroll + for (int i = 0; i < WORDS; ++i) + vec_items[i] = d_vec_in[BLOCK_THREADS * i]; + + // Convert from input type to output type + OutputT items[ITEMS_PER_THREAD]; + #pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; ++i) + items[i] = input_items[i]; + + // Reduce items within each thread stripe + thread_aggregate = (IS_FIRST_TILE) ? + internal::ThreadReduce(items, reduction_op) : + internal::ThreadReduce(items, reduction_op, thread_aggregate); + } + + + /** + * Consume a partial tile of input + */ + template + __device__ __forceinline__ void ConsumeTile( + OutputT &thread_aggregate, + OffsetT block_offset, ///< The offset the tile to consume + int valid_items, ///< The number of valid items in the tile + Int2Type /*is_full_tile*/, ///< Whether or not this is a full tile + Int2Type /*can_vectorize*/) ///< Whether or not we can vectorize loads + { + // Partial tile + int thread_offset = threadIdx.x; + + // Read first item + if ((IS_FIRST_TILE) && (thread_offset < valid_items)) + { + thread_aggregate = d_wrapped_in[block_offset + thread_offset]; + thread_offset += BLOCK_THREADS; + } + + // Continue reading items (block-striped) + while (thread_offset < valid_items) + { + OutputT item = d_wrapped_in[block_offset + thread_offset]; + thread_aggregate = reduction_op(thread_aggregate, item); + thread_offset += BLOCK_THREADS; + } + } + + + //--------------------------------------------------------------- + // Consume a contiguous segment of tiles + //--------------------------------------------------------------------- + + /** + * \brief Reduce a contiguous segment of input tiles + */ + template + __device__ __forceinline__ OutputT ConsumeRange( + GridEvenShare &even_share, ///< GridEvenShare descriptor + Int2Type can_vectorize) ///< Whether or not we can vectorize loads + { + OutputT thread_aggregate; + + if (even_share.block_offset + TILE_ITEMS > even_share.block_end) + { + // First tile isn't full (not all threads have valid items) + int valid_items = even_share.block_end - even_share.block_offset; + ConsumeTile(thread_aggregate, even_share.block_offset, valid_items, Int2Type(), can_vectorize); + return BlockReduceT(temp_storage.reduce).Reduce(thread_aggregate, reduction_op, valid_items); + } + + // At least one full block + ConsumeTile(thread_aggregate, even_share.block_offset, TILE_ITEMS, Int2Type(), can_vectorize); + even_share.block_offset += even_share.block_stride; + + // Consume subsequent full tiles of input + while (even_share.block_offset + TILE_ITEMS <= even_share.block_end) + { + ConsumeTile(thread_aggregate, even_share.block_offset, TILE_ITEMS, Int2Type(), can_vectorize); + even_share.block_offset += even_share.block_stride; + } + + // Consume a partially-full tile + if (even_share.block_offset < even_share.block_end) + { + int valid_items = even_share.block_end - even_share.block_offset; + ConsumeTile(thread_aggregate, even_share.block_offset, valid_items, Int2Type(), can_vectorize); + } + + // Compute block-wide reduction (all threads have valid items) + return BlockReduceT(temp_storage.reduce).Reduce(thread_aggregate, reduction_op); + } + + + /** + * \brief Reduce a contiguous segment of input tiles + */ + __device__ __forceinline__ OutputT ConsumeRange( + OffsetT block_offset, ///< [in] Threadblock begin offset (inclusive) + OffsetT block_end) ///< [in] Threadblock end offset (exclusive) + { + GridEvenShare even_share; + even_share.template BlockInit(block_offset, block_end); + + return (IsAligned(d_in + block_offset, Int2Type())) ? + ConsumeRange(even_share, Int2Type()) : + ConsumeRange(even_share, Int2Type()); + } + + + /** + * Reduce a contiguous segment of input tiles + */ + __device__ __forceinline__ OutputT ConsumeTiles( + GridEvenShare &even_share) ///< [in] GridEvenShare descriptor + { + // Initialize GRID_MAPPING_STRIP_MINE even-share descriptor for this thread block + even_share.template BlockInit(); + + return (IsAligned(d_in, Int2Type())) ? + ConsumeRange(even_share, Int2Type()) : + ConsumeRange(even_share, Int2Type()); + + } + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/agent/agent_reduce_by_key.cuh b/fastertransformer/cuda/cub/agent/agent_reduce_by_key.cuh new file mode 100644 index 000000000..51964d3e6 --- /dev/null +++ b/fastertransformer/cuda/cub/agent/agent_reduce_by_key.cuh @@ -0,0 +1,547 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::AgentReduceByKey implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduce-value-by-key. + */ + +#pragma once + +#include + +#include "single_pass_scan_operators.cuh" +#include "../block/block_load.cuh" +#include "../block/block_store.cuh" +#include "../block/block_scan.cuh" +#include "../block/block_discontinuity.cuh" +#include "../iterator/cache_modified_input_iterator.cuh" +#include "../iterator/constant_input_iterator.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/****************************************************************************** + * Tuning policy types + ******************************************************************************/ + +/** + * Parameterizable tuning policy type for AgentReduceByKey + */ +template < + int _BLOCK_THREADS, ///< Threads per thread block + int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) + BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use + CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements + BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use +struct AgentReduceByKeyPolicy +{ + enum + { + BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block + ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) + }; + + static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use + static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements + static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use +}; + + +/****************************************************************************** + * Thread block abstractions + ******************************************************************************/ + +/** + * \brief AgentReduceByKey implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduce-value-by-key + */ +template < + typename AgentReduceByKeyPolicyT, ///< Parameterized AgentReduceByKeyPolicy tuning policy type + typename KeysInputIteratorT, ///< Random-access input iterator type for keys + typename UniqueOutputIteratorT, ///< Random-access output iterator type for keys + typename ValuesInputIteratorT, ///< Random-access input iterator type for values + typename AggregatesOutputIteratorT, ///< Random-access output iterator type for values + typename NumRunsOutputIteratorT, ///< Output iterator type for recording number of items selected + typename EqualityOpT, ///< KeyT equality operator type + typename ReductionOpT, ///< ValueT reduction operator type + typename OffsetT> ///< Signed integer type for global offsets +struct AgentReduceByKey +{ + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + // The input keys type + typedef typename std::iterator_traits::value_type KeyInputT; + + // The output keys type + typedef typename If<(Equals::value_type, void>::VALUE), // KeyOutputT = (if output iterator's value type is void) ? + typename std::iterator_traits::value_type, // ... then the input iterator's value type, + typename std::iterator_traits::value_type>::Type KeyOutputT; // ... else the output iterator's value type + + // The input values type + typedef typename std::iterator_traits::value_type ValueInputT; + + // The output values type + typedef typename If<(Equals::value_type, void>::VALUE), // ValueOutputT = (if output iterator's value type is void) ? + typename std::iterator_traits::value_type, // ... then the input iterator's value type, + typename std::iterator_traits::value_type>::Type ValueOutputT; // ... else the output iterator's value type + + // Tuple type for scanning (pairs accumulated segment-value with segment-index) + typedef KeyValuePair OffsetValuePairT; + + // Tuple type for pairing keys and values + typedef KeyValuePair KeyValuePairT; + + // Tile status descriptor interface type + typedef ReduceByKeyScanTileState ScanTileStateT; + + // Guarded inequality functor + template + struct GuardedInequalityWrapper + { + _EqualityOpT op; ///< Wrapped equality operator + int num_remaining; ///< Items remaining + + /// Constructor + __host__ __device__ __forceinline__ + GuardedInequalityWrapper(_EqualityOpT op, int num_remaining) : op(op), num_remaining(num_remaining) {} + + /// Boolean inequality operator, returns (a != b) + template + __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b, int idx) const + { + if (idx < num_remaining) + return !op(a, b); // In bounds + + // Return true if first out-of-bounds item, false otherwise + return (idx == num_remaining); + } + }; + + + // Constants + enum + { + BLOCK_THREADS = AgentReduceByKeyPolicyT::BLOCK_THREADS, + ITEMS_PER_THREAD = AgentReduceByKeyPolicyT::ITEMS_PER_THREAD, + TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, + TWO_PHASE_SCATTER = (ITEMS_PER_THREAD > 1), + + // Whether or not the scan operation has a zero-valued identity value (true if we're performing addition on a primitive type) + HAS_IDENTITY_ZERO = (Equals::VALUE) && (Traits::PRIMITIVE), + }; + + // Cache-modified Input iterator wrapper type (for applying cache modifier) for keys + typedef typename If::VALUE, + CacheModifiedInputIterator, // Wrap the native input pointer with CacheModifiedValuesInputIterator + KeysInputIteratorT>::Type // Directly use the supplied input iterator type + WrappedKeysInputIteratorT; + + // Cache-modified Input iterator wrapper type (for applying cache modifier) for values + typedef typename If::VALUE, + CacheModifiedInputIterator, // Wrap the native input pointer with CacheModifiedValuesInputIterator + ValuesInputIteratorT>::Type // Directly use the supplied input iterator type + WrappedValuesInputIteratorT; + + // Cache-modified Input iterator wrapper type (for applying cache modifier) for fixup values + typedef typename If::VALUE, + CacheModifiedInputIterator, // Wrap the native input pointer with CacheModifiedValuesInputIterator + AggregatesOutputIteratorT>::Type // Directly use the supplied input iterator type + WrappedFixupInputIteratorT; + + // Reduce-value-by-segment scan operator + typedef ReduceBySegmentOp ReduceBySegmentOpT; + + // Parameterized BlockLoad type for keys + typedef BlockLoad< + KeyOutputT, + BLOCK_THREADS, + ITEMS_PER_THREAD, + AgentReduceByKeyPolicyT::LOAD_ALGORITHM> + BlockLoadKeysT; + + // Parameterized BlockLoad type for values + typedef BlockLoad< + ValueOutputT, + BLOCK_THREADS, + ITEMS_PER_THREAD, + AgentReduceByKeyPolicyT::LOAD_ALGORITHM> + BlockLoadValuesT; + + // Parameterized BlockDiscontinuity type for keys + typedef BlockDiscontinuity< + KeyOutputT, + BLOCK_THREADS> + BlockDiscontinuityKeys; + + // Parameterized BlockScan type + typedef BlockScan< + OffsetValuePairT, + BLOCK_THREADS, + AgentReduceByKeyPolicyT::SCAN_ALGORITHM> + BlockScanT; + + // Callback type for obtaining tile prefix during block scan + typedef TilePrefixCallbackOp< + OffsetValuePairT, + ReduceBySegmentOpT, + ScanTileStateT> + TilePrefixCallbackOpT; + + // Key and value exchange types + typedef KeyOutputT KeyExchangeT[TILE_ITEMS + 1]; + typedef ValueOutputT ValueExchangeT[TILE_ITEMS + 1]; + + // Shared memory type for this thread block + union _TempStorage + { + struct + { + typename BlockScanT::TempStorage scan; // Smem needed for tile scanning + typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback + typename BlockDiscontinuityKeys::TempStorage discontinuity; // Smem needed for discontinuity detection + }; + + // Smem needed for loading keys + typename BlockLoadKeysT::TempStorage load_keys; + + // Smem needed for loading values + typename BlockLoadValuesT::TempStorage load_values; + + // Smem needed for compacting key value pairs(allows non POD items in this union) + Uninitialized raw_exchange; + }; + + // Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + _TempStorage& temp_storage; ///< Reference to temp_storage + WrappedKeysInputIteratorT d_keys_in; ///< Input keys + UniqueOutputIteratorT d_unique_out; ///< Unique output keys + WrappedValuesInputIteratorT d_values_in; ///< Input values + AggregatesOutputIteratorT d_aggregates_out; ///< Output value aggregates + NumRunsOutputIteratorT d_num_runs_out; ///< Output pointer for total number of segments identified + EqualityOpT equality_op; ///< KeyT equality operator + ReductionOpT reduction_op; ///< Reduction operator + ReduceBySegmentOpT scan_op; ///< Reduce-by-segment scan operator + + + //--------------------------------------------------------------------- + // Constructor + //--------------------------------------------------------------------- + + // Constructor + __device__ __forceinline__ + AgentReduceByKey( + TempStorage& temp_storage, ///< Reference to temp_storage + KeysInputIteratorT d_keys_in, ///< Input keys + UniqueOutputIteratorT d_unique_out, ///< Unique output keys + ValuesInputIteratorT d_values_in, ///< Input values + AggregatesOutputIteratorT d_aggregates_out, ///< Output value aggregates + NumRunsOutputIteratorT d_num_runs_out, ///< Output pointer for total number of segments identified + EqualityOpT equality_op, ///< KeyT equality operator + ReductionOpT reduction_op) ///< ValueT reduction operator + : + temp_storage(temp_storage.Alias()), + d_keys_in(d_keys_in), + d_unique_out(d_unique_out), + d_values_in(d_values_in), + d_aggregates_out(d_aggregates_out), + d_num_runs_out(d_num_runs_out), + equality_op(equality_op), + reduction_op(reduction_op), + scan_op(reduction_op) + {} + + + //--------------------------------------------------------------------- + // Scatter utility methods + //--------------------------------------------------------------------- + + /** + * Directly scatter flagged items to output offsets + */ + __device__ __forceinline__ void ScatterDirect( + KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD], + OffsetT (&segment_flags)[ITEMS_PER_THREAD], + OffsetT (&segment_indices)[ITEMS_PER_THREAD]) + { + // Scatter flagged keys and values + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + if (segment_flags[ITEM]) + { + d_unique_out[segment_indices[ITEM]] = scatter_items[ITEM].key; + d_aggregates_out[segment_indices[ITEM]] = scatter_items[ITEM].value; + } + } + } + + + /** + * 2-phase scatter flagged items to output offsets + * + * The exclusive scan causes each head flag to be paired with the previous + * value aggregate: the scatter offsets must be decremented for value aggregates + */ + __device__ __forceinline__ void ScatterTwoPhase( + KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD], + OffsetT (&segment_flags)[ITEMS_PER_THREAD], + OffsetT (&segment_indices)[ITEMS_PER_THREAD], + OffsetT num_tile_segments, + OffsetT num_tile_segments_prefix) + { + CTA_SYNC(); + + // Compact and scatter pairs + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + if (segment_flags[ITEM]) + { + temp_storage.raw_exchange.Alias()[segment_indices[ITEM] - num_tile_segments_prefix] = scatter_items[ITEM]; + } + } + + CTA_SYNC(); + + for (int item = threadIdx.x; item < num_tile_segments; item += BLOCK_THREADS) + { + KeyValuePairT pair = temp_storage.raw_exchange.Alias()[item]; + d_unique_out[num_tile_segments_prefix + item] = pair.key; + d_aggregates_out[num_tile_segments_prefix + item] = pair.value; + } + } + + + /** + * Scatter flagged items + */ + __device__ __forceinline__ void Scatter( + KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD], + OffsetT (&segment_flags)[ITEMS_PER_THREAD], + OffsetT (&segment_indices)[ITEMS_PER_THREAD], + OffsetT num_tile_segments, + OffsetT num_tile_segments_prefix) + { + // Do a one-phase scatter if (a) two-phase is disabled or (b) the average number of selected items per thread is less than one + if (TWO_PHASE_SCATTER && (num_tile_segments > BLOCK_THREADS)) + { + ScatterTwoPhase( + scatter_items, + segment_flags, + segment_indices, + num_tile_segments, + num_tile_segments_prefix); + } + else + { + ScatterDirect( + scatter_items, + segment_flags, + segment_indices); + } + } + + + //--------------------------------------------------------------------- + // Cooperatively scan a device-wide sequence of tiles with other CTAs + //--------------------------------------------------------------------- + + /** + * Process a tile of input (dynamic chained scan) + */ + template ///< Whether the current tile is the last tile + __device__ __forceinline__ void ConsumeTile( + OffsetT num_remaining, ///< Number of global input items remaining (including this tile) + int tile_idx, ///< Tile index + OffsetT tile_offset, ///< Tile offset + ScanTileStateT& tile_state) ///< Global tile state descriptor + { + KeyOutputT keys[ITEMS_PER_THREAD]; // Tile keys + KeyOutputT prev_keys[ITEMS_PER_THREAD]; // Tile keys shuffled up + ValueOutputT values[ITEMS_PER_THREAD]; // Tile values + OffsetT head_flags[ITEMS_PER_THREAD]; // Segment head flags + OffsetT segment_indices[ITEMS_PER_THREAD]; // Segment indices + OffsetValuePairT scan_items[ITEMS_PER_THREAD]; // Zipped values and segment flags|indices + KeyValuePairT scatter_items[ITEMS_PER_THREAD]; // Zipped key value pairs for scattering + + // Load keys + if (IS_LAST_TILE) + BlockLoadKeysT(temp_storage.load_keys).Load(d_keys_in + tile_offset, keys, num_remaining); + else + BlockLoadKeysT(temp_storage.load_keys).Load(d_keys_in + tile_offset, keys); + + // Load tile predecessor key in first thread + KeyOutputT tile_predecessor; + if (threadIdx.x == 0) + { + tile_predecessor = (tile_idx == 0) ? + keys[0] : // First tile gets repeat of first item (thus first item will not be flagged as a head) + d_keys_in[tile_offset - 1]; // Subsequent tiles get last key from previous tile + } + + CTA_SYNC(); + + // Load values + if (IS_LAST_TILE) + BlockLoadValuesT(temp_storage.load_values).Load(d_values_in + tile_offset, values, num_remaining); + else + BlockLoadValuesT(temp_storage.load_values).Load(d_values_in + tile_offset, values); + + CTA_SYNC(); + + // Initialize head-flags and shuffle up the previous keys + if (IS_LAST_TILE) + { + // Use custom flag operator to additionally flag the first out-of-bounds item + GuardedInequalityWrapper flag_op(equality_op, num_remaining); + BlockDiscontinuityKeys(temp_storage.discontinuity).FlagHeads( + head_flags, keys, prev_keys, flag_op, tile_predecessor); + } + else + { + InequalityWrapper flag_op(equality_op); + BlockDiscontinuityKeys(temp_storage.discontinuity).FlagHeads( + head_flags, keys, prev_keys, flag_op, tile_predecessor); + } + + // Zip values and head flags + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + scan_items[ITEM].value = values[ITEM]; + scan_items[ITEM].key = head_flags[ITEM]; + } + + // Perform exclusive tile scan + OffsetValuePairT block_aggregate; // Inclusive block-wide scan aggregate + OffsetT num_segments_prefix; // Number of segments prior to this tile + OffsetValuePairT total_aggregate; // The tile prefix folded with block_aggregate + if (tile_idx == 0) + { + // Scan first tile + BlockScanT(temp_storage.scan).ExclusiveScan(scan_items, scan_items, scan_op, block_aggregate); + num_segments_prefix = 0; + total_aggregate = block_aggregate; + + // Update tile status if there are successor tiles + if ((!IS_LAST_TILE) && (threadIdx.x == 0)) + tile_state.SetInclusive(0, block_aggregate); + } + else + { + // Scan non-first tile + TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, scan_op, tile_idx); + BlockScanT(temp_storage.scan).ExclusiveScan(scan_items, scan_items, scan_op, prefix_op); + + block_aggregate = prefix_op.GetBlockAggregate(); + num_segments_prefix = prefix_op.GetExclusivePrefix().key; + total_aggregate = prefix_op.GetInclusivePrefix(); + } + + // Rezip scatter items and segment indices + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + scatter_items[ITEM].key = prev_keys[ITEM]; + scatter_items[ITEM].value = scan_items[ITEM].value; + segment_indices[ITEM] = scan_items[ITEM].key; + } + + // At this point, each flagged segment head has: + // - The key for the previous segment + // - The reduced value from the previous segment + // - The segment index for the reduced value + + // Scatter flagged keys and values + OffsetT num_tile_segments = block_aggregate.key; + Scatter(scatter_items, head_flags, segment_indices, num_tile_segments, num_segments_prefix); + + // Last thread in last tile will output final count (and last pair, if necessary) + if ((IS_LAST_TILE) && (threadIdx.x == BLOCK_THREADS - 1)) + { + OffsetT num_segments = num_segments_prefix + num_tile_segments; + + // If the last tile is a whole tile, output the final_value + if (num_remaining == TILE_ITEMS) + { + d_unique_out[num_segments] = keys[ITEMS_PER_THREAD - 1]; + d_aggregates_out[num_segments] = total_aggregate.value; + num_segments++; + } + + // Output the total number of items selected + *d_num_runs_out = num_segments; + } + } + + + /** + * Scan tiles of items as part of a dynamic chained scan + */ + __device__ __forceinline__ void ConsumeRange( + int num_items, ///< Total number of input items + ScanTileStateT& tile_state, ///< Global tile state descriptor + int start_tile) ///< The starting tile for the current grid + { + // Blocks are launched in increasing order, so just assign one tile per block + int tile_idx = start_tile + blockIdx.x; // Current tile index + OffsetT tile_offset = OffsetT(TILE_ITEMS) * tile_idx; // Global offset for the current tile + OffsetT num_remaining = num_items - tile_offset; // Remaining items (including this tile) + + if (num_remaining > TILE_ITEMS) + { + // Not last tile + ConsumeTile(num_remaining, tile_idx, tile_offset, tile_state); + } + else if (num_remaining > 0) + { + // Last tile + ConsumeTile(num_remaining, tile_idx, tile_offset, tile_state); + } + } + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/agent/agent_rle.cuh b/fastertransformer/cuda/cub/agent/agent_rle.cuh new file mode 100644 index 000000000..cb7a4a652 --- /dev/null +++ b/fastertransformer/cuda/cub/agent/agent_rle.cuh @@ -0,0 +1,837 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::AgentRle implements a stateful abstraction of CUDA thread blocks for participating in device-wide run-length-encode. + */ + +#pragma once + +#include + +#include "single_pass_scan_operators.cuh" +#include "../block/block_load.cuh" +#include "../block/block_store.cuh" +#include "../block/block_scan.cuh" +#include "../block/block_exchange.cuh" +#include "../block/block_discontinuity.cuh" +#include "../grid/grid_queue.cuh" +#include "../iterator/cache_modified_input_iterator.cuh" +#include "../iterator/constant_input_iterator.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/****************************************************************************** + * Tuning policy types + ******************************************************************************/ + +/** + * Parameterizable tuning policy type for AgentRle + */ +template < + int _BLOCK_THREADS, ///< Threads per thread block + int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) + BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use + CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements + bool _STORE_WARP_TIME_SLICING, ///< Whether or not only one warp's worth of shared memory should be allocated and time-sliced among block-warps during any store-related data transpositions (versus each warp having its own storage) + BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use +struct AgentRlePolicy +{ + enum + { + BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block + ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) + STORE_WARP_TIME_SLICING = _STORE_WARP_TIME_SLICING, ///< Whether or not only one warp's worth of shared memory should be allocated and time-sliced among block-warps during any store-related data transpositions (versus each warp having its own storage) + }; + + static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use + static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements + static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use +}; + + + + + +/****************************************************************************** + * Thread block abstractions + ******************************************************************************/ + +/** + * \brief AgentRle implements a stateful abstraction of CUDA thread blocks for participating in device-wide run-length-encode + */ +template < + typename AgentRlePolicyT, ///< Parameterized AgentRlePolicyT tuning policy type + typename InputIteratorT, ///< Random-access input iterator type for data + typename OffsetsOutputIteratorT, ///< Random-access output iterator type for offset values + typename LengthsOutputIteratorT, ///< Random-access output iterator type for length values + typename EqualityOpT, ///< T equality operator type + typename OffsetT> ///< Signed integer type for global offsets +struct AgentRle +{ + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + /// The input value type + typedef typename std::iterator_traits::value_type T; + + /// The lengths output value type + typedef typename If<(Equals::value_type, void>::VALUE), // LengthT = (if output iterator's value type is void) ? + OffsetT, // ... then the OffsetT type, + typename std::iterator_traits::value_type>::Type LengthT; // ... else the output iterator's value type + + /// Tuple type for scanning (pairs run-length and run-index) + typedef KeyValuePair LengthOffsetPair; + + /// Tile status descriptor interface type + typedef ReduceByKeyScanTileState ScanTileStateT; + + // Constants + enum + { + WARP_THREADS = CUB_WARP_THREADS(PTX_ARCH), + BLOCK_THREADS = AgentRlePolicyT::BLOCK_THREADS, + ITEMS_PER_THREAD = AgentRlePolicyT::ITEMS_PER_THREAD, + WARP_ITEMS = WARP_THREADS * ITEMS_PER_THREAD, + TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, + WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, + + /// Whether or not to sync after loading data + SYNC_AFTER_LOAD = (AgentRlePolicyT::LOAD_ALGORITHM != BLOCK_LOAD_DIRECT), + + /// Whether or not only one warp's worth of shared memory should be allocated and time-sliced among block-warps during any store-related data transpositions (versus each warp having its own storage) + STORE_WARP_TIME_SLICING = AgentRlePolicyT::STORE_WARP_TIME_SLICING, + ACTIVE_EXCHANGE_WARPS = (STORE_WARP_TIME_SLICING) ? 1 : WARPS, + }; + + + /** + * Special operator that signals all out-of-bounds items are not equal to everything else, + * forcing both (1) the last item to be tail-flagged and (2) all oob items to be marked + * trivial. + */ + template + struct OobInequalityOp + { + OffsetT num_remaining; + EqualityOpT equality_op; + + __device__ __forceinline__ OobInequalityOp( + OffsetT num_remaining, + EqualityOpT equality_op) + : + num_remaining(num_remaining), + equality_op(equality_op) + {} + + template + __host__ __device__ __forceinline__ bool operator()(T first, T second, Index idx) + { + if (!LAST_TILE || (idx < num_remaining)) + return !equality_op(first, second); + else + return true; + } + }; + + + // Cache-modified Input iterator wrapper type (for applying cache modifier) for data + typedef typename If::VALUE, + CacheModifiedInputIterator, // Wrap the native input pointer with CacheModifiedVLengthnputIterator + InputIteratorT>::Type // Directly use the supplied input iterator type + WrappedInputIteratorT; + + // Parameterized BlockLoad type for data + typedef BlockLoad< + T, + AgentRlePolicyT::BLOCK_THREADS, + AgentRlePolicyT::ITEMS_PER_THREAD, + AgentRlePolicyT::LOAD_ALGORITHM> + BlockLoadT; + + // Parameterized BlockDiscontinuity type for data + typedef BlockDiscontinuity BlockDiscontinuityT; + + // Parameterized WarpScan type + typedef WarpScan WarpScanPairs; + + // Reduce-length-by-run scan operator + typedef ReduceBySegmentOp ReduceBySegmentOpT; + + // Callback type for obtaining tile prefix during block scan + typedef TilePrefixCallbackOp< + LengthOffsetPair, + ReduceBySegmentOpT, + ScanTileStateT> + TilePrefixCallbackOpT; + + // Warp exchange types + typedef WarpExchange WarpExchangePairs; + + typedef typename If::Type WarpExchangePairsStorage; + + typedef WarpExchange WarpExchangeOffsets; + typedef WarpExchange WarpExchangeLengths; + + typedef LengthOffsetPair WarpAggregates[WARPS]; + + // Shared memory type for this thread block + struct _TempStorage + { + // Aliasable storage layout + union Aliasable + { + struct + { + typename BlockDiscontinuityT::TempStorage discontinuity; // Smem needed for discontinuity detection + typename WarpScanPairs::TempStorage warp_scan[WARPS]; // Smem needed for warp-synchronous scans + Uninitialized warp_aggregates; // Smem needed for sharing warp-wide aggregates + typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback + }; + + // Smem needed for input loading + typename BlockLoadT::TempStorage load; + + // Aliasable layout needed for two-phase scatter + union ScatterAliasable + { + unsigned long long align; + WarpExchangePairsStorage exchange_pairs[ACTIVE_EXCHANGE_WARPS]; + typename WarpExchangeOffsets::TempStorage exchange_offsets[ACTIVE_EXCHANGE_WARPS]; + typename WarpExchangeLengths::TempStorage exchange_lengths[ACTIVE_EXCHANGE_WARPS]; + + } scatter_aliasable; + + } aliasable; + + OffsetT tile_idx; // Shared tile index + LengthOffsetPair tile_inclusive; // Inclusive tile prefix + LengthOffsetPair tile_exclusive; // Exclusive tile prefix + }; + + // Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + _TempStorage& temp_storage; ///< Reference to temp_storage + + WrappedInputIteratorT d_in; ///< Pointer to input sequence of data items + OffsetsOutputIteratorT d_offsets_out; ///< Input run offsets + LengthsOutputIteratorT d_lengths_out; ///< Output run lengths + + EqualityOpT equality_op; ///< T equality operator + ReduceBySegmentOpT scan_op; ///< Reduce-length-by-flag scan operator + OffsetT num_items; ///< Total number of input items + + + //--------------------------------------------------------------------- + // Constructor + //--------------------------------------------------------------------- + + // Constructor + __device__ __forceinline__ + AgentRle( + TempStorage &temp_storage, ///< [in] Reference to temp_storage + InputIteratorT d_in, ///< [in] Pointer to input sequence of data items + OffsetsOutputIteratorT d_offsets_out, ///< [out] Pointer to output sequence of run offsets + LengthsOutputIteratorT d_lengths_out, ///< [out] Pointer to output sequence of run lengths + EqualityOpT equality_op, ///< [in] T equality operator + OffsetT num_items) ///< [in] Total number of input items + : + temp_storage(temp_storage.Alias()), + d_in(d_in), + d_offsets_out(d_offsets_out), + d_lengths_out(d_lengths_out), + equality_op(equality_op), + scan_op(cub::Sum()), + num_items(num_items) + {} + + + //--------------------------------------------------------------------- + // Utility methods for initializing the selections + //--------------------------------------------------------------------- + + template + __device__ __forceinline__ void InitializeSelections( + OffsetT tile_offset, + OffsetT num_remaining, + T (&items)[ITEMS_PER_THREAD], + LengthOffsetPair (&lengths_and_num_runs)[ITEMS_PER_THREAD]) + { + bool head_flags[ITEMS_PER_THREAD]; + bool tail_flags[ITEMS_PER_THREAD]; + + OobInequalityOp inequality_op(num_remaining, equality_op); + + if (FIRST_TILE && LAST_TILE) + { + // First-and-last-tile always head-flags the first item and tail-flags the last item + + BlockDiscontinuityT(temp_storage.aliasable.discontinuity).FlagHeadsAndTails( + head_flags, tail_flags, items, inequality_op); + } + else if (FIRST_TILE) + { + // First-tile always head-flags the first item + + // Get the first item from the next tile + T tile_successor_item; + if (threadIdx.x == BLOCK_THREADS - 1) + tile_successor_item = d_in[tile_offset + TILE_ITEMS]; + + BlockDiscontinuityT(temp_storage.aliasable.discontinuity).FlagHeadsAndTails( + head_flags, tail_flags, tile_successor_item, items, inequality_op); + } + else if (LAST_TILE) + { + // Last-tile always flags the last item + + // Get the last item from the previous tile + T tile_predecessor_item; + if (threadIdx.x == 0) + tile_predecessor_item = d_in[tile_offset - 1]; + + BlockDiscontinuityT(temp_storage.aliasable.discontinuity).FlagHeadsAndTails( + head_flags, tile_predecessor_item, tail_flags, items, inequality_op); + } + else + { + // Get the first item from the next tile + T tile_successor_item; + if (threadIdx.x == BLOCK_THREADS - 1) + tile_successor_item = d_in[tile_offset + TILE_ITEMS]; + + // Get the last item from the previous tile + T tile_predecessor_item; + if (threadIdx.x == 0) + tile_predecessor_item = d_in[tile_offset - 1]; + + BlockDiscontinuityT(temp_storage.aliasable.discontinuity).FlagHeadsAndTails( + head_flags, tile_predecessor_item, tail_flags, tile_successor_item, items, inequality_op); + } + + // Zip counts and runs + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + lengths_and_num_runs[ITEM].key = head_flags[ITEM] && (!tail_flags[ITEM]); + lengths_and_num_runs[ITEM].value = ((!head_flags[ITEM]) || (!tail_flags[ITEM])); + } + } + + //--------------------------------------------------------------------- + // Scan utility methods + //--------------------------------------------------------------------- + + /** + * Scan of allocations + */ + __device__ __forceinline__ void WarpScanAllocations( + LengthOffsetPair &tile_aggregate, + LengthOffsetPair &warp_aggregate, + LengthOffsetPair &warp_exclusive_in_tile, + LengthOffsetPair &thread_exclusive_in_warp, + LengthOffsetPair (&lengths_and_num_runs)[ITEMS_PER_THREAD]) + { + // Perform warpscans + unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS); + int lane_id = LaneId(); + + LengthOffsetPair identity; + identity.key = 0; + identity.value = 0; + + LengthOffsetPair thread_inclusive; + LengthOffsetPair thread_aggregate = internal::ThreadReduce(lengths_and_num_runs, scan_op); + WarpScanPairs(temp_storage.aliasable.warp_scan[warp_id]).Scan( + thread_aggregate, + thread_inclusive, + thread_exclusive_in_warp, + identity, + scan_op); + + // Last lane in each warp shares its warp-aggregate + if (lane_id == WARP_THREADS - 1) + temp_storage.aliasable.warp_aggregates.Alias()[warp_id] = thread_inclusive; + + CTA_SYNC(); + + // Accumulate total selected and the warp-wide prefix + warp_exclusive_in_tile = identity; + warp_aggregate = temp_storage.aliasable.warp_aggregates.Alias()[warp_id]; + tile_aggregate = temp_storage.aliasable.warp_aggregates.Alias()[0]; + + #pragma unroll + for (int WARP = 1; WARP < WARPS; ++WARP) + { + if (warp_id == WARP) + warp_exclusive_in_tile = tile_aggregate; + + tile_aggregate = scan_op(tile_aggregate, temp_storage.aliasable.warp_aggregates.Alias()[WARP]); + } + } + + + //--------------------------------------------------------------------- + // Utility methods for scattering selections + //--------------------------------------------------------------------- + + /** + * Two-phase scatter, specialized for warp time-slicing + */ + template + __device__ __forceinline__ void ScatterTwoPhase( + OffsetT tile_num_runs_exclusive_in_global, + OffsetT warp_num_runs_aggregate, + OffsetT warp_num_runs_exclusive_in_tile, + OffsetT (&thread_num_runs_exclusive_in_warp)[ITEMS_PER_THREAD], + LengthOffsetPair (&lengths_and_offsets)[ITEMS_PER_THREAD], + Int2Type is_warp_time_slice) + { + unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS); + int lane_id = LaneId(); + + // Locally compact items within the warp (first warp) + if (warp_id == 0) + { + WarpExchangePairs(temp_storage.aliasable.scatter_aliasable.exchange_pairs[0]).ScatterToStriped( + lengths_and_offsets, thread_num_runs_exclusive_in_warp); + } + + // Locally compact items within the warp (remaining warps) + #pragma unroll + for (int SLICE = 1; SLICE < WARPS; ++SLICE) + { + CTA_SYNC(); + + if (warp_id == SLICE) + { + WarpExchangePairs(temp_storage.aliasable.scatter_aliasable.exchange_pairs[0]).ScatterToStriped( + lengths_and_offsets, thread_num_runs_exclusive_in_warp); + } + } + + // Global scatter + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + if ((ITEM * WARP_THREADS) < warp_num_runs_aggregate - lane_id) + { + OffsetT item_offset = + tile_num_runs_exclusive_in_global + + warp_num_runs_exclusive_in_tile + + (ITEM * WARP_THREADS) + lane_id; + + // Scatter offset + d_offsets_out[item_offset] = lengths_and_offsets[ITEM].key; + + // Scatter length if not the first (global) length + if ((!FIRST_TILE) || (ITEM != 0) || (threadIdx.x > 0)) + { + d_lengths_out[item_offset - 1] = lengths_and_offsets[ITEM].value; + } + } + } + } + + + /** + * Two-phase scatter + */ + template + __device__ __forceinline__ void ScatterTwoPhase( + OffsetT tile_num_runs_exclusive_in_global, + OffsetT warp_num_runs_aggregate, + OffsetT warp_num_runs_exclusive_in_tile, + OffsetT (&thread_num_runs_exclusive_in_warp)[ITEMS_PER_THREAD], + LengthOffsetPair (&lengths_and_offsets)[ITEMS_PER_THREAD], + Int2Type is_warp_time_slice) + { + unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS); + int lane_id = LaneId(); + + // Unzip + OffsetT run_offsets[ITEMS_PER_THREAD]; + LengthT run_lengths[ITEMS_PER_THREAD]; + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + run_offsets[ITEM] = lengths_and_offsets[ITEM].key; + run_lengths[ITEM] = lengths_and_offsets[ITEM].value; + } + + WarpExchangeOffsets(temp_storage.aliasable.scatter_aliasable.exchange_offsets[warp_id]).ScatterToStriped( + run_offsets, thread_num_runs_exclusive_in_warp); + + WARP_SYNC(0xffffffff); + + WarpExchangeLengths(temp_storage.aliasable.scatter_aliasable.exchange_lengths[warp_id]).ScatterToStriped( + run_lengths, thread_num_runs_exclusive_in_warp); + + // Global scatter + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + if ((ITEM * WARP_THREADS) + lane_id < warp_num_runs_aggregate) + { + OffsetT item_offset = + tile_num_runs_exclusive_in_global + + warp_num_runs_exclusive_in_tile + + (ITEM * WARP_THREADS) + lane_id; + + // Scatter offset + d_offsets_out[item_offset] = run_offsets[ITEM]; + + // Scatter length if not the first (global) length + if ((!FIRST_TILE) || (ITEM != 0) || (threadIdx.x > 0)) + { + d_lengths_out[item_offset - 1] = run_lengths[ITEM]; + } + } + } + } + + + /** + * Direct scatter + */ + template + __device__ __forceinline__ void ScatterDirect( + OffsetT tile_num_runs_exclusive_in_global, + OffsetT warp_num_runs_aggregate, + OffsetT warp_num_runs_exclusive_in_tile, + OffsetT (&thread_num_runs_exclusive_in_warp)[ITEMS_PER_THREAD], + LengthOffsetPair (&lengths_and_offsets)[ITEMS_PER_THREAD]) + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + if (thread_num_runs_exclusive_in_warp[ITEM] < warp_num_runs_aggregate) + { + OffsetT item_offset = + tile_num_runs_exclusive_in_global + + warp_num_runs_exclusive_in_tile + + thread_num_runs_exclusive_in_warp[ITEM]; + + // Scatter offset + d_offsets_out[item_offset] = lengths_and_offsets[ITEM].key; + + // Scatter length if not the first (global) length + if (item_offset >= 1) + { + d_lengths_out[item_offset - 1] = lengths_and_offsets[ITEM].value; + } + } + } + } + + + /** + * Scatter + */ + template + __device__ __forceinline__ void Scatter( + OffsetT tile_num_runs_aggregate, + OffsetT tile_num_runs_exclusive_in_global, + OffsetT warp_num_runs_aggregate, + OffsetT warp_num_runs_exclusive_in_tile, + OffsetT (&thread_num_runs_exclusive_in_warp)[ITEMS_PER_THREAD], + LengthOffsetPair (&lengths_and_offsets)[ITEMS_PER_THREAD]) + { + if ((ITEMS_PER_THREAD == 1) || (tile_num_runs_aggregate < BLOCK_THREADS)) + { + // Direct scatter if the warp has any items + if (warp_num_runs_aggregate) + { + ScatterDirect( + tile_num_runs_exclusive_in_global, + warp_num_runs_aggregate, + warp_num_runs_exclusive_in_tile, + thread_num_runs_exclusive_in_warp, + lengths_and_offsets); + } + } + else + { + // Scatter two phase + ScatterTwoPhase( + tile_num_runs_exclusive_in_global, + warp_num_runs_aggregate, + warp_num_runs_exclusive_in_tile, + thread_num_runs_exclusive_in_warp, + lengths_and_offsets, + Int2Type()); + } + } + + + + //--------------------------------------------------------------------- + // Cooperatively scan a device-wide sequence of tiles with other CTAs + //--------------------------------------------------------------------- + + /** + * Process a tile of input (dynamic chained scan) + */ + template < + bool LAST_TILE> + __device__ __forceinline__ LengthOffsetPair ConsumeTile( + OffsetT num_items, ///< Total number of global input items + OffsetT num_remaining, ///< Number of global input items remaining (including this tile) + int tile_idx, ///< Tile index + OffsetT tile_offset, ///< Tile offset + ScanTileStateT &tile_status) ///< Global list of tile status + { + if (tile_idx == 0) + { + // First tile + + // Load items + T items[ITEMS_PER_THREAD]; + if (LAST_TILE) + BlockLoadT(temp_storage.aliasable.load).Load(d_in + tile_offset, items, num_remaining, T()); + else + BlockLoadT(temp_storage.aliasable.load).Load(d_in + tile_offset, items); + + if (SYNC_AFTER_LOAD) + CTA_SYNC(); + + // Set flags + LengthOffsetPair lengths_and_num_runs[ITEMS_PER_THREAD]; + + InitializeSelections( + tile_offset, + num_remaining, + items, + lengths_and_num_runs); + + // Exclusive scan of lengths and runs + LengthOffsetPair tile_aggregate; + LengthOffsetPair warp_aggregate; + LengthOffsetPair warp_exclusive_in_tile; + LengthOffsetPair thread_exclusive_in_warp; + + WarpScanAllocations( + tile_aggregate, + warp_aggregate, + warp_exclusive_in_tile, + thread_exclusive_in_warp, + lengths_and_num_runs); + + // Update tile status if this is not the last tile + if (!LAST_TILE && (threadIdx.x == 0)) + tile_status.SetInclusive(0, tile_aggregate); + + // Update thread_exclusive_in_warp to fold in warp run-length + if (thread_exclusive_in_warp.key == 0) + thread_exclusive_in_warp.value += warp_exclusive_in_tile.value; + + LengthOffsetPair lengths_and_offsets[ITEMS_PER_THREAD]; + OffsetT thread_num_runs_exclusive_in_warp[ITEMS_PER_THREAD]; + LengthOffsetPair lengths_and_num_runs2[ITEMS_PER_THREAD]; + + // Downsweep scan through lengths_and_num_runs + internal::ThreadScanExclusive(lengths_and_num_runs, lengths_and_num_runs2, scan_op, thread_exclusive_in_warp); + + // Zip + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + lengths_and_offsets[ITEM].value = lengths_and_num_runs2[ITEM].value; + lengths_and_offsets[ITEM].key = tile_offset + (threadIdx.x * ITEMS_PER_THREAD) + ITEM; + thread_num_runs_exclusive_in_warp[ITEM] = (lengths_and_num_runs[ITEM].key) ? + lengths_and_num_runs2[ITEM].key : // keep + WARP_THREADS * ITEMS_PER_THREAD; // discard + } + + OffsetT tile_num_runs_aggregate = tile_aggregate.key; + OffsetT tile_num_runs_exclusive_in_global = 0; + OffsetT warp_num_runs_aggregate = warp_aggregate.key; + OffsetT warp_num_runs_exclusive_in_tile = warp_exclusive_in_tile.key; + + // Scatter + Scatter( + tile_num_runs_aggregate, + tile_num_runs_exclusive_in_global, + warp_num_runs_aggregate, + warp_num_runs_exclusive_in_tile, + thread_num_runs_exclusive_in_warp, + lengths_and_offsets); + + // Return running total (inclusive of this tile) + return tile_aggregate; + } + else + { + // Not first tile + + // Load items + T items[ITEMS_PER_THREAD]; + if (LAST_TILE) + BlockLoadT(temp_storage.aliasable.load).Load(d_in + tile_offset, items, num_remaining, T()); + else + BlockLoadT(temp_storage.aliasable.load).Load(d_in + tile_offset, items); + + if (SYNC_AFTER_LOAD) + CTA_SYNC(); + + // Set flags + LengthOffsetPair lengths_and_num_runs[ITEMS_PER_THREAD]; + + InitializeSelections( + tile_offset, + num_remaining, + items, + lengths_and_num_runs); + + // Exclusive scan of lengths and runs + LengthOffsetPair tile_aggregate; + LengthOffsetPair warp_aggregate; + LengthOffsetPair warp_exclusive_in_tile; + LengthOffsetPair thread_exclusive_in_warp; + + WarpScanAllocations( + tile_aggregate, + warp_aggregate, + warp_exclusive_in_tile, + thread_exclusive_in_warp, + lengths_and_num_runs); + + // First warp computes tile prefix in lane 0 + TilePrefixCallbackOpT prefix_op(tile_status, temp_storage.aliasable.prefix, Sum(), tile_idx); + unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS); + if (warp_id == 0) + { + prefix_op(tile_aggregate); + if (threadIdx.x == 0) + temp_storage.tile_exclusive = prefix_op.exclusive_prefix; + } + + CTA_SYNC(); + + LengthOffsetPair tile_exclusive_in_global = temp_storage.tile_exclusive; + + // Update thread_exclusive_in_warp to fold in warp and tile run-lengths + LengthOffsetPair thread_exclusive = scan_op(tile_exclusive_in_global, warp_exclusive_in_tile); + if (thread_exclusive_in_warp.key == 0) + thread_exclusive_in_warp.value += thread_exclusive.value; + + // Downsweep scan through lengths_and_num_runs + LengthOffsetPair lengths_and_num_runs2[ITEMS_PER_THREAD]; + LengthOffsetPair lengths_and_offsets[ITEMS_PER_THREAD]; + OffsetT thread_num_runs_exclusive_in_warp[ITEMS_PER_THREAD]; + + internal::ThreadScanExclusive(lengths_and_num_runs, lengths_and_num_runs2, scan_op, thread_exclusive_in_warp); + + // Zip + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + lengths_and_offsets[ITEM].value = lengths_and_num_runs2[ITEM].value; + lengths_and_offsets[ITEM].key = tile_offset + (threadIdx.x * ITEMS_PER_THREAD) + ITEM; + thread_num_runs_exclusive_in_warp[ITEM] = (lengths_and_num_runs[ITEM].key) ? + lengths_and_num_runs2[ITEM].key : // keep + WARP_THREADS * ITEMS_PER_THREAD; // discard + } + + OffsetT tile_num_runs_aggregate = tile_aggregate.key; + OffsetT tile_num_runs_exclusive_in_global = tile_exclusive_in_global.key; + OffsetT warp_num_runs_aggregate = warp_aggregate.key; + OffsetT warp_num_runs_exclusive_in_tile = warp_exclusive_in_tile.key; + + // Scatter + Scatter( + tile_num_runs_aggregate, + tile_num_runs_exclusive_in_global, + warp_num_runs_aggregate, + warp_num_runs_exclusive_in_tile, + thread_num_runs_exclusive_in_warp, + lengths_and_offsets); + + // Return running total (inclusive of this tile) + return prefix_op.inclusive_prefix; + } + } + + + /** + * Scan tiles of items as part of a dynamic chained scan + */ + template ///< Output iterator type for recording number of items selected + __device__ __forceinline__ void ConsumeRange( + int num_tiles, ///< Total number of input tiles + ScanTileStateT& tile_status, ///< Global list of tile status + NumRunsIteratorT d_num_runs_out) ///< Output pointer for total number of runs identified + { + // Blocks are launched in increasing order, so just assign one tile per block + int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y; // Current tile index + OffsetT tile_offset = tile_idx * TILE_ITEMS; // Global offset for the current tile + OffsetT num_remaining = num_items - tile_offset; // Remaining items (including this tile) + + if (tile_idx < num_tiles - 1) + { + // Not the last tile (full) + ConsumeTile(num_items, num_remaining, tile_idx, tile_offset, tile_status); + } + else if (num_remaining > 0) + { + // The last tile (possibly partially-full) + LengthOffsetPair running_total = ConsumeTile(num_items, num_remaining, tile_idx, tile_offset, tile_status); + + if (threadIdx.x == 0) + { + // Output the total number of items selected + *d_num_runs_out = running_total.key; + + // The inclusive prefix contains accumulated length reduction for the last run + if (running_total.key > 0) + d_lengths_out[running_total.key - 1] = running_total.value; + } + } + } +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/agent/agent_scan.cuh b/fastertransformer/cuda/cub/agent/agent_scan.cuh new file mode 100644 index 000000000..9368615ef --- /dev/null +++ b/fastertransformer/cuda/cub/agent/agent_scan.cuh @@ -0,0 +1,471 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::AgentScan implements a stateful abstraction of CUDA thread blocks for participating in device-wide prefix scan . + */ + +#pragma once + +#include + +#include "single_pass_scan_operators.cuh" +#include "../block/block_load.cuh" +#include "../block/block_store.cuh" +#include "../block/block_scan.cuh" +#include "../grid/grid_queue.cuh" +#include "../iterator/cache_modified_input_iterator.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/****************************************************************************** + * Tuning policy types + ******************************************************************************/ + +/** + * Parameterizable tuning policy type for AgentScan + */ +template < + int _BLOCK_THREADS, ///< Threads per thread block + int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) + BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use + CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements + BlockStoreAlgorithm _STORE_ALGORITHM, ///< The BlockStore algorithm to use + BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use +struct AgentScanPolicy +{ + enum + { + BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block + ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) + }; + + static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use + static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements + static const BlockStoreAlgorithm STORE_ALGORITHM = _STORE_ALGORITHM; ///< The BlockStore algorithm to use + static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use +}; + + + + +/****************************************************************************** + * Thread block abstractions + ******************************************************************************/ + +/** + * \brief AgentScan implements a stateful abstraction of CUDA thread blocks for participating in device-wide prefix scan . + */ +template < + typename AgentScanPolicyT, ///< Parameterized AgentScanPolicyT tuning policy type + typename InputIteratorT, ///< Random-access input iterator type + typename OutputIteratorT, ///< Random-access output iterator type + typename ScanOpT, ///< Scan functor type + typename InitValueT, ///< The init_value element for ScanOpT type (cub::NullType for inclusive scan) + typename OffsetT> ///< Signed integer type for global offsets +struct AgentScan +{ + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + // The input value type + typedef typename std::iterator_traits::value_type InputT; + + // The output value type + typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? + typename std::iterator_traits::value_type, // ... then the input iterator's value type, + typename std::iterator_traits::value_type>::Type OutputT; // ... else the output iterator's value type + + // Tile status descriptor interface type + typedef ScanTileState ScanTileStateT; + + // Input iterator wrapper type (for applying cache modifier) + typedef typename If::VALUE, + CacheModifiedInputIterator, // Wrap the native input pointer with CacheModifiedInputIterator + InputIteratorT>::Type // Directly use the supplied input iterator type + WrappedInputIteratorT; + + // Constants + enum + { + IS_INCLUSIVE = Equals::VALUE, // Inclusive scan if no init_value type is provided + BLOCK_THREADS = AgentScanPolicyT::BLOCK_THREADS, + ITEMS_PER_THREAD = AgentScanPolicyT::ITEMS_PER_THREAD, + TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, + }; + + // Parameterized BlockLoad type + typedef BlockLoad< + OutputT, + AgentScanPolicyT::BLOCK_THREADS, + AgentScanPolicyT::ITEMS_PER_THREAD, + AgentScanPolicyT::LOAD_ALGORITHM> + BlockLoadT; + + // Parameterized BlockStore type + typedef BlockStore< + OutputT, + AgentScanPolicyT::BLOCK_THREADS, + AgentScanPolicyT::ITEMS_PER_THREAD, + AgentScanPolicyT::STORE_ALGORITHM> + BlockStoreT; + + // Parameterized BlockScan type + typedef BlockScan< + OutputT, + AgentScanPolicyT::BLOCK_THREADS, + AgentScanPolicyT::SCAN_ALGORITHM> + BlockScanT; + + // Callback type for obtaining tile prefix during block scan + typedef TilePrefixCallbackOp< + OutputT, + ScanOpT, + ScanTileStateT> + TilePrefixCallbackOpT; + + // Stateful BlockScan prefix callback type for managing a running total while scanning consecutive tiles + typedef BlockScanRunningPrefixOp< + OutputT, + ScanOpT> + RunningPrefixCallbackOp; + + // Shared memory type for this thread block + union _TempStorage + { + typename BlockLoadT::TempStorage load; // Smem needed for tile loading + typename BlockStoreT::TempStorage store; // Smem needed for tile storing + + struct + { + typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback + typename BlockScanT::TempStorage scan; // Smem needed for tile scanning + }; + }; + + // Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + _TempStorage& temp_storage; ///< Reference to temp_storage + WrappedInputIteratorT d_in; ///< Input data + OutputIteratorT d_out; ///< Output data + ScanOpT scan_op; ///< Binary scan operator + InitValueT init_value; ///< The init_value element for ScanOpT + + + //--------------------------------------------------------------------- + // Block scan utility methods + //--------------------------------------------------------------------- + + /** + * Exclusive scan specialization (first tile) + */ + __device__ __forceinline__ + void ScanTile( + OutputT (&items)[ITEMS_PER_THREAD], + OutputT init_value, + ScanOpT scan_op, + OutputT &block_aggregate, + Int2Type /*is_inclusive*/) + { + BlockScanT(temp_storage.scan).ExclusiveScan(items, items, init_value, scan_op, block_aggregate); + block_aggregate = scan_op(init_value, block_aggregate); + } + + + /** + * Inclusive scan specialization (first tile) + */ + __device__ __forceinline__ + void ScanTile( + OutputT (&items)[ITEMS_PER_THREAD], + InitValueT /*init_value*/, + ScanOpT scan_op, + OutputT &block_aggregate, + Int2Type /*is_inclusive*/) + { + BlockScanT(temp_storage.scan).InclusiveScan(items, items, scan_op, block_aggregate); + } + + + /** + * Exclusive scan specialization (subsequent tiles) + */ + template + __device__ __forceinline__ + void ScanTile( + OutputT (&items)[ITEMS_PER_THREAD], + ScanOpT scan_op, + PrefixCallback &prefix_op, + Int2Type /*is_inclusive*/) + { + BlockScanT(temp_storage.scan).ExclusiveScan(items, items, scan_op, prefix_op); + } + + + /** + * Inclusive scan specialization (subsequent tiles) + */ + template + __device__ __forceinline__ + void ScanTile( + OutputT (&items)[ITEMS_PER_THREAD], + ScanOpT scan_op, + PrefixCallback &prefix_op, + Int2Type /*is_inclusive*/) + { + BlockScanT(temp_storage.scan).InclusiveScan(items, items, scan_op, prefix_op); + } + + + //--------------------------------------------------------------------- + // Constructor + //--------------------------------------------------------------------- + + // Constructor + __device__ __forceinline__ + AgentScan( + TempStorage& temp_storage, ///< Reference to temp_storage + InputIteratorT d_in, ///< Input data + OutputIteratorT d_out, ///< Output data + ScanOpT scan_op, ///< Binary scan operator + InitValueT init_value) ///< Initial value to seed the exclusive scan + : + temp_storage(temp_storage.Alias()), + d_in(d_in), + d_out(d_out), + scan_op(scan_op), + init_value(init_value) + {} + + + //--------------------------------------------------------------------- + // Cooperatively scan a device-wide sequence of tiles with other CTAs + //--------------------------------------------------------------------- + + /** + * Process a tile of input (dynamic chained scan) + */ + template ///< Whether the current tile is the last tile + __device__ __forceinline__ void ConsumeTile( + OffsetT num_remaining, ///< Number of global input items remaining (including this tile) + int tile_idx, ///< Tile index + OffsetT tile_offset, ///< Tile offset + ScanTileStateT& tile_state) ///< Global tile state descriptor + { + // Load items + OutputT items[ITEMS_PER_THREAD]; + + if (IS_LAST_TILE) + BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items, num_remaining); + else + BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items); + + CTA_SYNC(); + + // Perform tile scan + if (tile_idx == 0) + { + // Scan first tile + OutputT block_aggregate; + ScanTile(items, init_value, scan_op, block_aggregate, Int2Type()); + if ((!IS_LAST_TILE) && (threadIdx.x == 0)) + tile_state.SetInclusive(0, block_aggregate); + } + else + { + // Scan non-first tile + TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, scan_op, tile_idx); + ScanTile(items, scan_op, prefix_op, Int2Type()); + } + + CTA_SYNC(); + + // Store items + if (IS_LAST_TILE) + BlockStoreT(temp_storage.store).Store(d_out + tile_offset, items, num_remaining); + else + BlockStoreT(temp_storage.store).Store(d_out + tile_offset, items); + } + + + /** + * Scan tiles of items as part of a dynamic chained scan + */ + __device__ __forceinline__ void ConsumeRange( + int num_items, ///< Total number of input items + ScanTileStateT& tile_state, ///< Global tile state descriptor + int start_tile) ///< The starting tile for the current grid + { + // Blocks are launched in increasing order, so just assign one tile per block + int tile_idx = start_tile + blockIdx.x; // Current tile index + OffsetT tile_offset = OffsetT(TILE_ITEMS) * tile_idx; // Global offset for the current tile + OffsetT num_remaining = num_items - tile_offset; // Remaining items (including this tile) + + if (num_remaining > TILE_ITEMS) + { + // Not last tile + ConsumeTile(num_remaining, tile_idx, tile_offset, tile_state); + } + else if (num_remaining > 0) + { + // Last tile + ConsumeTile(num_remaining, tile_idx, tile_offset, tile_state); + } + } + + + //--------------------------------------------------------------------- + // Scan an sequence of consecutive tiles (independent of other thread blocks) + //--------------------------------------------------------------------- + + /** + * Process a tile of input + */ + template < + bool IS_FIRST_TILE, + bool IS_LAST_TILE> + __device__ __forceinline__ void ConsumeTile( + OffsetT tile_offset, ///< Tile offset + RunningPrefixCallbackOp& prefix_op, ///< Running prefix operator + int valid_items = TILE_ITEMS) ///< Number of valid items in the tile + { + // Load items + OutputT items[ITEMS_PER_THREAD]; + + if (IS_LAST_TILE) + BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items, valid_items); + else + BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items); + + CTA_SYNC(); + + // Block scan + if (IS_FIRST_TILE) + { + OutputT block_aggregate; + ScanTile(items, init_value, scan_op, block_aggregate, Int2Type()); + prefix_op.running_total = block_aggregate; + } + else + { + ScanTile(items, scan_op, prefix_op, Int2Type()); + } + + CTA_SYNC(); + + // Store items + if (IS_LAST_TILE) + BlockStoreT(temp_storage.store).Store(d_out + tile_offset, items, valid_items); + else + BlockStoreT(temp_storage.store).Store(d_out + tile_offset, items); + } + + + /** + * Scan a consecutive share of input tiles + */ + __device__ __forceinline__ void ConsumeRange( + OffsetT range_offset, ///< [in] Threadblock begin offset (inclusive) + OffsetT range_end) ///< [in] Threadblock end offset (exclusive) + { + BlockScanRunningPrefixOp prefix_op(scan_op); + + if (range_offset + TILE_ITEMS <= range_end) + { + // Consume first tile of input (full) + ConsumeTile(range_offset, prefix_op); + range_offset += TILE_ITEMS; + + // Consume subsequent full tiles of input + while (range_offset + TILE_ITEMS <= range_end) + { + ConsumeTile(range_offset, prefix_op); + range_offset += TILE_ITEMS; + } + + // Consume a partially-full tile + if (range_offset < range_end) + { + int valid_items = range_end - range_offset; + ConsumeTile(range_offset, prefix_op, valid_items); + } + } + else + { + // Consume the first tile of input (partially-full) + int valid_items = range_end - range_offset; + ConsumeTile(range_offset, prefix_op, valid_items); + } + } + + + /** + * Scan a consecutive share of input tiles, seeded with the specified prefix value + */ + __device__ __forceinline__ void ConsumeRange( + OffsetT range_offset, ///< [in] Threadblock begin offset (inclusive) + OffsetT range_end, ///< [in] Threadblock end offset (exclusive) + OutputT prefix) ///< [in] The prefix to apply to the scan segment + { + BlockScanRunningPrefixOp prefix_op(prefix, scan_op); + + // Consume full tiles of input + while (range_offset + TILE_ITEMS <= range_end) + { + ConsumeTile(range_offset, prefix_op); + range_offset += TILE_ITEMS; + } + + // Consume a partially-full tile + if (range_offset < range_end) + { + int valid_items = range_end - range_offset; + ConsumeTile(range_offset, prefix_op, valid_items); + } + } + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/agent/agent_segment_fixup.cuh b/fastertransformer/cuda/cub/agent/agent_segment_fixup.cuh new file mode 100644 index 000000000..e2de58ed6 --- /dev/null +++ b/fastertransformer/cuda/cub/agent/agent_segment_fixup.cuh @@ -0,0 +1,375 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::AgentSegmentFixup implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduce-value-by-key. + */ + +#pragma once + +#include + +#include "single_pass_scan_operators.cuh" +#include "../block/block_load.cuh" +#include "../block/block_store.cuh" +#include "../block/block_scan.cuh" +#include "../block/block_discontinuity.cuh" +#include "../iterator/cache_modified_input_iterator.cuh" +#include "../iterator/constant_input_iterator.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/****************************************************************************** + * Tuning policy types + ******************************************************************************/ + +/** + * Parameterizable tuning policy type for AgentSegmentFixup + */ +template < + int _BLOCK_THREADS, ///< Threads per thread block + int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) + BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use + CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements + BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use +struct AgentSegmentFixupPolicy +{ + enum + { + BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block + ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) + }; + + static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use + static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements + static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use +}; + + +/****************************************************************************** + * Thread block abstractions + ******************************************************************************/ + +/** + * \brief AgentSegmentFixup implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduce-value-by-key + */ +template < + typename AgentSegmentFixupPolicyT, ///< Parameterized AgentSegmentFixupPolicy tuning policy type + typename PairsInputIteratorT, ///< Random-access input iterator type for keys + typename AggregatesOutputIteratorT, ///< Random-access output iterator type for values + typename EqualityOpT, ///< KeyT equality operator type + typename ReductionOpT, ///< ValueT reduction operator type + typename OffsetT> ///< Signed integer type for global offsets +struct AgentSegmentFixup +{ + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + // Data type of key-value input iterator + typedef typename std::iterator_traits::value_type KeyValuePairT; + + // Value type + typedef typename KeyValuePairT::Value ValueT; + + // Tile status descriptor interface type + typedef ReduceByKeyScanTileState ScanTileStateT; + + // Constants + enum + { + BLOCK_THREADS = AgentSegmentFixupPolicyT::BLOCK_THREADS, + ITEMS_PER_THREAD = AgentSegmentFixupPolicyT::ITEMS_PER_THREAD, + TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, + + // Whether or not do fixup using RLE + global atomics + USE_ATOMIC_FIXUP = (CUB_PTX_ARCH >= 350) && + (Equals::VALUE || + Equals::VALUE || + Equals::VALUE || + Equals::VALUE), + + // Whether or not the scan operation has a zero-valued identity value (true if we're performing addition on a primitive type) + HAS_IDENTITY_ZERO = (Equals::VALUE) && (Traits::PRIMITIVE), + }; + + // Cache-modified Input iterator wrapper type (for applying cache modifier) for keys + typedef typename If::VALUE, + CacheModifiedInputIterator, // Wrap the native input pointer with CacheModifiedValuesInputIterator + PairsInputIteratorT>::Type // Directly use the supplied input iterator type + WrappedPairsInputIteratorT; + + // Cache-modified Input iterator wrapper type (for applying cache modifier) for fixup values + typedef typename If::VALUE, + CacheModifiedInputIterator, // Wrap the native input pointer with CacheModifiedValuesInputIterator + AggregatesOutputIteratorT>::Type // Directly use the supplied input iterator type + WrappedFixupInputIteratorT; + + // Reduce-value-by-segment scan operator + typedef ReduceByKeyOp ReduceBySegmentOpT; + + // Parameterized BlockLoad type for pairs + typedef BlockLoad< + KeyValuePairT, + BLOCK_THREADS, + ITEMS_PER_THREAD, + AgentSegmentFixupPolicyT::LOAD_ALGORITHM> + BlockLoadPairs; + + // Parameterized BlockScan type + typedef BlockScan< + KeyValuePairT, + BLOCK_THREADS, + AgentSegmentFixupPolicyT::SCAN_ALGORITHM> + BlockScanT; + + // Callback type for obtaining tile prefix during block scan + typedef TilePrefixCallbackOp< + KeyValuePairT, + ReduceBySegmentOpT, + ScanTileStateT> + TilePrefixCallbackOpT; + + // Shared memory type for this thread block + union _TempStorage + { + struct + { + typename BlockScanT::TempStorage scan; // Smem needed for tile scanning + typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback + }; + + // Smem needed for loading keys + typename BlockLoadPairs::TempStorage load_pairs; + }; + + // Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + _TempStorage& temp_storage; ///< Reference to temp_storage + WrappedPairsInputIteratorT d_pairs_in; ///< Input keys + AggregatesOutputIteratorT d_aggregates_out; ///< Output value aggregates + WrappedFixupInputIteratorT d_fixup_in; ///< Fixup input values + InequalityWrapper inequality_op; ///< KeyT inequality operator + ReductionOpT reduction_op; ///< Reduction operator + ReduceBySegmentOpT scan_op; ///< Reduce-by-segment scan operator + + + //--------------------------------------------------------------------- + // Constructor + //--------------------------------------------------------------------- + + // Constructor + __device__ __forceinline__ + AgentSegmentFixup( + TempStorage& temp_storage, ///< Reference to temp_storage + PairsInputIteratorT d_pairs_in, ///< Input keys + AggregatesOutputIteratorT d_aggregates_out, ///< Output value aggregates + EqualityOpT equality_op, ///< KeyT equality operator + ReductionOpT reduction_op) ///< ValueT reduction operator + : + temp_storage(temp_storage.Alias()), + d_pairs_in(d_pairs_in), + d_aggregates_out(d_aggregates_out), + d_fixup_in(d_aggregates_out), + inequality_op(equality_op), + reduction_op(reduction_op), + scan_op(reduction_op) + {} + + + //--------------------------------------------------------------------- + // Cooperatively scan a device-wide sequence of tiles with other CTAs + //--------------------------------------------------------------------- + + + /** + * Process input tile. Specialized for atomic-fixup + */ + template + __device__ __forceinline__ void ConsumeTile( + OffsetT num_remaining, ///< Number of global input items remaining (including this tile) + int tile_idx, ///< Tile index + OffsetT tile_offset, ///< Tile offset + ScanTileStateT& tile_state, ///< Global tile state descriptor + Int2Type use_atomic_fixup) ///< Marker whether to use atomicAdd (instead of reduce-by-key) + { + KeyValuePairT pairs[ITEMS_PER_THREAD]; + + // Load pairs + KeyValuePairT oob_pair; + oob_pair.key = -1; + + if (IS_LAST_TILE) + BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs, num_remaining, oob_pair); + else + BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs); + + // RLE + #pragma unroll + for (int ITEM = 1; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + ValueT* d_scatter = d_aggregates_out + pairs[ITEM - 1].key; + if (pairs[ITEM].key != pairs[ITEM - 1].key) + atomicAdd(d_scatter, pairs[ITEM - 1].value); + else + pairs[ITEM].value = reduction_op(pairs[ITEM - 1].value, pairs[ITEM].value); + } + + // Flush last item if valid + ValueT* d_scatter = d_aggregates_out + pairs[ITEMS_PER_THREAD - 1].key; + if ((!IS_LAST_TILE) || (pairs[ITEMS_PER_THREAD - 1].key >= 0)) + atomicAdd(d_scatter, pairs[ITEMS_PER_THREAD - 1].value); + } + + + /** + * Process input tile. Specialized for reduce-by-key fixup + */ + template + __device__ __forceinline__ void ConsumeTile( + OffsetT num_remaining, ///< Number of global input items remaining (including this tile) + int tile_idx, ///< Tile index + OffsetT tile_offset, ///< Tile offset + ScanTileStateT& tile_state, ///< Global tile state descriptor + Int2Type use_atomic_fixup) ///< Marker whether to use atomicAdd (instead of reduce-by-key) + { + KeyValuePairT pairs[ITEMS_PER_THREAD]; + KeyValuePairT scatter_pairs[ITEMS_PER_THREAD]; + + // Load pairs + KeyValuePairT oob_pair; + oob_pair.key = -1; + + if (IS_LAST_TILE) + BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs, num_remaining, oob_pair); + else + BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs); + + CTA_SYNC(); + + KeyValuePairT tile_aggregate; + if (tile_idx == 0) + { + // Exclusive scan of values and segment_flags + BlockScanT(temp_storage.scan).ExclusiveScan(pairs, scatter_pairs, scan_op, tile_aggregate); + + // Update tile status if this is not the last tile + if (threadIdx.x == 0) + { + // Set first segment id to not trigger a flush (invalid from exclusive scan) + scatter_pairs[0].key = pairs[0].key; + + if (!IS_LAST_TILE) + tile_state.SetInclusive(0, tile_aggregate); + + } + } + else + { + // Exclusive scan of values and segment_flags + TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, scan_op, tile_idx); + BlockScanT(temp_storage.scan).ExclusiveScan(pairs, scatter_pairs, scan_op, prefix_op); + tile_aggregate = prefix_op.GetBlockAggregate(); + } + + // Scatter updated values + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + if (scatter_pairs[ITEM].key != pairs[ITEM].key) + { + // Update the value at the key location + ValueT value = d_fixup_in[scatter_pairs[ITEM].key]; + value = reduction_op(value, scatter_pairs[ITEM].value); + + d_aggregates_out[scatter_pairs[ITEM].key] = value; + } + } + + // Finalize the last item + if (IS_LAST_TILE) + { + // Last thread will output final count and last item, if necessary + if (threadIdx.x == BLOCK_THREADS - 1) + { + // If the last tile is a whole tile, the inclusive prefix contains accumulated value reduction for the last segment + if (num_remaining == TILE_ITEMS) + { + // Update the value at the key location + OffsetT last_key = pairs[ITEMS_PER_THREAD - 1].key; + d_aggregates_out[last_key] = reduction_op(tile_aggregate.value, d_fixup_in[last_key]); + } + } + } + } + + + /** + * Scan tiles of items as part of a dynamic chained scan + */ + __device__ __forceinline__ void ConsumeRange( + int num_items, ///< Total number of input items + int num_tiles, ///< Total number of input tiles + ScanTileStateT& tile_state) ///< Global tile state descriptor + { + // Blocks are launched in increasing order, so just assign one tile per block + int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y; // Current tile index + OffsetT tile_offset = tile_idx * TILE_ITEMS; // Global offset for the current tile + OffsetT num_remaining = num_items - tile_offset; // Remaining items (including this tile) + + if (num_remaining > TILE_ITEMS) + { + // Not the last tile (full) + ConsumeTile(num_remaining, tile_idx, tile_offset, tile_state, Int2Type()); + } + else if (num_remaining > 0) + { + // The last tile (possibly partially-full) + ConsumeTile(num_remaining, tile_idx, tile_offset, tile_state, Int2Type()); + } + } + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/agent/agent_select_if.cuh b/fastertransformer/cuda/cub/agent/agent_select_if.cuh new file mode 100644 index 000000000..52ca9fc28 --- /dev/null +++ b/fastertransformer/cuda/cub/agent/agent_select_if.cuh @@ -0,0 +1,703 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in device-wide select. + */ + +#pragma once + +#include + +#include "single_pass_scan_operators.cuh" +#include "../block/block_load.cuh" +#include "../block/block_store.cuh" +#include "../block/block_scan.cuh" +#include "../block/block_exchange.cuh" +#include "../block/block_discontinuity.cuh" +#include "../grid/grid_queue.cuh" +#include "../iterator/cache_modified_input_iterator.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/****************************************************************************** + * Tuning policy types + ******************************************************************************/ + +/** + * Parameterizable tuning policy type for AgentSelectIf + */ +template < + int _BLOCK_THREADS, ///< Threads per thread block + int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) + BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use + CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements + BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use +struct AgentSelectIfPolicy +{ + enum + { + BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block + ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) + }; + + static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use + static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements + static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use +}; + + + + +/****************************************************************************** + * Thread block abstractions + ******************************************************************************/ + + +/** + * \brief AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in device-wide selection + * + * Performs functor-based selection if SelectOpT functor type != NullType + * Otherwise performs flag-based selection if FlagsInputIterator's value type != NullType + * Otherwise performs discontinuity selection (keep unique) + */ +template < + typename AgentSelectIfPolicyT, ///< Parameterized AgentSelectIfPolicy tuning policy type + typename InputIteratorT, ///< Random-access input iterator type for selection items + typename FlagsInputIteratorT, ///< Random-access input iterator type for selections (NullType* if a selection functor or discontinuity flagging is to be used for selection) + typename SelectedOutputIteratorT, ///< Random-access input iterator type for selection_flags items + typename SelectOpT, ///< Selection operator type (NullType if selections or discontinuity flagging is to be used for selection) + typename EqualityOpT, ///< Equality operator type (NullType if selection functor or selections is to be used for selection) + typename OffsetT, ///< Signed integer type for global offsets + bool KEEP_REJECTS> ///< Whether or not we push rejected items to the back of the output +struct AgentSelectIf +{ + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + // The input value type + typedef typename std::iterator_traits::value_type InputT; + + // The output value type + typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? + typename std::iterator_traits::value_type, // ... then the input iterator's value type, + typename std::iterator_traits::value_type>::Type OutputT; // ... else the output iterator's value type + + // The flag value type + typedef typename std::iterator_traits::value_type FlagT; + + // Tile status descriptor interface type + typedef ScanTileState ScanTileStateT; + + // Constants + enum + { + USE_SELECT_OP, + USE_SELECT_FLAGS, + USE_DISCONTINUITY, + + BLOCK_THREADS = AgentSelectIfPolicyT::BLOCK_THREADS, + ITEMS_PER_THREAD = AgentSelectIfPolicyT::ITEMS_PER_THREAD, + TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, + TWO_PHASE_SCATTER = (ITEMS_PER_THREAD > 1), + + SELECT_METHOD = (!Equals::VALUE) ? + USE_SELECT_OP : + (!Equals::VALUE) ? + USE_SELECT_FLAGS : + USE_DISCONTINUITY + }; + + // Cache-modified Input iterator wrapper type (for applying cache modifier) for items + typedef typename If::VALUE, + CacheModifiedInputIterator, // Wrap the native input pointer with CacheModifiedValuesInputIterator + InputIteratorT>::Type // Directly use the supplied input iterator type + WrappedInputIteratorT; + + // Cache-modified Input iterator wrapper type (for applying cache modifier) for values + typedef typename If::VALUE, + CacheModifiedInputIterator, // Wrap the native input pointer with CacheModifiedValuesInputIterator + FlagsInputIteratorT>::Type // Directly use the supplied input iterator type + WrappedFlagsInputIteratorT; + + // Parameterized BlockLoad type for input data + typedef BlockLoad< + OutputT, + BLOCK_THREADS, + ITEMS_PER_THREAD, + AgentSelectIfPolicyT::LOAD_ALGORITHM> + BlockLoadT; + + // Parameterized BlockLoad type for flags + typedef BlockLoad< + FlagT, + BLOCK_THREADS, + ITEMS_PER_THREAD, + AgentSelectIfPolicyT::LOAD_ALGORITHM> + BlockLoadFlags; + + // Parameterized BlockDiscontinuity type for items + typedef BlockDiscontinuity< + OutputT, + BLOCK_THREADS> + BlockDiscontinuityT; + + // Parameterized BlockScan type + typedef BlockScan< + OffsetT, + BLOCK_THREADS, + AgentSelectIfPolicyT::SCAN_ALGORITHM> + BlockScanT; + + // Callback type for obtaining tile prefix during block scan + typedef TilePrefixCallbackOp< + OffsetT, + cub::Sum, + ScanTileStateT> + TilePrefixCallbackOpT; + + // Item exchange type + typedef OutputT ItemExchangeT[TILE_ITEMS]; + + // Shared memory type for this thread block + union _TempStorage + { + struct + { + typename BlockScanT::TempStorage scan; // Smem needed for tile scanning + typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback + typename BlockDiscontinuityT::TempStorage discontinuity; // Smem needed for discontinuity detection + }; + + // Smem needed for loading items + typename BlockLoadT::TempStorage load_items; + + // Smem needed for loading values + typename BlockLoadFlags::TempStorage load_flags; + + // Smem needed for compacting items (allows non POD items in this union) + Uninitialized raw_exchange; + }; + + // Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + _TempStorage& temp_storage; ///< Reference to temp_storage + WrappedInputIteratorT d_in; ///< Input items + SelectedOutputIteratorT d_selected_out; ///< Unique output items + WrappedFlagsInputIteratorT d_flags_in; ///< Input selection flags (if applicable) + InequalityWrapper inequality_op; ///< T inequality operator + SelectOpT select_op; ///< Selection operator + OffsetT num_items; ///< Total number of input items + + + //--------------------------------------------------------------------- + // Constructor + //--------------------------------------------------------------------- + + // Constructor + __device__ __forceinline__ + AgentSelectIf( + TempStorage &temp_storage, ///< Reference to temp_storage + InputIteratorT d_in, ///< Input data + FlagsInputIteratorT d_flags_in, ///< Input selection flags (if applicable) + SelectedOutputIteratorT d_selected_out, ///< Output data + SelectOpT select_op, ///< Selection operator + EqualityOpT equality_op, ///< Equality operator + OffsetT num_items) ///< Total number of input items + : + temp_storage(temp_storage.Alias()), + d_in(d_in), + d_flags_in(d_flags_in), + d_selected_out(d_selected_out), + select_op(select_op), + inequality_op(equality_op), + num_items(num_items) + {} + + + //--------------------------------------------------------------------- + // Utility methods for initializing the selections + //--------------------------------------------------------------------- + + /** + * Initialize selections (specialized for selection operator) + */ + template + __device__ __forceinline__ void InitializeSelections( + OffsetT /*tile_offset*/, + OffsetT num_tile_items, + OutputT (&items)[ITEMS_PER_THREAD], + OffsetT (&selection_flags)[ITEMS_PER_THREAD], + Int2Type /*select_method*/) + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + // Out-of-bounds items are selection_flags + selection_flags[ITEM] = 1; + + if (!IS_LAST_TILE || (OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM < num_tile_items)) + selection_flags[ITEM] = select_op(items[ITEM]); + } + } + + + /** + * Initialize selections (specialized for valid flags) + */ + template + __device__ __forceinline__ void InitializeSelections( + OffsetT tile_offset, + OffsetT num_tile_items, + OutputT (&/*items*/)[ITEMS_PER_THREAD], + OffsetT (&selection_flags)[ITEMS_PER_THREAD], + Int2Type /*select_method*/) + { + CTA_SYNC(); + + FlagT flags[ITEMS_PER_THREAD]; + + if (IS_LAST_TILE) + { + // Out-of-bounds items are selection_flags + BlockLoadFlags(temp_storage.load_flags).Load(d_flags_in + tile_offset, flags, num_tile_items, 1); + } + else + { + BlockLoadFlags(temp_storage.load_flags).Load(d_flags_in + tile_offset, flags); + } + + // Convert flag type to selection_flags type + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + selection_flags[ITEM] = flags[ITEM]; + } + } + + + /** + * Initialize selections (specialized for discontinuity detection) + */ + template + __device__ __forceinline__ void InitializeSelections( + OffsetT tile_offset, + OffsetT num_tile_items, + OutputT (&items)[ITEMS_PER_THREAD], + OffsetT (&selection_flags)[ITEMS_PER_THREAD], + Int2Type /*select_method*/) + { + if (IS_FIRST_TILE) + { + CTA_SYNC(); + + // Set head selection_flags. First tile sets the first flag for the first item + BlockDiscontinuityT(temp_storage.discontinuity).FlagHeads(selection_flags, items, inequality_op); + } + else + { + OutputT tile_predecessor; + if (threadIdx.x == 0) + tile_predecessor = d_in[tile_offset - 1]; + + CTA_SYNC(); + + BlockDiscontinuityT(temp_storage.discontinuity).FlagHeads(selection_flags, items, inequality_op, tile_predecessor); + } + + // Set selection flags for out-of-bounds items + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + // Set selection_flags for out-of-bounds items + if ((IS_LAST_TILE) && (OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM >= num_tile_items)) + selection_flags[ITEM] = 1; + } + } + + + //--------------------------------------------------------------------- + // Scatter utility methods + //--------------------------------------------------------------------- + + /** + * Scatter flagged items to output offsets (specialized for direct scattering) + */ + template + __device__ __forceinline__ void ScatterDirect( + OutputT (&items)[ITEMS_PER_THREAD], + OffsetT (&selection_flags)[ITEMS_PER_THREAD], + OffsetT (&selection_indices)[ITEMS_PER_THREAD], + OffsetT num_selections) + { + // Scatter flagged items + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + if (selection_flags[ITEM]) + { + if ((!IS_LAST_TILE) || selection_indices[ITEM] < num_selections) + { + d_selected_out[selection_indices[ITEM]] = items[ITEM]; + } + } + } + } + + + /** + * Scatter flagged items to output offsets (specialized for two-phase scattering) + */ + template + __device__ __forceinline__ void ScatterTwoPhase( + OutputT (&items)[ITEMS_PER_THREAD], + OffsetT (&selection_flags)[ITEMS_PER_THREAD], + OffsetT (&selection_indices)[ITEMS_PER_THREAD], + int /*num_tile_items*/, ///< Number of valid items in this tile + int num_tile_selections, ///< Number of selections in this tile + OffsetT num_selections_prefix, ///< Total number of selections prior to this tile + OffsetT /*num_rejected_prefix*/, ///< Total number of rejections prior to this tile + Int2Type /*is_keep_rejects*/) ///< Marker type indicating whether to keep rejected items in the second partition + { + CTA_SYNC(); + + // Compact and scatter items + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + int local_scatter_offset = selection_indices[ITEM] - num_selections_prefix; + if (selection_flags[ITEM]) + { + temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM]; + } + } + + CTA_SYNC(); + + for (int item = threadIdx.x; item < num_tile_selections; item += BLOCK_THREADS) + { + d_selected_out[num_selections_prefix + item] = temp_storage.raw_exchange.Alias()[item]; + } + } + + + /** + * Scatter flagged items to output offsets (specialized for two-phase scattering) + */ + template + __device__ __forceinline__ void ScatterTwoPhase( + OutputT (&items)[ITEMS_PER_THREAD], + OffsetT (&selection_flags)[ITEMS_PER_THREAD], + OffsetT (&selection_indices)[ITEMS_PER_THREAD], + int num_tile_items, ///< Number of valid items in this tile + int num_tile_selections, ///< Number of selections in this tile + OffsetT num_selections_prefix, ///< Total number of selections prior to this tile + OffsetT num_rejected_prefix, ///< Total number of rejections prior to this tile + Int2Type /*is_keep_rejects*/) ///< Marker type indicating whether to keep rejected items in the second partition + { + CTA_SYNC(); + + int tile_num_rejections = num_tile_items - num_tile_selections; + + // Scatter items to shared memory (rejections first) + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + int item_idx = (threadIdx.x * ITEMS_PER_THREAD) + ITEM; + int local_selection_idx = selection_indices[ITEM] - num_selections_prefix; + int local_rejection_idx = item_idx - local_selection_idx; + int local_scatter_offset = (selection_flags[ITEM]) ? + tile_num_rejections + local_selection_idx : + local_rejection_idx; + + temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM]; + } + + CTA_SYNC(); + + // Gather items from shared memory and scatter to global + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + int item_idx = (ITEM * BLOCK_THREADS) + threadIdx.x; + int rejection_idx = item_idx; + int selection_idx = item_idx - tile_num_rejections; + OffsetT scatter_offset = (item_idx < tile_num_rejections) ? + num_items - num_rejected_prefix - rejection_idx - 1 : + num_selections_prefix + selection_idx; + + OutputT item = temp_storage.raw_exchange.Alias()[item_idx]; + + if (!IS_LAST_TILE || (item_idx < num_tile_items)) + { + d_selected_out[scatter_offset] = item; + } + } + } + + + /** + * Scatter flagged items + */ + template + __device__ __forceinline__ void Scatter( + OutputT (&items)[ITEMS_PER_THREAD], + OffsetT (&selection_flags)[ITEMS_PER_THREAD], + OffsetT (&selection_indices)[ITEMS_PER_THREAD], + int num_tile_items, ///< Number of valid items in this tile + int num_tile_selections, ///< Number of selections in this tile + OffsetT num_selections_prefix, ///< Total number of selections prior to this tile + OffsetT num_rejected_prefix, ///< Total number of rejections prior to this tile + OffsetT num_selections) ///< Total number of selections including this tile + { + // Do a two-phase scatter if (a) keeping both partitions or (b) two-phase is enabled and the average number of selection_flags items per thread is greater than one + if (KEEP_REJECTS || (TWO_PHASE_SCATTER && (num_tile_selections > BLOCK_THREADS))) + { + ScatterTwoPhase( + items, + selection_flags, + selection_indices, + num_tile_items, + num_tile_selections, + num_selections_prefix, + num_rejected_prefix, + Int2Type()); + } + else + { + ScatterDirect( + items, + selection_flags, + selection_indices, + num_selections); + } + } + + //--------------------------------------------------------------------- + // Cooperatively scan a device-wide sequence of tiles with other CTAs + //--------------------------------------------------------------------- + + + /** + * Process first tile of input (dynamic chained scan). Returns the running count of selections (including this tile) + */ + template + __device__ __forceinline__ OffsetT ConsumeFirstTile( + int num_tile_items, ///< Number of input items comprising this tile + OffsetT tile_offset, ///< Tile offset + ScanTileStateT& tile_state) ///< Global tile state descriptor + { + OutputT items[ITEMS_PER_THREAD]; + OffsetT selection_flags[ITEMS_PER_THREAD]; + OffsetT selection_indices[ITEMS_PER_THREAD]; + + // Load items + if (IS_LAST_TILE) + BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items, num_tile_items); + else + BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items); + + // Initialize selection_flags + InitializeSelections( + tile_offset, + num_tile_items, + items, + selection_flags, + Int2Type()); + + CTA_SYNC(); + + // Exclusive scan of selection_flags + OffsetT num_tile_selections; + BlockScanT(temp_storage.scan).ExclusiveSum(selection_flags, selection_indices, num_tile_selections); + + if (threadIdx.x == 0) + { + // Update tile status if this is not the last tile + if (!IS_LAST_TILE) + tile_state.SetInclusive(0, num_tile_selections); + } + + // Discount any out-of-bounds selections + if (IS_LAST_TILE) + num_tile_selections -= (TILE_ITEMS - num_tile_items); + + // Scatter flagged items + Scatter( + items, + selection_flags, + selection_indices, + num_tile_items, + num_tile_selections, + 0, + 0, + num_tile_selections); + + return num_tile_selections; + } + + + /** + * Process subsequent tile of input (dynamic chained scan). Returns the running count of selections (including this tile) + */ + template + __device__ __forceinline__ OffsetT ConsumeSubsequentTile( + int num_tile_items, ///< Number of input items comprising this tile + int tile_idx, ///< Tile index + OffsetT tile_offset, ///< Tile offset + ScanTileStateT& tile_state) ///< Global tile state descriptor + { + OutputT items[ITEMS_PER_THREAD]; + OffsetT selection_flags[ITEMS_PER_THREAD]; + OffsetT selection_indices[ITEMS_PER_THREAD]; + + // Load items + if (IS_LAST_TILE) + BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items, num_tile_items); + else + BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items); + + // Initialize selection_flags + InitializeSelections( + tile_offset, + num_tile_items, + items, + selection_flags, + Int2Type()); + + CTA_SYNC(); + + // Exclusive scan of values and selection_flags + TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, cub::Sum(), tile_idx); + BlockScanT(temp_storage.scan).ExclusiveSum(selection_flags, selection_indices, prefix_op); + + OffsetT num_tile_selections = prefix_op.GetBlockAggregate(); + OffsetT num_selections = prefix_op.GetInclusivePrefix(); + OffsetT num_selections_prefix = prefix_op.GetExclusivePrefix(); + OffsetT num_rejected_prefix = (tile_idx * TILE_ITEMS) - num_selections_prefix; + + // Discount any out-of-bounds selections + if (IS_LAST_TILE) + { + int num_discount = TILE_ITEMS - num_tile_items; + num_selections -= num_discount; + num_tile_selections -= num_discount; + } + + // Scatter flagged items + Scatter( + items, + selection_flags, + selection_indices, + num_tile_items, + num_tile_selections, + num_selections_prefix, + num_rejected_prefix, + num_selections); + + return num_selections; + } + + + /** + * Process a tile of input + */ + template + __device__ __forceinline__ OffsetT ConsumeTile( + int num_tile_items, ///< Number of input items comprising this tile + int tile_idx, ///< Tile index + OffsetT tile_offset, ///< Tile offset + ScanTileStateT& tile_state) ///< Global tile state descriptor + { + OffsetT num_selections; + if (tile_idx == 0) + { + num_selections = ConsumeFirstTile(num_tile_items, tile_offset, tile_state); + } + else + { + num_selections = ConsumeSubsequentTile(num_tile_items, tile_idx, tile_offset, tile_state); + } + + return num_selections; + } + + + /** + * Scan tiles of items as part of a dynamic chained scan + */ + template ///< Output iterator type for recording number of items selection_flags + __device__ __forceinline__ void ConsumeRange( + int num_tiles, ///< Total number of input tiles + ScanTileStateT& tile_state, ///< Global tile state descriptor + NumSelectedIteratorT d_num_selected_out) ///< Output total number selection_flags + { + // Blocks are launched in increasing order, so just assign one tile per block + int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y; // Current tile index + OffsetT tile_offset = tile_idx * TILE_ITEMS; // Global offset for the current tile + + if (tile_idx < num_tiles - 1) + { + // Not the last tile (full) + ConsumeTile(TILE_ITEMS, tile_idx, tile_offset, tile_state); + } + else + { + // The last tile (possibly partially-full) + OffsetT num_remaining = num_items - tile_offset; + OffsetT num_selections = ConsumeTile(num_remaining, tile_idx, tile_offset, tile_state); + + if (threadIdx.x == 0) + { + // Output the total number of items selection_flags + *d_num_selected_out = num_selections; + } + } + } + +}; + + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/agent/agent_spmv_orig.cuh b/fastertransformer/cuda/cub/agent/agent_spmv_orig.cuh new file mode 100644 index 000000000..54e2a1394 --- /dev/null +++ b/fastertransformer/cuda/cub/agent/agent_spmv_orig.cuh @@ -0,0 +1,670 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::AgentSpmv implements a stateful abstraction of CUDA thread blocks for participating in device-wide SpMV. + */ + +#pragma once + +#include + +#include "../util_type.cuh" +#include "../block/block_reduce.cuh" +#include "../block/block_scan.cuh" +#include "../block/block_exchange.cuh" +#include "../thread/thread_search.cuh" +#include "../thread/thread_operators.cuh" +#include "../iterator/cache_modified_input_iterator.cuh" +#include "../iterator/counting_input_iterator.cuh" +#include "../iterator/tex_ref_input_iterator.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/****************************************************************************** + * Tuning policy + ******************************************************************************/ + +/** + * Parameterizable tuning policy type for AgentSpmv + */ +template < + int _BLOCK_THREADS, ///< Threads per thread block + int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) + CacheLoadModifier _ROW_OFFSETS_SEARCH_LOAD_MODIFIER, ///< Cache load modifier for reading CSR row-offsets during search + CacheLoadModifier _ROW_OFFSETS_LOAD_MODIFIER, ///< Cache load modifier for reading CSR row-offsets + CacheLoadModifier _COLUMN_INDICES_LOAD_MODIFIER, ///< Cache load modifier for reading CSR column-indices + CacheLoadModifier _VALUES_LOAD_MODIFIER, ///< Cache load modifier for reading CSR values + CacheLoadModifier _VECTOR_VALUES_LOAD_MODIFIER, ///< Cache load modifier for reading vector values + bool _DIRECT_LOAD_NONZEROS, ///< Whether to load nonzeros directly from global during sequential merging (vs. pre-staged through shared memory) + BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use +struct AgentSpmvPolicy +{ + enum + { + BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block + ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) + DIRECT_LOAD_NONZEROS = _DIRECT_LOAD_NONZEROS, ///< Whether to load nonzeros directly from global during sequential merging (pre-staged through shared memory) + }; + + static const CacheLoadModifier ROW_OFFSETS_SEARCH_LOAD_MODIFIER = _ROW_OFFSETS_SEARCH_LOAD_MODIFIER; ///< Cache load modifier for reading CSR row-offsets + static const CacheLoadModifier ROW_OFFSETS_LOAD_MODIFIER = _ROW_OFFSETS_LOAD_MODIFIER; ///< Cache load modifier for reading CSR row-offsets + static const CacheLoadModifier COLUMN_INDICES_LOAD_MODIFIER = _COLUMN_INDICES_LOAD_MODIFIER; ///< Cache load modifier for reading CSR column-indices + static const CacheLoadModifier VALUES_LOAD_MODIFIER = _VALUES_LOAD_MODIFIER; ///< Cache load modifier for reading CSR values + static const CacheLoadModifier VECTOR_VALUES_LOAD_MODIFIER = _VECTOR_VALUES_LOAD_MODIFIER; ///< Cache load modifier for reading vector values + static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use + +}; + + +/****************************************************************************** + * Thread block abstractions + ******************************************************************************/ + +template < + typename ValueT, ///< Matrix and vector value type + typename OffsetT> ///< Signed integer type for sequence offsets +struct SpmvParams +{ + ValueT* d_values; ///< Pointer to the array of \p num_nonzeros values of the corresponding nonzero elements of matrix A. + OffsetT* d_row_end_offsets; ///< Pointer to the array of \p m offsets demarcating the end of every row in \p d_column_indices and \p d_values + OffsetT* d_column_indices; ///< Pointer to the array of \p num_nonzeros column-indices of the corresponding nonzero elements of matrix A. (Indices are zero-valued.) + ValueT* d_vector_x; ///< Pointer to the array of \p num_cols values corresponding to the dense input vector x + ValueT* d_vector_y; ///< Pointer to the array of \p num_rows values corresponding to the dense output vector y + int num_rows; ///< Number of rows of matrix A. + int num_cols; ///< Number of columns of matrix A. + int num_nonzeros; ///< Number of nonzero elements of matrix A. + ValueT alpha; ///< Alpha multiplicand + ValueT beta; ///< Beta addend-multiplicand + + TexRefInputIterator t_vector_x; +}; + + +/** + * \brief AgentSpmv implements a stateful abstraction of CUDA thread blocks for participating in device-wide SpMV. + */ +template < + typename AgentSpmvPolicyT, ///< Parameterized AgentSpmvPolicy tuning policy type + typename ValueT, ///< Matrix and vector value type + typename OffsetT, ///< Signed integer type for sequence offsets + bool HAS_ALPHA, ///< Whether the input parameter \p alpha is 1 + bool HAS_BETA, ///< Whether the input parameter \p beta is 0 + int PTX_ARCH = CUB_PTX_ARCH> ///< PTX compute capability +struct AgentSpmv +{ + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + /// Constants + enum + { + BLOCK_THREADS = AgentSpmvPolicyT::BLOCK_THREADS, + ITEMS_PER_THREAD = AgentSpmvPolicyT::ITEMS_PER_THREAD, + TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, + }; + + /// 2D merge path coordinate type + typedef typename CubVector::Type CoordinateT; + + /// Input iterator wrapper types (for applying cache modifiers) + + typedef CacheModifiedInputIterator< + AgentSpmvPolicyT::ROW_OFFSETS_SEARCH_LOAD_MODIFIER, + OffsetT, + OffsetT> + RowOffsetsSearchIteratorT; + + typedef CacheModifiedInputIterator< + AgentSpmvPolicyT::ROW_OFFSETS_LOAD_MODIFIER, + OffsetT, + OffsetT> + RowOffsetsIteratorT; + + typedef CacheModifiedInputIterator< + AgentSpmvPolicyT::COLUMN_INDICES_LOAD_MODIFIER, + OffsetT, + OffsetT> + ColumnIndicesIteratorT; + + typedef CacheModifiedInputIterator< + AgentSpmvPolicyT::VALUES_LOAD_MODIFIER, + ValueT, + OffsetT> + ValueIteratorT; + + typedef CacheModifiedInputIterator< + AgentSpmvPolicyT::VECTOR_VALUES_LOAD_MODIFIER, + ValueT, + OffsetT> + VectorValueIteratorT; + + // Tuple type for scanning (pairs accumulated segment-value with segment-index) + typedef KeyValuePair KeyValuePairT; + + // Reduce-value-by-segment scan operator + typedef ReduceByKeyOp ReduceBySegmentOpT; + + // BlockReduce specialization + typedef BlockReduce< + ValueT, + BLOCK_THREADS, + BLOCK_REDUCE_WARP_REDUCTIONS> + BlockReduceT; + + // BlockScan specialization + typedef BlockScan< + KeyValuePairT, + BLOCK_THREADS, + AgentSpmvPolicyT::SCAN_ALGORITHM> + BlockScanT; + + // BlockScan specialization + typedef BlockScan< + ValueT, + BLOCK_THREADS, + AgentSpmvPolicyT::SCAN_ALGORITHM> + BlockPrefixSumT; + + // BlockExchange specialization + typedef BlockExchange< + ValueT, + BLOCK_THREADS, + ITEMS_PER_THREAD> + BlockExchangeT; + + /// Merge item type (either a non-zero value or a row-end offset) + union MergeItem + { + // Value type to pair with index type OffsetT (NullType if loading values directly during merge) + typedef typename If::Type MergeValueT; + + OffsetT row_end_offset; + MergeValueT nonzero; + }; + + /// Shared memory type required by this thread block + struct _TempStorage + { + CoordinateT tile_coords[2]; + + union Aliasable + { + // Smem needed for tile of merge items + MergeItem merge_items[ITEMS_PER_THREAD + TILE_ITEMS + 1]; + + // Smem needed for block exchange + typename BlockExchangeT::TempStorage exchange; + + // Smem needed for block-wide reduction + typename BlockReduceT::TempStorage reduce; + + // Smem needed for tile scanning + typename BlockScanT::TempStorage scan; + + // Smem needed for tile prefix sum + typename BlockPrefixSumT::TempStorage prefix_sum; + + } aliasable; + }; + + /// Temporary storage type (unionable) + struct TempStorage : Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + + _TempStorage& temp_storage; /// Reference to temp_storage + + SpmvParams& spmv_params; + + ValueIteratorT wd_values; ///< Wrapped pointer to the array of \p num_nonzeros values of the corresponding nonzero elements of matrix A. + RowOffsetsIteratorT wd_row_end_offsets; ///< Wrapped Pointer to the array of \p m offsets demarcating the end of every row in \p d_column_indices and \p d_values + ColumnIndicesIteratorT wd_column_indices; ///< Wrapped Pointer to the array of \p num_nonzeros column-indices of the corresponding nonzero elements of matrix A. (Indices are zero-valued.) + VectorValueIteratorT wd_vector_x; ///< Wrapped Pointer to the array of \p num_cols values corresponding to the dense input vector x + VectorValueIteratorT wd_vector_y; ///< Wrapped Pointer to the array of \p num_cols values corresponding to the dense input vector x + + + //--------------------------------------------------------------------- + // Interface + //--------------------------------------------------------------------- + + /** + * Constructor + */ + __device__ __forceinline__ AgentSpmv( + TempStorage& temp_storage, ///< Reference to temp_storage + SpmvParams& spmv_params) ///< SpMV input parameter bundle + : + temp_storage(temp_storage.Alias()), + spmv_params(spmv_params), + wd_values(spmv_params.d_values), + wd_row_end_offsets(spmv_params.d_row_end_offsets), + wd_column_indices(spmv_params.d_column_indices), + wd_vector_x(spmv_params.d_vector_x), + wd_vector_y(spmv_params.d_vector_y) + {} + + + + + /** + * Consume a merge tile, specialized for direct-load of nonzeros + */ + __device__ __forceinline__ KeyValuePairT ConsumeTile( + int tile_idx, + CoordinateT tile_start_coord, + CoordinateT tile_end_coord, + Int2Type is_direct_load) ///< Marker type indicating whether to load nonzeros directly during path-discovery or beforehand in batch + { + int tile_num_rows = tile_end_coord.x - tile_start_coord.x; + int tile_num_nonzeros = tile_end_coord.y - tile_start_coord.y; + OffsetT* s_tile_row_end_offsets = &temp_storage.aliasable.merge_items[0].row_end_offset; + + // Gather the row end-offsets for the merge tile into shared memory + for (int item = threadIdx.x; item <= tile_num_rows; item += BLOCK_THREADS) + { + s_tile_row_end_offsets[item] = wd_row_end_offsets[tile_start_coord.x + item]; + } + + CTA_SYNC(); + + // Search for the thread's starting coordinate within the merge tile + CountingInputIterator tile_nonzero_indices(tile_start_coord.y); + CoordinateT thread_start_coord; + + MergePathSearch( + OffsetT(threadIdx.x * ITEMS_PER_THREAD), // Diagonal + s_tile_row_end_offsets, // List A + tile_nonzero_indices, // List B + tile_num_rows, + tile_num_nonzeros, + thread_start_coord); + + CTA_SYNC(); // Perf-sync + + // Compute the thread's merge path segment + CoordinateT thread_current_coord = thread_start_coord; + KeyValuePairT scan_segment[ITEMS_PER_THREAD]; + + ValueT running_total = 0.0; + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + OffsetT nonzero_idx = CUB_MIN(tile_nonzero_indices[thread_current_coord.y], spmv_params.num_nonzeros - 1); + OffsetT column_idx = wd_column_indices[nonzero_idx]; + ValueT value = wd_values[nonzero_idx]; + + ValueT vector_value = spmv_params.t_vector_x[column_idx]; +#if (CUB_PTX_ARCH >= 350) + vector_value = wd_vector_x[column_idx]; +#endif + ValueT nonzero = value * vector_value; + + OffsetT row_end_offset = s_tile_row_end_offsets[thread_current_coord.x]; + + if (tile_nonzero_indices[thread_current_coord.y] < row_end_offset) + { + // Move down (accumulate) + running_total += nonzero; + scan_segment[ITEM].value = running_total; + scan_segment[ITEM].key = tile_num_rows; + ++thread_current_coord.y; + } + else + { + // Move right (reset) + scan_segment[ITEM].value = running_total; + scan_segment[ITEM].key = thread_current_coord.x; + running_total = 0.0; + ++thread_current_coord.x; + } + } + + CTA_SYNC(); + + // Block-wide reduce-value-by-segment + KeyValuePairT tile_carry; + ReduceBySegmentOpT scan_op; + KeyValuePairT scan_item; + + scan_item.value = running_total; + scan_item.key = thread_current_coord.x; + + BlockScanT(temp_storage.aliasable.scan).ExclusiveScan(scan_item, scan_item, scan_op, tile_carry); + + if (tile_num_rows > 0) + { + if (threadIdx.x == 0) + scan_item.key = -1; + + // Direct scatter + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + if (scan_segment[ITEM].key < tile_num_rows) + { + if (scan_item.key == scan_segment[ITEM].key) + scan_segment[ITEM].value = scan_item.value + scan_segment[ITEM].value; + + if (HAS_ALPHA) + { + scan_segment[ITEM].value *= spmv_params.alpha; + } + + if (HAS_BETA) + { + // Update the output vector element + ValueT addend = spmv_params.beta * wd_vector_y[tile_start_coord.x + scan_segment[ITEM].key]; + scan_segment[ITEM].value += addend; + } + + // Set the output vector element + spmv_params.d_vector_y[tile_start_coord.x + scan_segment[ITEM].key] = scan_segment[ITEM].value; + } + } + } + + // Return the tile's running carry-out + return tile_carry; + } + + + + /** + * Consume a merge tile, specialized for indirect load of nonzeros + */ + __device__ __forceinline__ KeyValuePairT ConsumeTile( + int tile_idx, + CoordinateT tile_start_coord, + CoordinateT tile_end_coord, + Int2Type is_direct_load) ///< Marker type indicating whether to load nonzeros directly during path-discovery or beforehand in batch + { + int tile_num_rows = tile_end_coord.x - tile_start_coord.x; + int tile_num_nonzeros = tile_end_coord.y - tile_start_coord.y; + +#if (CUB_PTX_ARCH >= 520) + + OffsetT* s_tile_row_end_offsets = &temp_storage.aliasable.merge_items[0].row_end_offset; + ValueT* s_tile_nonzeros = &temp_storage.aliasable.merge_items[tile_num_rows + ITEMS_PER_THREAD].nonzero; + + // Gather the nonzeros for the merge tile into shared memory + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + int nonzero_idx = threadIdx.x + (ITEM * BLOCK_THREADS); + + ValueIteratorT a = wd_values + tile_start_coord.y + nonzero_idx; + ColumnIndicesIteratorT ci = wd_column_indices + tile_start_coord.y + nonzero_idx; + ValueT* s = s_tile_nonzeros + nonzero_idx; + + if (nonzero_idx < tile_num_nonzeros) + { + + OffsetT column_idx = *ci; + ValueT value = *a; + + ValueT vector_value = spmv_params.t_vector_x[column_idx]; + vector_value = wd_vector_x[column_idx]; + + ValueT nonzero = value * vector_value; + + *s = nonzero; + } + } + + +#else + + OffsetT* s_tile_row_end_offsets = &temp_storage.aliasable.merge_items[0].row_end_offset; + ValueT* s_tile_nonzeros = &temp_storage.aliasable.merge_items[tile_num_rows + ITEMS_PER_THREAD].nonzero; + + // Gather the nonzeros for the merge tile into shared memory + if (tile_num_nonzeros > 0) + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + int nonzero_idx = threadIdx.x + (ITEM * BLOCK_THREADS); + nonzero_idx = CUB_MIN(nonzero_idx, tile_num_nonzeros - 1); + + OffsetT column_idx = wd_column_indices[tile_start_coord.y + nonzero_idx]; + ValueT value = wd_values[tile_start_coord.y + nonzero_idx]; + + ValueT vector_value = spmv_params.t_vector_x[column_idx]; +#if (CUB_PTX_ARCH >= 350) + vector_value = wd_vector_x[column_idx]; +#endif + ValueT nonzero = value * vector_value; + + s_tile_nonzeros[nonzero_idx] = nonzero; + } + } + +#endif + + // Gather the row end-offsets for the merge tile into shared memory + #pragma unroll 1 + for (int item = threadIdx.x; item <= tile_num_rows; item += BLOCK_THREADS) + { + s_tile_row_end_offsets[item] = wd_row_end_offsets[tile_start_coord.x + item]; + } + + CTA_SYNC(); + + // Search for the thread's starting coordinate within the merge tile + CountingInputIterator tile_nonzero_indices(tile_start_coord.y); + CoordinateT thread_start_coord; + + MergePathSearch( + OffsetT(threadIdx.x * ITEMS_PER_THREAD), // Diagonal + s_tile_row_end_offsets, // List A + tile_nonzero_indices, // List B + tile_num_rows, + tile_num_nonzeros, + thread_start_coord); + + CTA_SYNC(); // Perf-sync + + // Compute the thread's merge path segment + CoordinateT thread_current_coord = thread_start_coord; + KeyValuePairT scan_segment[ITEMS_PER_THREAD]; + ValueT running_total = 0.0; + + OffsetT row_end_offset = s_tile_row_end_offsets[thread_current_coord.x]; + ValueT nonzero = s_tile_nonzeros[thread_current_coord.y]; + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + if (tile_nonzero_indices[thread_current_coord.y] < row_end_offset) + { + // Move down (accumulate) + scan_segment[ITEM].value = nonzero; + running_total += nonzero; + ++thread_current_coord.y; + nonzero = s_tile_nonzeros[thread_current_coord.y]; + } + else + { + // Move right (reset) + scan_segment[ITEM].value = 0.0; + running_total = 0.0; + ++thread_current_coord.x; + row_end_offset = s_tile_row_end_offsets[thread_current_coord.x]; + } + + scan_segment[ITEM].key = thread_current_coord.x; + } + + CTA_SYNC(); + + // Block-wide reduce-value-by-segment + KeyValuePairT tile_carry; + ReduceBySegmentOpT scan_op; + KeyValuePairT scan_item; + + scan_item.value = running_total; + scan_item.key = thread_current_coord.x; + + BlockScanT(temp_storage.aliasable.scan).ExclusiveScan(scan_item, scan_item, scan_op, tile_carry); + + if (threadIdx.x == 0) + { + scan_item.key = thread_start_coord.x; + scan_item.value = 0.0; + } + + if (tile_num_rows > 0) + { + + CTA_SYNC(); + + // Scan downsweep and scatter + ValueT* s_partials = &temp_storage.aliasable.merge_items[0].nonzero; + + if (scan_item.key != scan_segment[0].key) + { + s_partials[scan_item.key] = scan_item.value; + } + else + { + scan_segment[0].value += scan_item.value; + } + + #pragma unroll + for (int ITEM = 1; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + if (scan_segment[ITEM - 1].key != scan_segment[ITEM].key) + { + s_partials[scan_segment[ITEM - 1].key] = scan_segment[ITEM - 1].value; + } + else + { + scan_segment[ITEM].value += scan_segment[ITEM - 1].value; + } + } + + CTA_SYNC(); + + #pragma unroll 1 + for (int item = threadIdx.x; item < tile_num_rows; item += BLOCK_THREADS) + { + spmv_params.d_vector_y[tile_start_coord.x + item] = s_partials[item]; + } + } + + // Return the tile's running carry-out + return tile_carry; + } + + + /** + * Consume input tile + */ + __device__ __forceinline__ void ConsumeTile( + CoordinateT* d_tile_coordinates, ///< [in] Pointer to the temporary array of tile starting coordinates + KeyValuePairT* d_tile_carry_pairs, ///< [out] Pointer to the temporary array carry-out dot product row-ids, one per block + int num_merge_tiles) ///< [in] Number of merge tiles + { + int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y; // Current tile index + + if (tile_idx >= num_merge_tiles) + return; + + // Read our starting coordinates + if (threadIdx.x < 2) + { + if (d_tile_coordinates == NULL) + { + // Search our starting coordinates + OffsetT diagonal = (tile_idx + threadIdx.x) * TILE_ITEMS; + CoordinateT tile_coord; + CountingInputIterator nonzero_indices(0); + + // Search the merge path + MergePathSearch( + diagonal, + RowOffsetsSearchIteratorT(spmv_params.d_row_end_offsets), + nonzero_indices, + spmv_params.num_rows, + spmv_params.num_nonzeros, + tile_coord); + + temp_storage.tile_coords[threadIdx.x] = tile_coord; + } + else + { + temp_storage.tile_coords[threadIdx.x] = d_tile_coordinates[tile_idx + threadIdx.x]; + } + } + + CTA_SYNC(); + + CoordinateT tile_start_coord = temp_storage.tile_coords[0]; + CoordinateT tile_end_coord = temp_storage.tile_coords[1]; + + // Consume multi-segment tile + KeyValuePairT tile_carry = ConsumeTile( + tile_idx, + tile_start_coord, + tile_end_coord, + Int2Type()); + + // Output the tile's carry-out + if (threadIdx.x == 0) + { + if (HAS_ALPHA) + tile_carry.value *= spmv_params.alpha; + + tile_carry.key += tile_start_coord.x; + d_tile_carry_pairs[tile_idx] = tile_carry; + } + } + + +}; + + + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/agent/single_pass_scan_operators.cuh b/fastertransformer/cuda/cub/agent/single_pass_scan_operators.cuh new file mode 100644 index 000000000..53409bdee --- /dev/null +++ b/fastertransformer/cuda/cub/agent/single_pass_scan_operators.cuh @@ -0,0 +1,815 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Callback operator types for supplying BlockScan prefixes + */ + +#pragma once + +#include + +#include "../thread/thread_load.cuh" +#include "../thread/thread_store.cuh" +#include "../warp/warp_reduce.cuh" +#include "../util_arch.cuh" +#include "../util_device.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/****************************************************************************** + * Prefix functor type for maintaining a running prefix while scanning a + * region independent of other thread blocks + ******************************************************************************/ + +/** + * Stateful callback operator type for supplying BlockScan prefixes. + * Maintains a running prefix that can be applied to consecutive + * BlockScan operations. + */ +template < + typename T, ///< BlockScan value type + typename ScanOpT> ///< Wrapped scan operator type +struct BlockScanRunningPrefixOp +{ + ScanOpT op; ///< Wrapped scan operator + T running_total; ///< Running block-wide prefix + + /// Constructor + __device__ __forceinline__ BlockScanRunningPrefixOp(ScanOpT op) + : + op(op) + {} + + /// Constructor + __device__ __forceinline__ BlockScanRunningPrefixOp( + T starting_prefix, + ScanOpT op) + : + op(op), + running_total(starting_prefix) + {} + + /** + * Prefix callback operator. Returns the block-wide running_total in thread-0. + */ + __device__ __forceinline__ T operator()( + const T &block_aggregate) ///< The aggregate sum of the BlockScan inputs + { + T retval = running_total; + running_total = op(running_total, block_aggregate); + return retval; + } +}; + + +/****************************************************************************** + * Generic tile status interface types for block-cooperative scans + ******************************************************************************/ + +/** + * Enumerations of tile status + */ +enum ScanTileStatus +{ + SCAN_TILE_OOB, // Out-of-bounds (e.g., padding) + SCAN_TILE_INVALID = 99, // Not yet processed + SCAN_TILE_PARTIAL, // Tile aggregate is available + SCAN_TILE_INCLUSIVE, // Inclusive tile prefix is available +}; + + +/** + * Tile status interface. + */ +template < + typename T, + bool SINGLE_WORD = Traits::PRIMITIVE> +struct ScanTileState; + + +/** + * Tile status interface specialized for scan status and value types + * that can be combined into one machine word that can be + * read/written coherently in a single access. + */ +template +struct ScanTileState +{ + // Status word type + typedef typename If<(sizeof(T) == 8), + long long, + typename If<(sizeof(T) == 4), + int, + typename If<(sizeof(T) == 2), + short, + char>::Type>::Type>::Type StatusWord; + + + // Unit word type + typedef typename If<(sizeof(T) == 8), + longlong2, + typename If<(sizeof(T) == 4), + int2, + typename If<(sizeof(T) == 2), + int, + uchar2>::Type>::Type>::Type TxnWord; + + + // Device word type + struct TileDescriptor + { + StatusWord status; + T value; + }; + + + // Constants + enum + { + TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS, + }; + + + // Device storage + TxnWord *d_tile_descriptors; + + /// Constructor + __host__ __device__ __forceinline__ + ScanTileState() + : + d_tile_descriptors(NULL) + {} + + + /// Initializer + __host__ __device__ __forceinline__ + cudaError_t Init( + int /*num_tiles*/, ///< [in] Number of tiles + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t /*temp_storage_bytes*/) ///< [in] Size in bytes of \t d_temp_storage allocation + { + d_tile_descriptors = reinterpret_cast(d_temp_storage); + return cudaSuccess; + } + + + /** + * Compute device memory needed for tile status + */ + __host__ __device__ __forceinline__ + static cudaError_t AllocationSize( + int num_tiles, ///< [in] Number of tiles + size_t &temp_storage_bytes) ///< [out] Size in bytes of \t d_temp_storage allocation + { + temp_storage_bytes = (num_tiles + TILE_STATUS_PADDING) * sizeof(TileDescriptor); // bytes needed for tile status descriptors + return cudaSuccess; + } + + + /** + * Initialize (from device) + */ + __device__ __forceinline__ void InitializeStatus(int num_tiles) + { + int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; + + TxnWord val = TxnWord(); + TileDescriptor *descriptor = reinterpret_cast(&val); + + if (tile_idx < num_tiles) + { + // Not-yet-set + descriptor->status = StatusWord(SCAN_TILE_INVALID); + d_tile_descriptors[TILE_STATUS_PADDING + tile_idx] = val; + } + + if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING)) + { + // Padding + descriptor->status = StatusWord(SCAN_TILE_OOB); + d_tile_descriptors[threadIdx.x] = val; + } + } + + + /** + * Update the specified tile's inclusive value and corresponding status + */ + __device__ __forceinline__ void SetInclusive(int tile_idx, T tile_inclusive) + { + TileDescriptor tile_descriptor; + tile_descriptor.status = SCAN_TILE_INCLUSIVE; + tile_descriptor.value = tile_inclusive; + + TxnWord alias; + *reinterpret_cast(&alias) = tile_descriptor; + ThreadStore(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); + } + + + /** + * Update the specified tile's partial value and corresponding status + */ + __device__ __forceinline__ void SetPartial(int tile_idx, T tile_partial) + { + TileDescriptor tile_descriptor; + tile_descriptor.status = SCAN_TILE_PARTIAL; + tile_descriptor.value = tile_partial; + + TxnWord alias; + *reinterpret_cast(&alias) = tile_descriptor; + ThreadStore(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); + } + + /** + * Wait for the corresponding tile to become non-invalid + */ + __device__ __forceinline__ void WaitForValid( + int tile_idx, + StatusWord &status, + T &value) + { + TileDescriptor tile_descriptor; + do + { + __threadfence_block(); // prevent hoisting loads from loop + TxnWord alias = ThreadLoad(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); + tile_descriptor = reinterpret_cast(alias); + + } while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff)); + + status = tile_descriptor.status; + value = tile_descriptor.value; + } + +}; + + + +/** + * Tile status interface specialized for scan status and value types that + * cannot be combined into one machine word. + */ +template +struct ScanTileState +{ + // Status word type + typedef char StatusWord; + + // Constants + enum + { + TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS, + }; + + // Device storage + StatusWord *d_tile_status; + T *d_tile_partial; + T *d_tile_inclusive; + + /// Constructor + __host__ __device__ __forceinline__ + ScanTileState() + : + d_tile_status(NULL), + d_tile_partial(NULL), + d_tile_inclusive(NULL) + {} + + + /// Initializer + __host__ __device__ __forceinline__ + cudaError_t Init( + int num_tiles, ///< [in] Number of tiles + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t temp_storage_bytes) ///< [in] Size in bytes of \t d_temp_storage allocation + { + cudaError_t error = cudaSuccess; + do + { + void* allocations[3]; + size_t allocation_sizes[3]; + + allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * sizeof(StatusWord); // bytes needed for tile status descriptors + allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized); // bytes needed for partials + allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized); // bytes needed for inclusives + + // Compute allocation pointers into the single storage blob + if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break; + + // Alias the offsets + d_tile_status = reinterpret_cast(allocations[0]); + d_tile_partial = reinterpret_cast(allocations[1]); + d_tile_inclusive = reinterpret_cast(allocations[2]); + } + while (0); + + return error; + } + + + /** + * Compute device memory needed for tile status + */ + __host__ __device__ __forceinline__ + static cudaError_t AllocationSize( + int num_tiles, ///< [in] Number of tiles + size_t &temp_storage_bytes) ///< [out] Size in bytes of \t d_temp_storage allocation + { + // Specify storage allocation requirements + size_t allocation_sizes[3]; + allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * sizeof(StatusWord); // bytes needed for tile status descriptors + allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized); // bytes needed for partials + allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized); // bytes needed for inclusives + + // Set the necessary size of the blob + void* allocations[3]; + return CubDebug(AliasTemporaries(NULL, temp_storage_bytes, allocations, allocation_sizes)); + } + + + /** + * Initialize (from device) + */ + __device__ __forceinline__ void InitializeStatus(int num_tiles) + { + int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; + if (tile_idx < num_tiles) + { + // Not-yet-set + d_tile_status[TILE_STATUS_PADDING + tile_idx] = StatusWord(SCAN_TILE_INVALID); + } + + if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING)) + { + // Padding + d_tile_status[threadIdx.x] = StatusWord(SCAN_TILE_OOB); + } + } + + + /** + * Update the specified tile's inclusive value and corresponding status + */ + __device__ __forceinline__ void SetInclusive(int tile_idx, T tile_inclusive) + { + // Update tile inclusive value + ThreadStore(d_tile_inclusive + TILE_STATUS_PADDING + tile_idx, tile_inclusive); + + // Fence + __threadfence(); + + // Update tile status + ThreadStore(d_tile_status + TILE_STATUS_PADDING + tile_idx, StatusWord(SCAN_TILE_INCLUSIVE)); + } + + + /** + * Update the specified tile's partial value and corresponding status + */ + __device__ __forceinline__ void SetPartial(int tile_idx, T tile_partial) + { + // Update tile partial value + ThreadStore(d_tile_partial + TILE_STATUS_PADDING + tile_idx, tile_partial); + + // Fence + __threadfence(); + + // Update tile status + ThreadStore(d_tile_status + TILE_STATUS_PADDING + tile_idx, StatusWord(SCAN_TILE_PARTIAL)); + } + + /** + * Wait for the corresponding tile to become non-invalid + */ + __device__ __forceinline__ void WaitForValid( + int tile_idx, + StatusWord &status, + T &value) + { + do { + status = ThreadLoad(d_tile_status + TILE_STATUS_PADDING + tile_idx); + + __threadfence(); // prevent hoisting loads from loop or loads below above this one + + } while (status == SCAN_TILE_INVALID); + + if (status == StatusWord(SCAN_TILE_PARTIAL)) + value = ThreadLoad(d_tile_partial + TILE_STATUS_PADDING + tile_idx); + else + value = ThreadLoad(d_tile_inclusive + TILE_STATUS_PADDING + tile_idx); + } +}; + + +/****************************************************************************** + * ReduceByKey tile status interface types for block-cooperative scans + ******************************************************************************/ + +/** + * Tile status interface for reduction by key. + * + */ +template < + typename ValueT, + typename KeyT, + bool SINGLE_WORD = (Traits::PRIMITIVE) && (sizeof(ValueT) + sizeof(KeyT) < 16)> +struct ReduceByKeyScanTileState; + + +/** + * Tile status interface for reduction by key, specialized for scan status and value types that + * cannot be combined into one machine word. + */ +template < + typename ValueT, + typename KeyT> +struct ReduceByKeyScanTileState : + ScanTileState > +{ + typedef ScanTileState > SuperClass; + + /// Constructor + __host__ __device__ __forceinline__ + ReduceByKeyScanTileState() : SuperClass() {} +}; + + +/** + * Tile status interface for reduction by key, specialized for scan status and value types that + * can be combined into one machine word that can be read/written coherently in a single access. + */ +template < + typename ValueT, + typename KeyT> +struct ReduceByKeyScanTileState +{ + typedef KeyValuePairKeyValuePairT; + + // Constants + enum + { + PAIR_SIZE = sizeof(ValueT) + sizeof(KeyT), + TXN_WORD_SIZE = 1 << Log2::VALUE, + STATUS_WORD_SIZE = TXN_WORD_SIZE - PAIR_SIZE, + + TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS, + }; + + // Status word type + typedef typename If<(STATUS_WORD_SIZE == 8), + long long, + typename If<(STATUS_WORD_SIZE == 4), + int, + typename If<(STATUS_WORD_SIZE == 2), + short, + char>::Type>::Type>::Type StatusWord; + + // Status word type + typedef typename If<(TXN_WORD_SIZE == 16), + longlong2, + typename If<(TXN_WORD_SIZE == 8), + long long, + int>::Type>::Type TxnWord; + + // Device word type (for when sizeof(ValueT) == sizeof(KeyT)) + struct TileDescriptorBigStatus + { + KeyT key; + ValueT value; + StatusWord status; + }; + + // Device word type (for when sizeof(ValueT) != sizeof(KeyT)) + struct TileDescriptorLittleStatus + { + ValueT value; + StatusWord status; + KeyT key; + }; + + // Device word type + typedef typename If< + (sizeof(ValueT) == sizeof(KeyT)), + TileDescriptorBigStatus, + TileDescriptorLittleStatus>::Type + TileDescriptor; + + + // Device storage + TxnWord *d_tile_descriptors; + + + /// Constructor + __host__ __device__ __forceinline__ + ReduceByKeyScanTileState() + : + d_tile_descriptors(NULL) + {} + + + /// Initializer + __host__ __device__ __forceinline__ + cudaError_t Init( + int /*num_tiles*/, ///< [in] Number of tiles + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t /*temp_storage_bytes*/) ///< [in] Size in bytes of \t d_temp_storage allocation + { + d_tile_descriptors = reinterpret_cast(d_temp_storage); + return cudaSuccess; + } + + + /** + * Compute device memory needed for tile status + */ + __host__ __device__ __forceinline__ + static cudaError_t AllocationSize( + int num_tiles, ///< [in] Number of tiles + size_t &temp_storage_bytes) ///< [out] Size in bytes of \t d_temp_storage allocation + { + temp_storage_bytes = (num_tiles + TILE_STATUS_PADDING) * sizeof(TileDescriptor); // bytes needed for tile status descriptors + return cudaSuccess; + } + + + /** + * Initialize (from device) + */ + __device__ __forceinline__ void InitializeStatus(int num_tiles) + { + int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; + TxnWord val = TxnWord(); + TileDescriptor *descriptor = reinterpret_cast(&val); + + if (tile_idx < num_tiles) + { + // Not-yet-set + descriptor->status = StatusWord(SCAN_TILE_INVALID); + d_tile_descriptors[TILE_STATUS_PADDING + tile_idx] = val; + } + + if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING)) + { + // Padding + descriptor->status = StatusWord(SCAN_TILE_OOB); + d_tile_descriptors[threadIdx.x] = val; + } + } + + + /** + * Update the specified tile's inclusive value and corresponding status + */ + __device__ __forceinline__ void SetInclusive(int tile_idx, KeyValuePairT tile_inclusive) + { + TileDescriptor tile_descriptor; + tile_descriptor.status = SCAN_TILE_INCLUSIVE; + tile_descriptor.value = tile_inclusive.value; + tile_descriptor.key = tile_inclusive.key; + + TxnWord alias; + *reinterpret_cast(&alias) = tile_descriptor; + ThreadStore(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); + } + + + /** + * Update the specified tile's partial value and corresponding status + */ + __device__ __forceinline__ void SetPartial(int tile_idx, KeyValuePairT tile_partial) + { + TileDescriptor tile_descriptor; + tile_descriptor.status = SCAN_TILE_PARTIAL; + tile_descriptor.value = tile_partial.value; + tile_descriptor.key = tile_partial.key; + + TxnWord alias; + *reinterpret_cast(&alias) = tile_descriptor; + ThreadStore(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); + } + + /** + * Wait for the corresponding tile to become non-invalid + */ + __device__ __forceinline__ void WaitForValid( + int tile_idx, + StatusWord &status, + KeyValuePairT &value) + { +// TxnWord alias = ThreadLoad(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); +// TileDescriptor tile_descriptor = reinterpret_cast(alias); +// +// while (tile_descriptor.status == SCAN_TILE_INVALID) +// { +// __threadfence_block(); // prevent hoisting loads from loop +// +// alias = ThreadLoad(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); +// tile_descriptor = reinterpret_cast(alias); +// } +// +// status = tile_descriptor.status; +// value.value = tile_descriptor.value; +// value.key = tile_descriptor.key; + + TileDescriptor tile_descriptor; + do + { + __threadfence_block(); // prevent hoisting loads from loop + TxnWord alias = ThreadLoad(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); + tile_descriptor = reinterpret_cast(alias); + + } while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff)); + + status = tile_descriptor.status; + value.value = tile_descriptor.value; + value.key = tile_descriptor.key; + } + +}; + + +/****************************************************************************** + * Prefix call-back operator for coupling local block scan within a + * block-cooperative scan + ******************************************************************************/ + +/** + * Stateful block-scan prefix functor. Provides the the running prefix for + * the current tile by using the call-back warp to wait on on + * aggregates/prefixes from predecessor tiles to become available. + */ +template < + typename T, + typename ScanOpT, + typename ScanTileStateT, + int PTX_ARCH = CUB_PTX_ARCH> +struct TilePrefixCallbackOp +{ + // Parameterized warp reduce + typedef WarpReduce WarpReduceT; + + // Temporary storage type + struct _TempStorage + { + typename WarpReduceT::TempStorage warp_reduce; + T exclusive_prefix; + T inclusive_prefix; + T block_aggregate; + }; + + // Alias wrapper allowing temporary storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + // Type of status word + typedef typename ScanTileStateT::StatusWord StatusWord; + + // Fields + _TempStorage& temp_storage; ///< Reference to a warp-reduction instance + ScanTileStateT& tile_status; ///< Interface to tile status + ScanOpT scan_op; ///< Binary scan operator + int tile_idx; ///< The current tile index + T exclusive_prefix; ///< Exclusive prefix for the tile + T inclusive_prefix; ///< Inclusive prefix for the tile + + // Constructor + __device__ __forceinline__ + TilePrefixCallbackOp( + ScanTileStateT &tile_status, + TempStorage &temp_storage, + ScanOpT scan_op, + int tile_idx) + : + temp_storage(temp_storage.Alias()), + tile_status(tile_status), + scan_op(scan_op), + tile_idx(tile_idx) {} + + + // Block until all predecessors within the warp-wide window have non-invalid status + __device__ __forceinline__ + void ProcessWindow( + int predecessor_idx, ///< Preceding tile index to inspect + StatusWord &predecessor_status, ///< [out] Preceding tile status + T &window_aggregate) ///< [out] Relevant partial reduction from this window of preceding tiles + { + T value; + tile_status.WaitForValid(predecessor_idx, predecessor_status, value); + + // Perform a segmented reduction to get the prefix for the current window. + // Use the swizzled scan operator because we are now scanning *down* towards thread0. + + int tail_flag = (predecessor_status == StatusWord(SCAN_TILE_INCLUSIVE)); + window_aggregate = WarpReduceT(temp_storage.warp_reduce).TailSegmentedReduce( + value, + tail_flag, + SwizzleScanOp(scan_op)); + } + + + // BlockScan prefix callback functor (called by the first warp) + __device__ __forceinline__ + T operator()(T block_aggregate) + { + + // Update our status with our tile-aggregate + if (threadIdx.x == 0) + { + temp_storage.block_aggregate = block_aggregate; + tile_status.SetPartial(tile_idx, block_aggregate); + } + + int predecessor_idx = tile_idx - threadIdx.x - 1; + StatusWord predecessor_status; + T window_aggregate; + + // Wait for the warp-wide window of predecessor tiles to become valid + ProcessWindow(predecessor_idx, predecessor_status, window_aggregate); + + // The exclusive tile prefix starts out as the current window aggregate + exclusive_prefix = window_aggregate; + + // Keep sliding the window back until we come across a tile whose inclusive prefix is known + while (WARP_ALL((predecessor_status != StatusWord(SCAN_TILE_INCLUSIVE)), 0xffffffff)) + { + predecessor_idx -= CUB_PTX_WARP_THREADS; + + // Update exclusive tile prefix with the window prefix + ProcessWindow(predecessor_idx, predecessor_status, window_aggregate); + exclusive_prefix = scan_op(window_aggregate, exclusive_prefix); + } + + // Compute the inclusive tile prefix and update the status for this tile + if (threadIdx.x == 0) + { + inclusive_prefix = scan_op(exclusive_prefix, block_aggregate); + tile_status.SetInclusive(tile_idx, inclusive_prefix); + + temp_storage.exclusive_prefix = exclusive_prefix; + temp_storage.inclusive_prefix = inclusive_prefix; + } + + // Return exclusive_prefix + return exclusive_prefix; + } + + // Get the exclusive prefix stored in temporary storage + __device__ __forceinline__ + T GetExclusivePrefix() + { + return temp_storage.exclusive_prefix; + } + + // Get the inclusive prefix stored in temporary storage + __device__ __forceinline__ + T GetInclusivePrefix() + { + return temp_storage.inclusive_prefix; + } + + // Get the block aggregate stored in temporary storage + __device__ __forceinline__ + T GetBlockAggregate() + { + return temp_storage.block_aggregate; + } + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/block/block_adjacent_difference.cuh b/fastertransformer/cuda/cub/block/block_adjacent_difference.cuh new file mode 100644 index 000000000..acef9f056 --- /dev/null +++ b/fastertransformer/cuda/cub/block/block_adjacent_difference.cuh @@ -0,0 +1,596 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * The cub::BlockDiscontinuity class provides [collective](index.html#sec0) methods for flagging discontinuities within an ordered set of items partitioned across a CUDA thread block. + */ + +#pragma once + +#include "../util_type.cuh" +#include "../util_ptx.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +template < + typename T, + int BLOCK_DIM_X, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int PTX_ARCH = CUB_PTX_ARCH> +class BlockAdjacentDifference +{ +private: + + /****************************************************************************** + * Constants and type definitions + ******************************************************************************/ + + /// Constants + enum + { + /// The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + }; + + + /// Shared memory storage layout type (last element from each thread's input) + struct _TempStorage + { + T first_items[BLOCK_THREADS]; + T last_items[BLOCK_THREADS]; + }; + + + /****************************************************************************** + * Utility methods + ******************************************************************************/ + + /// Internal storage allocator + __device__ __forceinline__ _TempStorage& PrivateStorage() + { + __shared__ _TempStorage private_storage; + return private_storage; + } + + + /// Specialization for when FlagOp has third index param + template ::HAS_PARAM> + struct ApplyOp + { + // Apply flag operator + static __device__ __forceinline__ T FlagT(FlagOp flag_op, const T &a, const T &b, int idx) + { + return flag_op(b, a, idx); + } + }; + + /// Specialization for when FlagOp does not have a third index param + template + struct ApplyOp + { + // Apply flag operator + static __device__ __forceinline__ T FlagT(FlagOp flag_op, const T &a, const T &b, int /*idx*/) + { + return flag_op(b, a); + } + }; + + /// Templated unrolling of item comparison (inductive case) + template + struct Iterate + { + // Head flags + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + static __device__ __forceinline__ void FlagHeads( + int linear_tid, + FlagT (&flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&preds)[ITEMS_PER_THREAD], ///< [out] Calling thread's predecessor items + FlagOp flag_op) ///< [in] Binary boolean flag predicate + { + preds[ITERATION] = input[ITERATION - 1]; + + flags[ITERATION] = ApplyOp::FlagT( + flag_op, + preds[ITERATION], + input[ITERATION], + (linear_tid * ITEMS_PER_THREAD) + ITERATION); + + Iterate::FlagHeads(linear_tid, flags, input, preds, flag_op); + } + + // Tail flags + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + static __device__ __forceinline__ void FlagTails( + int linear_tid, + FlagT (&flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp flag_op) ///< [in] Binary boolean flag predicate + { + flags[ITERATION] = ApplyOp::FlagT( + flag_op, + input[ITERATION], + input[ITERATION + 1], + (linear_tid * ITEMS_PER_THREAD) + ITERATION + 1); + + Iterate::FlagTails(linear_tid, flags, input, flag_op); + } + + }; + + /// Templated unrolling of item comparison (termination case) + template + struct Iterate + { + // Head flags + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + static __device__ __forceinline__ void FlagHeads( + int /*linear_tid*/, + FlagT (&/*flags*/)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T (&/*input*/)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&/*preds*/)[ITEMS_PER_THREAD], ///< [out] Calling thread's predecessor items + FlagOp /*flag_op*/) ///< [in] Binary boolean flag predicate + {} + + // Tail flags + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + static __device__ __forceinline__ void FlagTails( + int /*linear_tid*/, + FlagT (&/*flags*/)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T (&/*input*/)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp /*flag_op*/) ///< [in] Binary boolean flag predicate + {} + }; + + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + /// Shared storage reference + _TempStorage &temp_storage; + + /// Linear thread-id + unsigned int linear_tid; + + +public: + + /// \smemstorage{BlockDiscontinuity} + struct TempStorage : Uninitialized<_TempStorage> {}; + + + /******************************************************************//** + * \name Collective constructors + *********************************************************************/ + //@{ + + /** + * \brief Collective constructor using a private static allocation of shared memory as temporary storage. + */ + __device__ __forceinline__ BlockAdjacentDifference() + : + temp_storage(PrivateStorage()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + /** + * \brief Collective constructor using the specified memory allocation as temporary storage. + */ + __device__ __forceinline__ BlockAdjacentDifference( + TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + //@} end member group + /******************************************************************//** + * \name Head flag operations + *********************************************************************/ + //@{ + + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagHeads( + FlagT (&head_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&preds)[ITEMS_PER_THREAD], ///< [out] Calling thread's predecessor items + FlagOp flag_op) ///< [in] Binary boolean flag predicate + { + // Share last item + temp_storage.last_items[linear_tid] = input[ITEMS_PER_THREAD - 1]; + + CTA_SYNC(); + + if (linear_tid == 0) + { + // Set flag for first thread-item (preds[0] is undefined) + head_flags[0] = 1; + } + else + { + preds[0] = temp_storage.last_items[linear_tid - 1]; + head_flags[0] = ApplyOp::FlagT(flag_op, preds[0], input[0], linear_tid * ITEMS_PER_THREAD); + } + + // Set head_flags for remaining items + Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op); + } + + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagHeads( + FlagT (&head_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&preds)[ITEMS_PER_THREAD], ///< [out] Calling thread's predecessor items + FlagOp flag_op, ///< [in] Binary boolean flag predicate + T tile_predecessor_item) ///< [in] [thread0 only] Item with which to compare the first tile item (input0 from thread0). + { + // Share last item + temp_storage.last_items[linear_tid] = input[ITEMS_PER_THREAD - 1]; + + CTA_SYNC(); + + // Set flag for first thread-item + preds[0] = (linear_tid == 0) ? + tile_predecessor_item : // First thread + temp_storage.last_items[linear_tid - 1]; + + head_flags[0] = ApplyOp::FlagT(flag_op, preds[0], input[0], linear_tid * ITEMS_PER_THREAD); + + // Set head_flags for remaining items + Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op); + } + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagHeads( + FlagT (&head_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp flag_op) ///< [in] Binary boolean flag predicate + { + T preds[ITEMS_PER_THREAD]; + FlagHeads(head_flags, input, preds, flag_op); + } + + + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagHeads( + FlagT (&head_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp flag_op, ///< [in] Binary boolean flag predicate + T tile_predecessor_item) ///< [in] [thread0 only] Item with which to compare the first tile item (input0 from thread0). + { + T preds[ITEMS_PER_THREAD]; + FlagHeads(head_flags, input, preds, flag_op, tile_predecessor_item); + } + + + + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagTails( + FlagT (&tail_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity tail_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp flag_op) ///< [in] Binary boolean flag predicate + { + // Share first item + temp_storage.first_items[linear_tid] = input[0]; + + CTA_SYNC(); + + // Set flag for last thread-item + tail_flags[ITEMS_PER_THREAD - 1] = (linear_tid == BLOCK_THREADS - 1) ? + 1 : // Last thread + ApplyOp::FlagT( + flag_op, + input[ITEMS_PER_THREAD - 1], + temp_storage.first_items[linear_tid + 1], + (linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD); + + // Set tail_flags for remaining items + Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op); + } + + + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagTails( + FlagT (&tail_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity tail_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp flag_op, ///< [in] Binary boolean flag predicate + T tile_successor_item) ///< [in] [threadBLOCK_THREADS-1 only] Item with which to compare the last tile item (inputITEMS_PER_THREAD-1 from threadBLOCK_THREADS-1). + { + // Share first item + temp_storage.first_items[linear_tid] = input[0]; + + CTA_SYNC(); + + // Set flag for last thread-item + T successor_item = (linear_tid == BLOCK_THREADS - 1) ? + tile_successor_item : // Last thread + temp_storage.first_items[linear_tid + 1]; + + tail_flags[ITEMS_PER_THREAD - 1] = ApplyOp::FlagT( + flag_op, + input[ITEMS_PER_THREAD - 1], + successor_item, + (linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD); + + // Set tail_flags for remaining items + Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op); + } + + + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagHeadsAndTails( + FlagT (&head_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + FlagT (&tail_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity tail_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp flag_op) ///< [in] Binary boolean flag predicate + { + // Share first and last items + temp_storage.first_items[linear_tid] = input[0]; + temp_storage.last_items[linear_tid] = input[ITEMS_PER_THREAD - 1]; + + CTA_SYNC(); + + T preds[ITEMS_PER_THREAD]; + + // Set flag for first thread-item + preds[0] = temp_storage.last_items[linear_tid - 1]; + if (linear_tid == 0) + { + head_flags[0] = 1; + } + else + { + head_flags[0] = ApplyOp::FlagT( + flag_op, + preds[0], + input[0], + linear_tid * ITEMS_PER_THREAD); + } + + + // Set flag for last thread-item + tail_flags[ITEMS_PER_THREAD - 1] = (linear_tid == BLOCK_THREADS - 1) ? + 1 : // Last thread + ApplyOp::FlagT( + flag_op, + input[ITEMS_PER_THREAD - 1], + temp_storage.first_items[linear_tid + 1], + (linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD); + + // Set head_flags for remaining items + Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op); + + // Set tail_flags for remaining items + Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op); + } + + + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagHeadsAndTails( + FlagT (&head_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + FlagT (&tail_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity tail_flags + T tile_successor_item, ///< [in] [threadBLOCK_THREADS-1 only] Item with which to compare the last tile item (inputITEMS_PER_THREAD-1 from threadBLOCK_THREADS-1). + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp flag_op) ///< [in] Binary boolean flag predicate + { + // Share first and last items + temp_storage.first_items[linear_tid] = input[0]; + temp_storage.last_items[linear_tid] = input[ITEMS_PER_THREAD - 1]; + + CTA_SYNC(); + + T preds[ITEMS_PER_THREAD]; + + // Set flag for first thread-item + if (linear_tid == 0) + { + head_flags[0] = 1; + } + else + { + preds[0] = temp_storage.last_items[linear_tid - 1]; + head_flags[0] = ApplyOp::FlagT( + flag_op, + preds[0], + input[0], + linear_tid * ITEMS_PER_THREAD); + } + + // Set flag for last thread-item + T successor_item = (linear_tid == BLOCK_THREADS - 1) ? + tile_successor_item : // Last thread + temp_storage.first_items[linear_tid + 1]; + + tail_flags[ITEMS_PER_THREAD - 1] = ApplyOp::FlagT( + flag_op, + input[ITEMS_PER_THREAD - 1], + successor_item, + (linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD); + + // Set head_flags for remaining items + Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op); + + // Set tail_flags for remaining items + Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op); + } + + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagHeadsAndTails( + FlagT (&head_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T tile_predecessor_item, ///< [in] [thread0 only] Item with which to compare the first tile item (input0 from thread0). + FlagT (&tail_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity tail_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp flag_op) ///< [in] Binary boolean flag predicate + { + // Share first and last items + temp_storage.first_items[linear_tid] = input[0]; + temp_storage.last_items[linear_tid] = input[ITEMS_PER_THREAD - 1]; + + CTA_SYNC(); + + T preds[ITEMS_PER_THREAD]; + + // Set flag for first thread-item + preds[0] = (linear_tid == 0) ? + tile_predecessor_item : // First thread + temp_storage.last_items[linear_tid - 1]; + + head_flags[0] = ApplyOp::FlagT( + flag_op, + preds[0], + input[0], + linear_tid * ITEMS_PER_THREAD); + + // Set flag for last thread-item + tail_flags[ITEMS_PER_THREAD - 1] = (linear_tid == BLOCK_THREADS - 1) ? + 1 : // Last thread + ApplyOp::FlagT( + flag_op, + input[ITEMS_PER_THREAD - 1], + temp_storage.first_items[linear_tid + 1], + (linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD); + + // Set head_flags for remaining items + Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op); + + // Set tail_flags for remaining items + Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op); + } + + + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagHeadsAndTails( + FlagT (&head_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T tile_predecessor_item, ///< [in] [thread0 only] Item with which to compare the first tile item (input0 from thread0). + FlagT (&tail_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity tail_flags + T tile_successor_item, ///< [in] [threadBLOCK_THREADS-1 only] Item with which to compare the last tile item (inputITEMS_PER_THREAD-1 from threadBLOCK_THREADS-1). + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp flag_op) ///< [in] Binary boolean flag predicate + { + // Share first and last items + temp_storage.first_items[linear_tid] = input[0]; + temp_storage.last_items[linear_tid] = input[ITEMS_PER_THREAD - 1]; + + CTA_SYNC(); + + T preds[ITEMS_PER_THREAD]; + + // Set flag for first thread-item + preds[0] = (linear_tid == 0) ? + tile_predecessor_item : // First thread + temp_storage.last_items[linear_tid - 1]; + + head_flags[0] = ApplyOp::FlagT( + flag_op, + preds[0], + input[0], + linear_tid * ITEMS_PER_THREAD); + + // Set flag for last thread-item + T successor_item = (linear_tid == BLOCK_THREADS - 1) ? + tile_successor_item : // Last thread + temp_storage.first_items[linear_tid + 1]; + + tail_flags[ITEMS_PER_THREAD - 1] = ApplyOp::FlagT( + flag_op, + input[ITEMS_PER_THREAD - 1], + successor_item, + (linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD); + + // Set head_flags for remaining items + Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op); + + // Set tail_flags for remaining items + Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op); + } + + + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/block/block_discontinuity.cuh b/fastertransformer/cuda/cub/block/block_discontinuity.cuh new file mode 100644 index 000000000..503e3e0b0 --- /dev/null +++ b/fastertransformer/cuda/cub/block/block_discontinuity.cuh @@ -0,0 +1,1148 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * The cub::BlockDiscontinuity class provides [collective](index.html#sec0) methods for flagging discontinuities within an ordered set of items partitioned across a CUDA thread block. + */ + +#pragma once + +#include "../util_type.cuh" +#include "../util_ptx.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \brief The BlockDiscontinuity class provides [collective](index.html#sec0) methods for flagging discontinuities within an ordered set of items partitioned across a CUDA thread block. ![](discont_logo.png) + * \ingroup BlockModule + * + * \tparam T The data type to be flagged. + * \tparam BLOCK_DIM_X The thread block length in threads along the X dimension + * \tparam BLOCK_DIM_Y [optional] The thread block length in threads along the Y dimension (default: 1) + * \tparam BLOCK_DIM_Z [optional] The thread block length in threads along the Z dimension (default: 1) + * \tparam PTX_ARCH [optional] \ptxversion + * + * \par Overview + * - A set of "head flags" (or "tail flags") is often used to indicate corresponding items + * that differ from their predecessors (or successors). For example, head flags are convenient + * for demarcating disjoint data segments as part of a segmented scan or reduction. + * - \blocked + * + * \par Performance Considerations + * - \granularity + * + * \par A Simple Example + * \blockcollective{BlockDiscontinuity} + * \par + * The code snippet below illustrates the head flagging of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockDiscontinuity for a 1D block of 128 threads on type int + * typedef cub::BlockDiscontinuity BlockDiscontinuity; + * + * // Allocate shared memory for BlockDiscontinuity + * __shared__ typename BlockDiscontinuity::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Collectively compute head flags for discontinuities in the segment + * int head_flags[4]; + * BlockDiscontinuity(temp_storage).FlagHeads(head_flags, thread_data, cub::Inequality()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is + * { [0,0,1,1], [1,1,1,1], [2,3,3,3], [3,4,4,4], ... }. + * The corresponding output \p head_flags in those threads will be + * { [1,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }. + * + * \par Performance Considerations + * - Incurs zero bank conflicts for most types + * + */ +template < + typename T, + int BLOCK_DIM_X, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int PTX_ARCH = CUB_PTX_ARCH> +class BlockDiscontinuity +{ +private: + + /****************************************************************************** + * Constants and type definitions + ******************************************************************************/ + + /// Constants + enum + { + /// The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + }; + + + /// Shared memory storage layout type (last element from each thread's input) + struct _TempStorage + { + T first_items[BLOCK_THREADS]; + T last_items[BLOCK_THREADS]; + }; + + + /****************************************************************************** + * Utility methods + ******************************************************************************/ + + /// Internal storage allocator + __device__ __forceinline__ _TempStorage& PrivateStorage() + { + __shared__ _TempStorage private_storage; + return private_storage; + } + + + /// Specialization for when FlagOp has third index param + template ::HAS_PARAM> + struct ApplyOp + { + // Apply flag operator + static __device__ __forceinline__ bool FlagT(FlagOp flag_op, const T &a, const T &b, int idx) + { + return flag_op(a, b, idx); + } + }; + + /// Specialization for when FlagOp does not have a third index param + template + struct ApplyOp + { + // Apply flag operator + static __device__ __forceinline__ bool FlagT(FlagOp flag_op, const T &a, const T &b, int /*idx*/) + { + return flag_op(a, b); + } + }; + + /// Templated unrolling of item comparison (inductive case) + template + struct Iterate + { + // Head flags + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + static __device__ __forceinline__ void FlagHeads( + int linear_tid, + FlagT (&flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&preds)[ITEMS_PER_THREAD], ///< [out] Calling thread's predecessor items + FlagOp flag_op) ///< [in] Binary boolean flag predicate + { + preds[ITERATION] = input[ITERATION - 1]; + + flags[ITERATION] = ApplyOp::FlagT( + flag_op, + preds[ITERATION], + input[ITERATION], + (linear_tid * ITEMS_PER_THREAD) + ITERATION); + + Iterate::FlagHeads(linear_tid, flags, input, preds, flag_op); + } + + // Tail flags + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + static __device__ __forceinline__ void FlagTails( + int linear_tid, + FlagT (&flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp flag_op) ///< [in] Binary boolean flag predicate + { + flags[ITERATION] = ApplyOp::FlagT( + flag_op, + input[ITERATION], + input[ITERATION + 1], + (linear_tid * ITEMS_PER_THREAD) + ITERATION + 1); + + Iterate::FlagTails(linear_tid, flags, input, flag_op); + } + + }; + + /// Templated unrolling of item comparison (termination case) + template + struct Iterate + { + // Head flags + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + static __device__ __forceinline__ void FlagHeads( + int /*linear_tid*/, + FlagT (&/*flags*/)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T (&/*input*/)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&/*preds*/)[ITEMS_PER_THREAD], ///< [out] Calling thread's predecessor items + FlagOp /*flag_op*/) ///< [in] Binary boolean flag predicate + {} + + // Tail flags + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + static __device__ __forceinline__ void FlagTails( + int /*linear_tid*/, + FlagT (&/*flags*/)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T (&/*input*/)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp /*flag_op*/) ///< [in] Binary boolean flag predicate + {} + }; + + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + /// Shared storage reference + _TempStorage &temp_storage; + + /// Linear thread-id + unsigned int linear_tid; + + +public: + + /// \smemstorage{BlockDiscontinuity} + struct TempStorage : Uninitialized<_TempStorage> {}; + + + /******************************************************************//** + * \name Collective constructors + *********************************************************************/ + //@{ + + /** + * \brief Collective constructor using a private static allocation of shared memory as temporary storage. + */ + __device__ __forceinline__ BlockDiscontinuity() + : + temp_storage(PrivateStorage()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + /** + * \brief Collective constructor using the specified memory allocation as temporary storage. + */ + __device__ __forceinline__ BlockDiscontinuity( + TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + //@} end member group + /******************************************************************//** + * \name Head flag operations + *********************************************************************/ + //@{ + + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagHeads( + FlagT (&head_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&preds)[ITEMS_PER_THREAD], ///< [out] Calling thread's predecessor items + FlagOp flag_op) ///< [in] Binary boolean flag predicate + { + // Share last item + temp_storage.last_items[linear_tid] = input[ITEMS_PER_THREAD - 1]; + + CTA_SYNC(); + + if (linear_tid == 0) + { + // Set flag for first thread-item (preds[0] is undefined) + head_flags[0] = 1; + } + else + { + preds[0] = temp_storage.last_items[linear_tid - 1]; + head_flags[0] = ApplyOp::FlagT(flag_op, preds[0], input[0], linear_tid * ITEMS_PER_THREAD); + } + + // Set head_flags for remaining items + Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op); + } + + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagHeads( + FlagT (&head_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&preds)[ITEMS_PER_THREAD], ///< [out] Calling thread's predecessor items + FlagOp flag_op, ///< [in] Binary boolean flag predicate + T tile_predecessor_item) ///< [in] [thread0 only] Item with which to compare the first tile item (input0 from thread0). + { + // Share last item + temp_storage.last_items[linear_tid] = input[ITEMS_PER_THREAD - 1]; + + CTA_SYNC(); + + // Set flag for first thread-item + preds[0] = (linear_tid == 0) ? + tile_predecessor_item : // First thread + temp_storage.last_items[linear_tid - 1]; + + head_flags[0] = ApplyOp::FlagT(flag_op, preds[0], input[0], linear_tid * ITEMS_PER_THREAD); + + // Set head_flags for remaining items + Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op); + } + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + + /** + * \brief Sets head flags indicating discontinuities between items partitioned across the thread block, for which the first item has no reference and is always flagged. + * + * \par + * - The flag head_flagsi is set for item + * inputi when + * flag_op(previous-item, inputi) + * returns \p true (where previous-item is either the preceding item + * in the same thread or the last item in the previous thread). + * - For thread0, item input0 is always flagged. + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates the head-flagging of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockDiscontinuity for a 1D block of 128 threads on type int + * typedef cub::BlockDiscontinuity BlockDiscontinuity; + * + * // Allocate shared memory for BlockDiscontinuity + * __shared__ typename BlockDiscontinuity::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Collectively compute head flags for discontinuities in the segment + * int head_flags[4]; + * BlockDiscontinuity(temp_storage).FlagHeads(head_flags, thread_data, cub::Inequality()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is + * { [0,0,1,1], [1,1,1,1], [2,3,3,3], [3,4,4,4], ... }. + * The corresponding output \p head_flags in those threads will be + * { [1,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam FlagT [inferred] The flag type (must be an integer type) + * \tparam FlagOp [inferred] Binary predicate functor type having member T operator()(const T &a, const T &b) or member T operator()(const T &a, const T &b, unsigned int b_index), and returning \p true if a discontinuity exists between \p a and \p b, otherwise \p false. \p b_index is the rank of b in the aggregate tile of data. + */ + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagHeads( + FlagT (&head_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp flag_op) ///< [in] Binary boolean flag predicate + { + T preds[ITEMS_PER_THREAD]; + FlagHeads(head_flags, input, preds, flag_op); + } + + + /** + * \brief Sets head flags indicating discontinuities between items partitioned across the thread block. + * + * \par + * - The flag head_flagsi is set for item + * inputi when + * flag_op(previous-item, inputi) + * returns \p true (where previous-item is either the preceding item + * in the same thread or the last item in the previous thread). + * - For thread0, item input0 is compared + * against \p tile_predecessor_item. + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates the head-flagging of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockDiscontinuity for a 1D block of 128 threads on type int + * typedef cub::BlockDiscontinuity BlockDiscontinuity; + * + * // Allocate shared memory for BlockDiscontinuity + * __shared__ typename BlockDiscontinuity::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Have thread0 obtain the predecessor item for the entire tile + * int tile_predecessor_item; + * if (threadIdx.x == 0) tile_predecessor_item == ... + * + * // Collectively compute head flags for discontinuities in the segment + * int head_flags[4]; + * BlockDiscontinuity(temp_storage).FlagHeads( + * head_flags, thread_data, cub::Inequality(), tile_predecessor_item); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is + * { [0,0,1,1], [1,1,1,1], [2,3,3,3], [3,4,4,4], ... }, + * and that \p tile_predecessor_item is \p 0. The corresponding output \p head_flags in those threads will be + * { [0,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam FlagT [inferred] The flag type (must be an integer type) + * \tparam FlagOp [inferred] Binary predicate functor type having member T operator()(const T &a, const T &b) or member T operator()(const T &a, const T &b, unsigned int b_index), and returning \p true if a discontinuity exists between \p a and \p b, otherwise \p false. \p b_index is the rank of b in the aggregate tile of data. + */ + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagHeads( + FlagT (&head_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp flag_op, ///< [in] Binary boolean flag predicate + T tile_predecessor_item) ///< [in] [thread0 only] Item with which to compare the first tile item (input0 from thread0). + { + T preds[ITEMS_PER_THREAD]; + FlagHeads(head_flags, input, preds, flag_op, tile_predecessor_item); + } + + + + //@} end member group + /******************************************************************//** + * \name Tail flag operations + *********************************************************************/ + //@{ + + + /** + * \brief Sets tail flags indicating discontinuities between items partitioned across the thread block, for which the last item has no reference and is always flagged. + * + * \par + * - The flag tail_flagsi is set for item + * inputi when + * flag_op(inputi, next-item) + * returns \p true (where next-item is either the next item + * in the same thread or the first item in the next thread). + * - For threadBLOCK_THREADS-1, item + * inputITEMS_PER_THREAD-1 is always flagged. + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates the tail-flagging of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockDiscontinuity for a 1D block of 128 threads on type int + * typedef cub::BlockDiscontinuity BlockDiscontinuity; + * + * // Allocate shared memory for BlockDiscontinuity + * __shared__ typename BlockDiscontinuity::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Collectively compute tail flags for discontinuities in the segment + * int tail_flags[4]; + * BlockDiscontinuity(temp_storage).FlagTails(tail_flags, thread_data, cub::Inequality()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is + * { [0,0,1,1], [1,1,1,1], [2,3,3,3], ..., [124,125,125,125] }. + * The corresponding output \p tail_flags in those threads will be + * { [0,1,0,0], [0,0,0,1], [1,0,0,...], ..., [1,0,0,1] }. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam FlagT [inferred] The flag type (must be an integer type) + * \tparam FlagOp [inferred] Binary predicate functor type having member T operator()(const T &a, const T &b) or member T operator()(const T &a, const T &b, unsigned int b_index), and returning \p true if a discontinuity exists between \p a and \p b, otherwise \p false. \p b_index is the rank of b in the aggregate tile of data. + */ + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagTails( + FlagT (&tail_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity tail_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp flag_op) ///< [in] Binary boolean flag predicate + { + // Share first item + temp_storage.first_items[linear_tid] = input[0]; + + CTA_SYNC(); + + // Set flag for last thread-item + tail_flags[ITEMS_PER_THREAD - 1] = (linear_tid == BLOCK_THREADS - 1) ? + 1 : // Last thread + ApplyOp::FlagT( + flag_op, + input[ITEMS_PER_THREAD - 1], + temp_storage.first_items[linear_tid + 1], + (linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD); + + // Set tail_flags for remaining items + Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op); + } + + + /** + * \brief Sets tail flags indicating discontinuities between items partitioned across the thread block. + * + * \par + * - The flag tail_flagsi is set for item + * inputi when + * flag_op(inputi, next-item) + * returns \p true (where next-item is either the next item + * in the same thread or the first item in the next thread). + * - For threadBLOCK_THREADS-1, item + * inputITEMS_PER_THREAD-1 is compared + * against \p tile_successor_item. + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates the tail-flagging of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockDiscontinuity for a 1D block of 128 threads on type int + * typedef cub::BlockDiscontinuity BlockDiscontinuity; + * + * // Allocate shared memory for BlockDiscontinuity + * __shared__ typename BlockDiscontinuity::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Have thread127 obtain the successor item for the entire tile + * int tile_successor_item; + * if (threadIdx.x == 127) tile_successor_item == ... + * + * // Collectively compute tail flags for discontinuities in the segment + * int tail_flags[4]; + * BlockDiscontinuity(temp_storage).FlagTails( + * tail_flags, thread_data, cub::Inequality(), tile_successor_item); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is + * { [0,0,1,1], [1,1,1,1], [2,3,3,3], ..., [124,125,125,125] } + * and that \p tile_successor_item is \p 125. The corresponding output \p tail_flags in those threads will be + * { [0,1,0,0], [0,0,0,1], [1,0,0,...], ..., [1,0,0,0] }. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam FlagT [inferred] The flag type (must be an integer type) + * \tparam FlagOp [inferred] Binary predicate functor type having member T operator()(const T &a, const T &b) or member T operator()(const T &a, const T &b, unsigned int b_index), and returning \p true if a discontinuity exists between \p a and \p b, otherwise \p false. \p b_index is the rank of b in the aggregate tile of data. + */ + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagTails( + FlagT (&tail_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity tail_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp flag_op, ///< [in] Binary boolean flag predicate + T tile_successor_item) ///< [in] [threadBLOCK_THREADS-1 only] Item with which to compare the last tile item (inputITEMS_PER_THREAD-1 from threadBLOCK_THREADS-1). + { + // Share first item + temp_storage.first_items[linear_tid] = input[0]; + + CTA_SYNC(); + + // Set flag for last thread-item + T successor_item = (linear_tid == BLOCK_THREADS - 1) ? + tile_successor_item : // Last thread + temp_storage.first_items[linear_tid + 1]; + + tail_flags[ITEMS_PER_THREAD - 1] = ApplyOp::FlagT( + flag_op, + input[ITEMS_PER_THREAD - 1], + successor_item, + (linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD); + + // Set tail_flags for remaining items + Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op); + } + + + //@} end member group + /******************************************************************//** + * \name Head & tail flag operations + *********************************************************************/ + //@{ + + + /** + * \brief Sets both head and tail flags indicating discontinuities between items partitioned across the thread block. + * + * \par + * - The flag head_flagsi is set for item + * inputi when + * flag_op(previous-item, inputi) + * returns \p true (where previous-item is either the preceding item + * in the same thread or the last item in the previous thread). + * - For thread0, item input0 is always flagged. + * - The flag tail_flagsi is set for item + * inputi when + * flag_op(inputi, next-item) + * returns \p true (where next-item is either the next item + * in the same thread or the first item in the next thread). + * - For threadBLOCK_THREADS-1, item + * inputITEMS_PER_THREAD-1 is always flagged. + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates the head- and tail-flagging of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockDiscontinuity for a 1D block of 128 threads on type int + * typedef cub::BlockDiscontinuity BlockDiscontinuity; + * + * // Allocate shared memory for BlockDiscontinuity + * __shared__ typename BlockDiscontinuity::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Collectively compute head and flags for discontinuities in the segment + * int head_flags[4]; + * int tail_flags[4]; + * BlockDiscontinuity(temp_storage).FlagTails( + * head_flags, tail_flags, thread_data, cub::Inequality()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is + * { [0,0,1,1], [1,1,1,1], [2,3,3,3], ..., [124,125,125,125] } + * and that the tile_successor_item is \p 125. The corresponding output \p head_flags + * in those threads will be { [1,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }. + * and the corresponding output \p tail_flags in those threads will be + * { [0,1,0,0], [0,0,0,1], [1,0,0,...], ..., [1,0,0,1] }. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam FlagT [inferred] The flag type (must be an integer type) + * \tparam FlagOp [inferred] Binary predicate functor type having member T operator()(const T &a, const T &b) or member T operator()(const T &a, const T &b, unsigned int b_index), and returning \p true if a discontinuity exists between \p a and \p b, otherwise \p false. \p b_index is the rank of b in the aggregate tile of data. + */ + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagHeadsAndTails( + FlagT (&head_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + FlagT (&tail_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity tail_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp flag_op) ///< [in] Binary boolean flag predicate + { + // Share first and last items + temp_storage.first_items[linear_tid] = input[0]; + temp_storage.last_items[linear_tid] = input[ITEMS_PER_THREAD - 1]; + + CTA_SYNC(); + + T preds[ITEMS_PER_THREAD]; + + // Set flag for first thread-item + preds[0] = temp_storage.last_items[linear_tid - 1]; + if (linear_tid == 0) + { + head_flags[0] = 1; + } + else + { + head_flags[0] = ApplyOp::FlagT( + flag_op, + preds[0], + input[0], + linear_tid * ITEMS_PER_THREAD); + } + + + // Set flag for last thread-item + tail_flags[ITEMS_PER_THREAD - 1] = (linear_tid == BLOCK_THREADS - 1) ? + 1 : // Last thread + ApplyOp::FlagT( + flag_op, + input[ITEMS_PER_THREAD - 1], + temp_storage.first_items[linear_tid + 1], + (linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD); + + // Set head_flags for remaining items + Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op); + + // Set tail_flags for remaining items + Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op); + } + + + /** + * \brief Sets both head and tail flags indicating discontinuities between items partitioned across the thread block. + * + * \par + * - The flag head_flagsi is set for item + * inputi when + * flag_op(previous-item, inputi) + * returns \p true (where previous-item is either the preceding item + * in the same thread or the last item in the previous thread). + * - For thread0, item input0 is always flagged. + * - The flag tail_flagsi is set for item + * inputi when + * flag_op(inputi, next-item) + * returns \p true (where next-item is either the next item + * in the same thread or the first item in the next thread). + * - For threadBLOCK_THREADS-1, item + * inputITEMS_PER_THREAD-1 is compared + * against \p tile_predecessor_item. + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates the head- and tail-flagging of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockDiscontinuity for a 1D block of 128 threads on type int + * typedef cub::BlockDiscontinuity BlockDiscontinuity; + * + * // Allocate shared memory for BlockDiscontinuity + * __shared__ typename BlockDiscontinuity::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Have thread127 obtain the successor item for the entire tile + * int tile_successor_item; + * if (threadIdx.x == 127) tile_successor_item == ... + * + * // Collectively compute head and flags for discontinuities in the segment + * int head_flags[4]; + * int tail_flags[4]; + * BlockDiscontinuity(temp_storage).FlagTails( + * head_flags, tail_flags, tile_successor_item, thread_data, cub::Inequality()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is + * { [0,0,1,1], [1,1,1,1], [2,3,3,3], ..., [124,125,125,125] } + * and that the tile_successor_item is \p 125. The corresponding output \p head_flags + * in those threads will be { [1,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }. + * and the corresponding output \p tail_flags in those threads will be + * { [0,1,0,0], [0,0,0,1], [1,0,0,...], ..., [1,0,0,0] }. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam FlagT [inferred] The flag type (must be an integer type) + * \tparam FlagOp [inferred] Binary predicate functor type having member T operator()(const T &a, const T &b) or member T operator()(const T &a, const T &b, unsigned int b_index), and returning \p true if a discontinuity exists between \p a and \p b, otherwise \p false. \p b_index is the rank of b in the aggregate tile of data. + */ + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagHeadsAndTails( + FlagT (&head_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + FlagT (&tail_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity tail_flags + T tile_successor_item, ///< [in] [threadBLOCK_THREADS-1 only] Item with which to compare the last tile item (inputITEMS_PER_THREAD-1 from threadBLOCK_THREADS-1). + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp flag_op) ///< [in] Binary boolean flag predicate + { + // Share first and last items + temp_storage.first_items[linear_tid] = input[0]; + temp_storage.last_items[linear_tid] = input[ITEMS_PER_THREAD - 1]; + + CTA_SYNC(); + + T preds[ITEMS_PER_THREAD]; + + // Set flag for first thread-item + if (linear_tid == 0) + { + head_flags[0] = 1; + } + else + { + preds[0] = temp_storage.last_items[linear_tid - 1]; + head_flags[0] = ApplyOp::FlagT( + flag_op, + preds[0], + input[0], + linear_tid * ITEMS_PER_THREAD); + } + + // Set flag for last thread-item + T successor_item = (linear_tid == BLOCK_THREADS - 1) ? + tile_successor_item : // Last thread + temp_storage.first_items[linear_tid + 1]; + + tail_flags[ITEMS_PER_THREAD - 1] = ApplyOp::FlagT( + flag_op, + input[ITEMS_PER_THREAD - 1], + successor_item, + (linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD); + + // Set head_flags for remaining items + Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op); + + // Set tail_flags for remaining items + Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op); + } + + + /** + * \brief Sets both head and tail flags indicating discontinuities between items partitioned across the thread block. + * + * \par + * - The flag head_flagsi is set for item + * inputi when + * flag_op(previous-item, inputi) + * returns \p true (where previous-item is either the preceding item + * in the same thread or the last item in the previous thread). + * - For thread0, item input0 is compared + * against \p tile_predecessor_item. + * - The flag tail_flagsi is set for item + * inputi when + * flag_op(inputi, next-item) + * returns \p true (where next-item is either the next item + * in the same thread or the first item in the next thread). + * - For threadBLOCK_THREADS-1, item + * inputITEMS_PER_THREAD-1 is always flagged. + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates the head- and tail-flagging of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockDiscontinuity for a 1D block of 128 threads on type int + * typedef cub::BlockDiscontinuity BlockDiscontinuity; + * + * // Allocate shared memory for BlockDiscontinuity + * __shared__ typename BlockDiscontinuity::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Have thread0 obtain the predecessor item for the entire tile + * int tile_predecessor_item; + * if (threadIdx.x == 0) tile_predecessor_item == ... + * + * // Have thread127 obtain the successor item for the entire tile + * int tile_successor_item; + * if (threadIdx.x == 127) tile_successor_item == ... + * + * // Collectively compute head and flags for discontinuities in the segment + * int head_flags[4]; + * int tail_flags[4]; + * BlockDiscontinuity(temp_storage).FlagTails( + * head_flags, tile_predecessor_item, tail_flags, tile_successor_item, + * thread_data, cub::Inequality()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is + * { [0,0,1,1], [1,1,1,1], [2,3,3,3], ..., [124,125,125,125] }, + * that the \p tile_predecessor_item is \p 0, and that the + * \p tile_successor_item is \p 125. The corresponding output \p head_flags + * in those threads will be { [0,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }. + * and the corresponding output \p tail_flags in those threads will be + * { [0,1,0,0], [0,0,0,1], [1,0,0,...], ..., [1,0,0,1] }. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam FlagT [inferred] The flag type (must be an integer type) + * \tparam FlagOp [inferred] Binary predicate functor type having member T operator()(const T &a, const T &b) or member T operator()(const T &a, const T &b, unsigned int b_index), and returning \p true if a discontinuity exists between \p a and \p b, otherwise \p false. \p b_index is the rank of b in the aggregate tile of data. + */ + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagHeadsAndTails( + FlagT (&head_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T tile_predecessor_item, ///< [in] [thread0 only] Item with which to compare the first tile item (input0 from thread0). + FlagT (&tail_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity tail_flags + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp flag_op) ///< [in] Binary boolean flag predicate + { + // Share first and last items + temp_storage.first_items[linear_tid] = input[0]; + temp_storage.last_items[linear_tid] = input[ITEMS_PER_THREAD - 1]; + + CTA_SYNC(); + + T preds[ITEMS_PER_THREAD]; + + // Set flag for first thread-item + preds[0] = (linear_tid == 0) ? + tile_predecessor_item : // First thread + temp_storage.last_items[linear_tid - 1]; + + head_flags[0] = ApplyOp::FlagT( + flag_op, + preds[0], + input[0], + linear_tid * ITEMS_PER_THREAD); + + // Set flag for last thread-item + tail_flags[ITEMS_PER_THREAD - 1] = (linear_tid == BLOCK_THREADS - 1) ? + 1 : // Last thread + ApplyOp::FlagT( + flag_op, + input[ITEMS_PER_THREAD - 1], + temp_storage.first_items[linear_tid + 1], + (linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD); + + // Set head_flags for remaining items + Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op); + + // Set tail_flags for remaining items + Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op); + } + + + /** + * \brief Sets both head and tail flags indicating discontinuities between items partitioned across the thread block. + * + * \par + * - The flag head_flagsi is set for item + * inputi when + * flag_op(previous-item, inputi) + * returns \p true (where previous-item is either the preceding item + * in the same thread or the last item in the previous thread). + * - For thread0, item input0 is compared + * against \p tile_predecessor_item. + * - The flag tail_flagsi is set for item + * inputi when + * flag_op(inputi, next-item) + * returns \p true (where next-item is either the next item + * in the same thread or the first item in the next thread). + * - For threadBLOCK_THREADS-1, item + * inputITEMS_PER_THREAD-1 is compared + * against \p tile_successor_item. + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates the head- and tail-flagging of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockDiscontinuity for a 1D block of 128 threads on type int + * typedef cub::BlockDiscontinuity BlockDiscontinuity; + * + * // Allocate shared memory for BlockDiscontinuity + * __shared__ typename BlockDiscontinuity::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Have thread0 obtain the predecessor item for the entire tile + * int tile_predecessor_item; + * if (threadIdx.x == 0) tile_predecessor_item == ... + * + * // Have thread127 obtain the successor item for the entire tile + * int tile_successor_item; + * if (threadIdx.x == 127) tile_successor_item == ... + * + * // Collectively compute head and flags for discontinuities in the segment + * int head_flags[4]; + * int tail_flags[4]; + * BlockDiscontinuity(temp_storage).FlagTails( + * head_flags, tile_predecessor_item, tail_flags, tile_successor_item, + * thread_data, cub::Inequality()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is + * { [0,0,1,1], [1,1,1,1], [2,3,3,3], ..., [124,125,125,125] }, + * that the \p tile_predecessor_item is \p 0, and that the + * \p tile_successor_item is \p 125. The corresponding output \p head_flags + * in those threads will be { [0,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }. + * and the corresponding output \p tail_flags in those threads will be + * { [0,1,0,0], [0,0,0,1], [1,0,0,...], ..., [1,0,0,0] }. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam FlagT [inferred] The flag type (must be an integer type) + * \tparam FlagOp [inferred] Binary predicate functor type having member T operator()(const T &a, const T &b) or member T operator()(const T &a, const T &b, unsigned int b_index), and returning \p true if a discontinuity exists between \p a and \p b, otherwise \p false. \p b_index is the rank of b in the aggregate tile of data. + */ + template < + int ITEMS_PER_THREAD, + typename FlagT, + typename FlagOp> + __device__ __forceinline__ void FlagHeadsAndTails( + FlagT (&head_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags + T tile_predecessor_item, ///< [in] [thread0 only] Item with which to compare the first tile item (input0 from thread0). + FlagT (&tail_flags)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity tail_flags + T tile_successor_item, ///< [in] [threadBLOCK_THREADS-1 only] Item with which to compare the last tile item (inputITEMS_PER_THREAD-1 from threadBLOCK_THREADS-1). + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + FlagOp flag_op) ///< [in] Binary boolean flag predicate + { + // Share first and last items + temp_storage.first_items[linear_tid] = input[0]; + temp_storage.last_items[linear_tid] = input[ITEMS_PER_THREAD - 1]; + + CTA_SYNC(); + + T preds[ITEMS_PER_THREAD]; + + // Set flag for first thread-item + preds[0] = (linear_tid == 0) ? + tile_predecessor_item : // First thread + temp_storage.last_items[linear_tid - 1]; + + head_flags[0] = ApplyOp::FlagT( + flag_op, + preds[0], + input[0], + linear_tid * ITEMS_PER_THREAD); + + // Set flag for last thread-item + T successor_item = (linear_tid == BLOCK_THREADS - 1) ? + tile_successor_item : // Last thread + temp_storage.first_items[linear_tid + 1]; + + tail_flags[ITEMS_PER_THREAD - 1] = ApplyOp::FlagT( + flag_op, + input[ITEMS_PER_THREAD - 1], + successor_item, + (linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD); + + // Set head_flags for remaining items + Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op); + + // Set tail_flags for remaining items + Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op); + } + + + + + //@} end member group + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/block/block_exchange.cuh b/fastertransformer/cuda/cub/block/block_exchange.cuh new file mode 100644 index 000000000..3ae993439 --- /dev/null +++ b/fastertransformer/cuda/cub/block/block_exchange.cuh @@ -0,0 +1,1248 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * The cub::BlockExchange class provides [collective](index.html#sec0) methods for rearranging data partitioned across a CUDA thread block. + */ + +#pragma once + +#include "../util_ptx.cuh" +#include "../util_arch.cuh" +#include "../util_macro.cuh" +#include "../util_type.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \brief The BlockExchange class provides [collective](index.html#sec0) methods for rearranging data partitioned across a CUDA thread block. ![](transpose_logo.png) + * \ingroup BlockModule + * + * \tparam T The data type to be exchanged. + * \tparam BLOCK_DIM_X The thread block length in threads along the X dimension + * \tparam ITEMS_PER_THREAD The number of items partitioned onto each thread. + * \tparam WARP_TIME_SLICING [optional] When \p true, only use enough shared memory for a single warp's worth of tile data, time-slicing the block-wide exchange over multiple synchronized rounds. Yields a smaller memory footprint at the expense of decreased parallelism. (Default: false) + * \tparam BLOCK_DIM_Y [optional] The thread block length in threads along the Y dimension (default: 1) + * \tparam BLOCK_DIM_Z [optional] The thread block length in threads along the Z dimension (default: 1) + * \tparam PTX_ARCH [optional] \ptxversion + * + * \par Overview + * - It is commonplace for blocks of threads to rearrange data items between + * threads. For example, the device-accessible memory subsystem prefers access patterns + * where data items are "striped" across threads (where consecutive threads access consecutive items), + * yet most block-wide operations prefer a "blocked" partitioning of items across threads + * (where consecutive items belong to a single thread). + * - BlockExchange supports the following types of data exchanges: + * - Transposing between [blocked](index.html#sec5sec3) and [striped](index.html#sec5sec3) arrangements + * - Transposing between [blocked](index.html#sec5sec3) and [warp-striped](index.html#sec5sec3) arrangements + * - Scattering ranked items to a [blocked arrangement](index.html#sec5sec3) + * - Scattering ranked items to a [striped arrangement](index.html#sec5sec3) + * - \rowmajor + * + * \par A Simple Example + * \blockcollective{BlockExchange} + * \par + * The code snippet below illustrates the conversion from a "blocked" to a "striped" arrangement + * of 512 integer items partitioned across 128 threads where each thread owns 4 items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(int *d_data, ...) + * { + * // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each + * typedef cub::BlockExchange BlockExchange; + * + * // Allocate shared memory for BlockExchange + * __shared__ typename BlockExchange::TempStorage temp_storage; + * + * // Load a tile of data striped across threads + * int thread_data[4]; + * cub::LoadDirectStriped<128>(threadIdx.x, d_data, thread_data); + * + * // Collectively exchange data into a blocked arrangement across threads + * BlockExchange(temp_storage).StripedToBlocked(thread_data); + * + * \endcode + * \par + * Suppose the set of striped input \p thread_data across the block of threads is + * { [0,128,256,384], [1,129,257,385], ..., [127,255,383,511] }. + * The corresponding output \p thread_data in those threads will be + * { [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }. + * + * \par Performance Considerations + * - Proper device-specific padding ensures zero bank conflicts for most types. + * + */ +template < + typename InputT, + int BLOCK_DIM_X, + int ITEMS_PER_THREAD, + bool WARP_TIME_SLICING = false, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int PTX_ARCH = CUB_PTX_ARCH> +class BlockExchange +{ +private: + + /****************************************************************************** + * Constants + ******************************************************************************/ + + /// Constants + enum + { + /// The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + + LOG_WARP_THREADS = CUB_LOG_WARP_THREADS(PTX_ARCH), + WARP_THREADS = 1 << LOG_WARP_THREADS, + WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, + + LOG_SMEM_BANKS = CUB_LOG_SMEM_BANKS(PTX_ARCH), + SMEM_BANKS = 1 << LOG_SMEM_BANKS, + + TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, + + TIME_SLICES = (WARP_TIME_SLICING) ? WARPS : 1, + + TIME_SLICED_THREADS = (WARP_TIME_SLICING) ? CUB_MIN(BLOCK_THREADS, WARP_THREADS) : BLOCK_THREADS, + TIME_SLICED_ITEMS = TIME_SLICED_THREADS * ITEMS_PER_THREAD, + + WARP_TIME_SLICED_THREADS = CUB_MIN(BLOCK_THREADS, WARP_THREADS), + WARP_TIME_SLICED_ITEMS = WARP_TIME_SLICED_THREADS * ITEMS_PER_THREAD, + + // Insert padding to avoid bank conflicts during raking when items per thread is a power of two and > 4 (otherwise we can typically use 128b loads) + INSERT_PADDING = (ITEMS_PER_THREAD > 4) && (PowerOfTwo::VALUE), + PADDING_ITEMS = (INSERT_PADDING) ? (TIME_SLICED_ITEMS >> LOG_SMEM_BANKS) : 0, + }; + + /****************************************************************************** + * Type definitions + ******************************************************************************/ + + /// Shared memory storage layout type + struct __align__(16) _TempStorage + { + InputT buff[TIME_SLICED_ITEMS + PADDING_ITEMS]; + }; + +public: + + /// \smemstorage{BlockExchange} + struct TempStorage : Uninitialized<_TempStorage> {}; + +private: + + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + /// Shared storage reference + _TempStorage &temp_storage; + + /// Linear thread-id + unsigned int linear_tid; + unsigned int lane_id; + unsigned int warp_id; + unsigned int warp_offset; + + + /****************************************************************************** + * Utility methods + ******************************************************************************/ + + /// Internal storage allocator + __device__ __forceinline__ _TempStorage& PrivateStorage() + { + __shared__ _TempStorage private_storage; + return private_storage; + } + + + /** + * Transposes data items from blocked arrangement to striped arrangement. Specialized for no timeslicing. + */ + template + __device__ __forceinline__ void BlockedToStriped( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between blocked and striped arrangements. + OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between blocked and striped arrangements. + Int2Type /*time_slicing*/) + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = (linear_tid * ITEMS_PER_THREAD) + ITEM; + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + temp_storage.buff[item_offset] = input_items[ITEM]; + } + + CTA_SYNC(); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = int(ITEM * BLOCK_THREADS) + linear_tid; + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + output_items[ITEM] = temp_storage.buff[item_offset]; + } + } + + + /** + * Transposes data items from blocked arrangement to striped arrangement. Specialized for warp-timeslicing. + */ + template + __device__ __forceinline__ void BlockedToStriped( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between blocked and striped arrangements. + OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between blocked and striped arrangements. + Int2Type /*time_slicing*/) + { + InputT temp_items[ITEMS_PER_THREAD]; + + #pragma unroll + for (int SLICE = 0; SLICE < TIME_SLICES; SLICE++) + { + const int SLICE_OFFSET = SLICE * TIME_SLICED_ITEMS; + const int SLICE_OOB = SLICE_OFFSET + TIME_SLICED_ITEMS; + + CTA_SYNC(); + + if (warp_id == SLICE) + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = (lane_id * ITEMS_PER_THREAD) + ITEM; + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + temp_storage.buff[item_offset] = input_items[ITEM]; + } + } + + CTA_SYNC(); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + // Read a strip of items + const int STRIP_OFFSET = ITEM * BLOCK_THREADS; + const int STRIP_OOB = STRIP_OFFSET + BLOCK_THREADS; + + if ((SLICE_OFFSET < STRIP_OOB) && (SLICE_OOB > STRIP_OFFSET)) + { + int item_offset = STRIP_OFFSET + linear_tid - SLICE_OFFSET; + if ((item_offset >= 0) && (item_offset < TIME_SLICED_ITEMS)) + { + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + temp_items[ITEM] = temp_storage.buff[item_offset]; + } + } + } + } + + // Copy + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + output_items[ITEM] = temp_items[ITEM]; + } + } + + + /** + * Transposes data items from blocked arrangement to warp-striped arrangement. Specialized for no timeslicing + */ + template + __device__ __forceinline__ void BlockedToWarpStriped( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between blocked and striped arrangements. + OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between blocked and striped arrangements. + Int2Type /*time_slicing*/) + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = warp_offset + ITEM + (lane_id * ITEMS_PER_THREAD); + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + temp_storage.buff[item_offset] = input_items[ITEM]; + } + + WARP_SYNC(0xffffffff); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = warp_offset + (ITEM * WARP_TIME_SLICED_THREADS) + lane_id; + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + output_items[ITEM] = temp_storage.buff[item_offset]; + } + } + + /** + * Transposes data items from blocked arrangement to warp-striped arrangement. Specialized for warp-timeslicing + */ + template + __device__ __forceinline__ void BlockedToWarpStriped( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between blocked and striped arrangements. + OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between blocked and striped arrangements. + Int2Type /*time_slicing*/) + { + if (warp_id == 0) + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = ITEM + (lane_id * ITEMS_PER_THREAD); + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + temp_storage.buff[item_offset] = input_items[ITEM]; + } + + WARP_SYNC(0xffffffff); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = (ITEM * WARP_TIME_SLICED_THREADS) + lane_id; + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + output_items[ITEM] = temp_storage.buff[item_offset]; + } + } + + #pragma unroll + for (unsigned int SLICE = 1; SLICE < TIME_SLICES; ++SLICE) + { + CTA_SYNC(); + + if (warp_id == SLICE) + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = ITEM + (lane_id * ITEMS_PER_THREAD); + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + temp_storage.buff[item_offset] = input_items[ITEM]; + } + + WARP_SYNC(0xffffffff); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = (ITEM * WARP_TIME_SLICED_THREADS) + lane_id; + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + output_items[ITEM] = temp_storage.buff[item_offset]; + } + } + } + } + + + /** + * Transposes data items from striped arrangement to blocked arrangement. Specialized for no timeslicing. + */ + template + __device__ __forceinline__ void StripedToBlocked( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between blocked and striped arrangements. + OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between blocked and striped arrangements. + Int2Type /*time_slicing*/) + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = int(ITEM * BLOCK_THREADS) + linear_tid; + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + temp_storage.buff[item_offset] = input_items[ITEM]; + } + + CTA_SYNC(); + + // No timeslicing + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = (linear_tid * ITEMS_PER_THREAD) + ITEM; + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + output_items[ITEM] = temp_storage.buff[item_offset]; + } + } + + + /** + * Transposes data items from striped arrangement to blocked arrangement. Specialized for warp-timeslicing. + */ + template + __device__ __forceinline__ void StripedToBlocked( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between blocked and striped arrangements. + OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between blocked and striped arrangements. + Int2Type /*time_slicing*/) + { + // Warp time-slicing + InputT temp_items[ITEMS_PER_THREAD]; + + #pragma unroll + for (int SLICE = 0; SLICE < TIME_SLICES; SLICE++) + { + const int SLICE_OFFSET = SLICE * TIME_SLICED_ITEMS; + const int SLICE_OOB = SLICE_OFFSET + TIME_SLICED_ITEMS; + + CTA_SYNC(); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + // Write a strip of items + const int STRIP_OFFSET = ITEM * BLOCK_THREADS; + const int STRIP_OOB = STRIP_OFFSET + BLOCK_THREADS; + + if ((SLICE_OFFSET < STRIP_OOB) && (SLICE_OOB > STRIP_OFFSET)) + { + int item_offset = STRIP_OFFSET + linear_tid - SLICE_OFFSET; + if ((item_offset >= 0) && (item_offset < TIME_SLICED_ITEMS)) + { + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + temp_storage.buff[item_offset] = input_items[ITEM]; + } + } + } + + CTA_SYNC(); + + if (warp_id == SLICE) + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = (lane_id * ITEMS_PER_THREAD) + ITEM; + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + temp_items[ITEM] = temp_storage.buff[item_offset]; + } + } + } + + // Copy + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + output_items[ITEM] = temp_items[ITEM]; + } + } + + + /** + * Transposes data items from warp-striped arrangement to blocked arrangement. Specialized for no timeslicing + */ + template + __device__ __forceinline__ void WarpStripedToBlocked( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between blocked and striped arrangements. + OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between blocked and striped arrangements. + Int2Type /*time_slicing*/) + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = warp_offset + (ITEM * WARP_TIME_SLICED_THREADS) + lane_id; + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + temp_storage.buff[item_offset] = input_items[ITEM]; + } + + WARP_SYNC(0xffffffff); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = warp_offset + ITEM + (lane_id * ITEMS_PER_THREAD); + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + output_items[ITEM] = temp_storage.buff[item_offset]; + } + } + + + /** + * Transposes data items from warp-striped arrangement to blocked arrangement. Specialized for warp-timeslicing + */ + template + __device__ __forceinline__ void WarpStripedToBlocked( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between blocked and striped arrangements. + OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between blocked and striped arrangements. + Int2Type /*time_slicing*/) + { + #pragma unroll + for (unsigned int SLICE = 0; SLICE < TIME_SLICES; ++SLICE) + { + CTA_SYNC(); + + if (warp_id == SLICE) + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = (ITEM * WARP_TIME_SLICED_THREADS) + lane_id; + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + temp_storage.buff[item_offset] = input_items[ITEM]; + } + + WARP_SYNC(0xffffffff); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = ITEM + (lane_id * ITEMS_PER_THREAD); + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + output_items[ITEM] = temp_storage.buff[item_offset]; + } + } + } + } + + + /** + * Exchanges data items annotated by rank into blocked arrangement. Specialized for no timeslicing. + */ + template + __device__ __forceinline__ void ScatterToBlocked( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between blocked and striped arrangements. + OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between blocked and striped arrangements. + OffsetT ranks[ITEMS_PER_THREAD], ///< [in] Corresponding scatter ranks + Int2Type /*time_slicing*/) + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = ranks[ITEM]; + if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset); + temp_storage.buff[item_offset] = input_items[ITEM]; + } + + CTA_SYNC(); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = (linear_tid * ITEMS_PER_THREAD) + ITEM; + if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset); + output_items[ITEM] = temp_storage.buff[item_offset]; + } + } + + /** + * Exchanges data items annotated by rank into blocked arrangement. Specialized for warp-timeslicing. + */ + template + __device__ __forceinline__ void ScatterToBlocked( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between blocked and striped arrangements. + OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between blocked and striped arrangements. + OffsetT ranks[ITEMS_PER_THREAD], ///< [in] Corresponding scatter ranks + Int2Type /*time_slicing*/) + { + InputT temp_items[ITEMS_PER_THREAD]; + + #pragma unroll + for (int SLICE = 0; SLICE < TIME_SLICES; SLICE++) + { + CTA_SYNC(); + + const int SLICE_OFFSET = TIME_SLICED_ITEMS * SLICE; + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = ranks[ITEM] - SLICE_OFFSET; + if ((item_offset >= 0) && (item_offset < WARP_TIME_SLICED_ITEMS)) + { + if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset); + temp_storage.buff[item_offset] = input_items[ITEM]; + } + } + + CTA_SYNC(); + + if (warp_id == SLICE) + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = (lane_id * ITEMS_PER_THREAD) + ITEM; + if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset); + temp_items[ITEM] = temp_storage.buff[item_offset]; + } + } + } + + // Copy + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + output_items[ITEM] = temp_items[ITEM]; + } + } + + + /** + * Exchanges data items annotated by rank into striped arrangement. Specialized for no timeslicing. + */ + template + __device__ __forceinline__ void ScatterToStriped( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between blocked and striped arrangements. + OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between blocked and striped arrangements. + OffsetT ranks[ITEMS_PER_THREAD], ///< [in] Corresponding scatter ranks + Int2Type /*time_slicing*/) + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = ranks[ITEM]; + if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset); + temp_storage.buff[item_offset] = input_items[ITEM]; + } + + CTA_SYNC(); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = int(ITEM * BLOCK_THREADS) + linear_tid; + if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset); + output_items[ITEM] = temp_storage.buff[item_offset]; + } + } + + + /** + * Exchanges data items annotated by rank into striped arrangement. Specialized for warp-timeslicing. + */ + template + __device__ __forceinline__ void ScatterToStriped( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between blocked and striped arrangements. + OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between blocked and striped arrangements. + OffsetT ranks[ITEMS_PER_THREAD], ///< [in] Corresponding scatter ranks + Int2Type /*time_slicing*/) + { + InputT temp_items[ITEMS_PER_THREAD]; + + #pragma unroll + for (int SLICE = 0; SLICE < TIME_SLICES; SLICE++) + { + const int SLICE_OFFSET = SLICE * TIME_SLICED_ITEMS; + const int SLICE_OOB = SLICE_OFFSET + TIME_SLICED_ITEMS; + + CTA_SYNC(); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = ranks[ITEM] - SLICE_OFFSET; + if ((item_offset >= 0) && (item_offset < WARP_TIME_SLICED_ITEMS)) + { + if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset); + temp_storage.buff[item_offset] = input_items[ITEM]; + } + } + + CTA_SYNC(); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + // Read a strip of items + const int STRIP_OFFSET = ITEM * BLOCK_THREADS; + const int STRIP_OOB = STRIP_OFFSET + BLOCK_THREADS; + + if ((SLICE_OFFSET < STRIP_OOB) && (SLICE_OOB > STRIP_OFFSET)) + { + int item_offset = STRIP_OFFSET + linear_tid - SLICE_OFFSET; + if ((item_offset >= 0) && (item_offset < TIME_SLICED_ITEMS)) + { + if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; + temp_items[ITEM] = temp_storage.buff[item_offset]; + } + } + } + } + + // Copy + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + output_items[ITEM] = temp_items[ITEM]; + } + } + + +public: + + /******************************************************************//** + * \name Collective constructors + *********************************************************************/ + //@{ + + /** + * \brief Collective constructor using a private static allocation of shared memory as temporary storage. + */ + __device__ __forceinline__ BlockExchange() + : + temp_storage(PrivateStorage()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)), + warp_id((WARPS == 1) ? 0 : linear_tid / WARP_THREADS), + lane_id(LaneId()), + warp_offset(warp_id * WARP_TIME_SLICED_ITEMS) + {} + + + /** + * \brief Collective constructor using the specified memory allocation as temporary storage. + */ + __device__ __forceinline__ BlockExchange( + TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)), + lane_id(LaneId()), + warp_id((WARPS == 1) ? 0 : linear_tid / WARP_THREADS), + warp_offset(warp_id * WARP_TIME_SLICED_ITEMS) + {} + + + //@} end member group + /******************************************************************//** + * \name Structured exchanges + *********************************************************************/ + //@{ + + /** + * \brief Transposes data items from striped arrangement to blocked arrangement. + * + * \par + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates the conversion from a "striped" to a "blocked" arrangement + * of 512 integer items partitioned across 128 threads where each thread owns 4 items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(int *d_data, ...) + * { + * // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each + * typedef cub::BlockExchange BlockExchange; + * + * // Allocate shared memory for BlockExchange + * __shared__ typename BlockExchange::TempStorage temp_storage; + * + * // Load a tile of ordered data into a striped arrangement across block threads + * int thread_data[4]; + * cub::LoadDirectStriped<128>(threadIdx.x, d_data, thread_data); + * + * // Collectively exchange data into a blocked arrangement across threads + * BlockExchange(temp_storage).StripedToBlocked(thread_data, thread_data); + * + * \endcode + * \par + * Suppose the set of striped input \p thread_data across the block of threads is + * { [0,128,256,384], [1,129,257,385], ..., [127,255,383,511] } after loading from device-accessible memory. + * The corresponding output \p thread_data in those threads will be + * { [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }. + * + */ + template + __device__ __forceinline__ void StripedToBlocked( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between striped and blocked arrangements. + OutputT output_items[ITEMS_PER_THREAD]) ///< [out] Items from exchange, converting between striped and blocked arrangements. + { + StripedToBlocked(input_items, output_items, Int2Type()); + } + + + /** + * \brief Transposes data items from blocked arrangement to striped arrangement. + * + * \par + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates the conversion from a "blocked" to a "striped" arrangement + * of 512 integer items partitioned across 128 threads where each thread owns 4 items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(int *d_data, ...) + * { + * // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each + * typedef cub::BlockExchange BlockExchange; + * + * // Allocate shared memory for BlockExchange + * __shared__ typename BlockExchange::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Collectively exchange data into a striped arrangement across threads + * BlockExchange(temp_storage).BlockedToStriped(thread_data, thread_data); + * + * // Store data striped across block threads into an ordered tile + * cub::StoreDirectStriped(threadIdx.x, d_data, thread_data); + * + * \endcode + * \par + * Suppose the set of blocked input \p thread_data across the block of threads is + * { [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }. + * The corresponding output \p thread_data in those threads will be + * { [0,128,256,384], [1,129,257,385], ..., [127,255,383,511] } in + * preparation for storing to device-accessible memory. + * + */ + template + __device__ __forceinline__ void BlockedToStriped( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between striped and blocked arrangements. + OutputT output_items[ITEMS_PER_THREAD]) ///< [out] Items from exchange, converting between striped and blocked arrangements. + { + BlockedToStriped(input_items, output_items, Int2Type()); + } + + + + /** + * \brief Transposes data items from warp-striped arrangement to blocked arrangement. + * + * \par + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates the conversion from a "warp-striped" to a "blocked" arrangement + * of 512 integer items partitioned across 128 threads where each thread owns 4 items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(int *d_data, ...) + * { + * // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each + * typedef cub::BlockExchange BlockExchange; + * + * // Allocate shared memory for BlockExchange + * __shared__ typename BlockExchange::TempStorage temp_storage; + * + * // Load a tile of ordered data into a warp-striped arrangement across warp threads + * int thread_data[4]; + * cub::LoadSWarptriped(threadIdx.x, d_data, thread_data); + * + * // Collectively exchange data into a blocked arrangement across threads + * BlockExchange(temp_storage).WarpStripedToBlocked(thread_data); + * + * \endcode + * \par + * Suppose the set of warp-striped input \p thread_data across the block of threads is + * { [0,32,64,96], [1,33,65,97], [2,34,66,98], ..., [415,447,479,511] } + * after loading from device-accessible memory. (The first 128 items are striped across + * the first warp of 32 threads, the second 128 items are striped across the second warp, etc.) + * The corresponding output \p thread_data in those threads will be + * { [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }. + * + */ + template + __device__ __forceinline__ void WarpStripedToBlocked( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between striped and blocked arrangements. + OutputT output_items[ITEMS_PER_THREAD]) ///< [out] Items from exchange, converting between striped and blocked arrangements. + { + WarpStripedToBlocked(input_items, output_items, Int2Type()); + } + + + + /** + * \brief Transposes data items from blocked arrangement to warp-striped arrangement. + * + * \par + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates the conversion from a "blocked" to a "warp-striped" arrangement + * of 512 integer items partitioned across 128 threads where each thread owns 4 items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(int *d_data, ...) + * { + * // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each + * typedef cub::BlockExchange BlockExchange; + * + * // Allocate shared memory for BlockExchange + * __shared__ typename BlockExchange::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Collectively exchange data into a warp-striped arrangement across threads + * BlockExchange(temp_storage).BlockedToWarpStriped(thread_data, thread_data); + * + * // Store data striped across warp threads into an ordered tile + * cub::StoreDirectStriped(threadIdx.x, d_data, thread_data); + * + * \endcode + * \par + * Suppose the set of blocked input \p thread_data across the block of threads is + * { [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }. + * The corresponding output \p thread_data in those threads will be + * { [0,32,64,96], [1,33,65,97], [2,34,66,98], ..., [415,447,479,511] } + * in preparation for storing to device-accessible memory. (The first 128 items are striped across + * the first warp of 32 threads, the second 128 items are striped across the second warp, etc.) + * + */ + template + __device__ __forceinline__ void BlockedToWarpStriped( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between striped and blocked arrangements. + OutputT output_items[ITEMS_PER_THREAD]) ///< [out] Items from exchange, converting between striped and blocked arrangements. + { + BlockedToWarpStriped(input_items, output_items, Int2Type()); + } + + + + //@} end member group + /******************************************************************//** + * \name Scatter exchanges + *********************************************************************/ + //@{ + + + /** + * \brief Exchanges data items annotated by rank into blocked arrangement. + * + * \par + * - \smemreuse + * + * \tparam OffsetT [inferred] Signed integer type for local offsets + */ + template + __device__ __forceinline__ void ScatterToBlocked( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between striped and blocked arrangements. + OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items from exchange, converting between striped and blocked arrangements. + OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks + { + ScatterToBlocked(input_items, output_items, ranks, Int2Type()); + } + + + + /** + * \brief Exchanges data items annotated by rank into striped arrangement. + * + * \par + * - \smemreuse + * + * \tparam OffsetT [inferred] Signed integer type for local offsets + */ + template + __device__ __forceinline__ void ScatterToStriped( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between striped and blocked arrangements. + OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items from exchange, converting between striped and blocked arrangements. + OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks + { + ScatterToStriped(input_items, output_items, ranks, Int2Type()); + } + + + + /** + * \brief Exchanges data items annotated by rank into striped arrangement. Items with rank -1 are not exchanged. + * + * \par + * - \smemreuse + * + * \tparam OffsetT [inferred] Signed integer type for local offsets + */ + template + __device__ __forceinline__ void ScatterToStripedGuarded( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between striped and blocked arrangements. + OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items from exchange, converting between striped and blocked arrangements. + OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = ranks[ITEM]; + if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset); + if (ranks[ITEM] >= 0) + temp_storage.buff[item_offset] = input_items[ITEM]; + } + + CTA_SYNC(); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = int(ITEM * BLOCK_THREADS) + linear_tid; + if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset); + output_items[ITEM] = temp_storage.buff[item_offset]; + } + } + + + + + /** + * \brief Exchanges valid data items annotated by rank into striped arrangement. + * + * \par + * - \smemreuse + * + * \tparam OffsetT [inferred] Signed integer type for local offsets + * \tparam ValidFlag [inferred] FlagT type denoting which items are valid + */ + template + __device__ __forceinline__ void ScatterToStripedFlagged( + InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between striped and blocked arrangements. + OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items from exchange, converting between striped and blocked arrangements. + OffsetT ranks[ITEMS_PER_THREAD], ///< [in] Corresponding scatter ranks + ValidFlag is_valid[ITEMS_PER_THREAD]) ///< [in] Corresponding flag denoting item validity + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = ranks[ITEM]; + if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset); + if (is_valid[ITEM]) + temp_storage.buff[item_offset] = input_items[ITEM]; + } + + CTA_SYNC(); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = int(ITEM * BLOCK_THREADS) + linear_tid; + if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset); + output_items[ITEM] = temp_storage.buff[item_offset]; + } + } + + + //@} end member group + + + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + + __device__ __forceinline__ void StripedToBlocked( + InputT items[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between striped and blocked arrangements. + { + StripedToBlocked(items, items); + } + + __device__ __forceinline__ void BlockedToStriped( + InputT items[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between striped and blocked arrangements. + { + BlockedToStriped(items, items); + } + + __device__ __forceinline__ void WarpStripedToBlocked( + InputT items[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between striped and blocked arrangements. + { + WarpStripedToBlocked(items, items); + } + + __device__ __forceinline__ void BlockedToWarpStriped( + InputT items[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between striped and blocked arrangements. + { + BlockedToWarpStriped(items, items); + } + + template + __device__ __forceinline__ void ScatterToBlocked( + InputT items[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between striped and blocked arrangements. + OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks + { + ScatterToBlocked(items, items, ranks); + } + + template + __device__ __forceinline__ void ScatterToStriped( + InputT items[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between striped and blocked arrangements. + OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks + { + ScatterToStriped(items, items, ranks); + } + + template + __device__ __forceinline__ void ScatterToStripedGuarded( + InputT items[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between striped and blocked arrangements. + OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks + { + ScatterToStripedGuarded(items, items, ranks); + } + + template + __device__ __forceinline__ void ScatterToStripedFlagged( + InputT items[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between striped and blocked arrangements. + OffsetT ranks[ITEMS_PER_THREAD], ///< [in] Corresponding scatter ranks + ValidFlag is_valid[ITEMS_PER_THREAD]) ///< [in] Corresponding flag denoting item validity + { + ScatterToStriped(items, items, ranks, is_valid); + } + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + +}; + + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + +template < + typename T, + int ITEMS_PER_THREAD, + int LOGICAL_WARP_THREADS = CUB_PTX_WARP_THREADS, + int PTX_ARCH = CUB_PTX_ARCH> +class WarpExchange +{ +private: + + /****************************************************************************** + * Constants + ******************************************************************************/ + + /// Constants + enum + { + // Whether the logical warp size and the PTX warp size coincide + IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)), + + WARP_ITEMS = (ITEMS_PER_THREAD * LOGICAL_WARP_THREADS) + 1, + + LOG_SMEM_BANKS = CUB_LOG_SMEM_BANKS(PTX_ARCH), + SMEM_BANKS = 1 << LOG_SMEM_BANKS, + + // Insert padding if the number of items per thread is a power of two and > 4 (otherwise we can typically use 128b loads) + INSERT_PADDING = (ITEMS_PER_THREAD > 4) && (PowerOfTwo::VALUE), + PADDING_ITEMS = (INSERT_PADDING) ? (WARP_ITEMS >> LOG_SMEM_BANKS) : 0, + }; + + /****************************************************************************** + * Type definitions + ******************************************************************************/ + + /// Shared memory storage layout type + struct _TempStorage + { + T buff[WARP_ITEMS + PADDING_ITEMS]; + }; + +public: + + /// \smemstorage{WarpExchange} + struct TempStorage : Uninitialized<_TempStorage> {}; + +private: + + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + _TempStorage &temp_storage; + int lane_id; + +public: + + /****************************************************************************** + * Construction + ******************************************************************************/ + + /// Constructor + __device__ __forceinline__ WarpExchange( + TempStorage &temp_storage) + : + temp_storage(temp_storage.Alias()), + lane_id(IS_ARCH_WARP ? + LaneId() : + LaneId() % LOGICAL_WARP_THREADS) + {} + + + /****************************************************************************** + * Interface + ******************************************************************************/ + + /** + * \brief Exchanges valid data items annotated by rank into striped arrangement. + * + * \par + * - \smemreuse + * + * \tparam OffsetT [inferred] Signed integer type for local offsets + */ + template + __device__ __forceinline__ void ScatterToStriped( + T items[ITEMS_PER_THREAD], ///< [in-out] Items to exchange + OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks + { + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + if (INSERT_PADDING) ranks[ITEM] = SHR_ADD(ranks[ITEM], LOG_SMEM_BANKS, ranks[ITEM]); + temp_storage.buff[ranks[ITEM]] = items[ITEM]; + } + + WARP_SYNC(0xffffffff); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + int item_offset = (ITEM * LOGICAL_WARP_THREADS) + lane_id; + if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset); + items[ITEM] = temp_storage.buff[item_offset]; + } + } + +}; + + + + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + + + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/block/block_histogram.cuh b/fastertransformer/cuda/cub/block/block_histogram.cuh new file mode 100644 index 000000000..b7cb9700e --- /dev/null +++ b/fastertransformer/cuda/cub/block/block_histogram.cuh @@ -0,0 +1,415 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * The cub::BlockHistogram class provides [collective](index.html#sec0) methods for constructing block-wide histograms from data samples partitioned across a CUDA thread block. + */ + +#pragma once + +#include "specializations/block_histogram_sort.cuh" +#include "specializations/block_histogram_atomic.cuh" +#include "../util_ptx.cuh" +#include "../util_arch.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/****************************************************************************** + * Algorithmic variants + ******************************************************************************/ + +/** + * \brief BlockHistogramAlgorithm enumerates alternative algorithms for the parallel construction of block-wide histograms. + */ +enum BlockHistogramAlgorithm +{ + + /** + * \par Overview + * Sorting followed by differentiation. Execution is comprised of two phases: + * -# Sort the data using efficient radix sort + * -# Look for "runs" of same-valued keys by detecting discontinuities; the run-lengths are histogram bin counts. + * + * \par Performance Considerations + * Delivers consistent throughput regardless of sample bin distribution. + */ + BLOCK_HISTO_SORT, + + + /** + * \par Overview + * Use atomic addition to update byte counts directly + * + * \par Performance Considerations + * Performance is strongly tied to the hardware implementation of atomic + * addition, and may be significantly degraded for non uniformly-random + * input distributions where many concurrent updates are likely to be + * made to the same bin counter. + */ + BLOCK_HISTO_ATOMIC, +}; + + + +/****************************************************************************** + * Block histogram + ******************************************************************************/ + + +/** + * \brief The BlockHistogram class provides [collective](index.html#sec0) methods for constructing block-wide histograms from data samples partitioned across a CUDA thread block. ![](histogram_logo.png) + * \ingroup BlockModule + * + * \tparam T The sample type being histogrammed (must be castable to an integer bin identifier) + * \tparam BLOCK_DIM_X The thread block length in threads along the X dimension + * \tparam ITEMS_PER_THREAD The number of items per thread + * \tparam BINS The number bins within the histogram + * \tparam ALGORITHM [optional] cub::BlockHistogramAlgorithm enumerator specifying the underlying algorithm to use (default: cub::BLOCK_HISTO_SORT) + * \tparam BLOCK_DIM_Y [optional] The thread block length in threads along the Y dimension (default: 1) + * \tparam BLOCK_DIM_Z [optional] The thread block length in threads along the Z dimension (default: 1) + * \tparam PTX_ARCH [optional] \ptxversion + * + * \par Overview + * - A histogram + * counts the number of observations that fall into each of the disjoint categories (known as bins). + * - BlockHistogram can be optionally specialized to use different algorithms: + * -# cub::BLOCK_HISTO_SORT. Sorting followed by differentiation. [More...](\ref cub::BlockHistogramAlgorithm) + * -# cub::BLOCK_HISTO_ATOMIC. Use atomic addition to update byte counts directly. [More...](\ref cub::BlockHistogramAlgorithm) + * + * \par Performance Considerations + * - \granularity + * + * \par A Simple Example + * \blockcollective{BlockHistogram} + * \par + * The code snippet below illustrates a 256-bin histogram of 512 integer samples that + * are partitioned across 128 threads where each thread owns 4 samples. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize a 256-bin BlockHistogram type for a 1D block of 128 threads having 4 character samples each + * typedef cub::BlockHistogram BlockHistogram; + * + * // Allocate shared memory for BlockHistogram + * __shared__ typename BlockHistogram::TempStorage temp_storage; + * + * // Allocate shared memory for block-wide histogram bin counts + * __shared__ unsigned int smem_histogram[256]; + * + * // Obtain input samples per thread + * unsigned char data[4]; + * ... + * + * // Compute the block-wide histogram + * BlockHistogram(temp_storage).Histogram(data, smem_histogram); + * + * \endcode + * + * \par Performance and Usage Considerations + * - The histogram output can be constructed in shared or device-accessible memory + * - See cub::BlockHistogramAlgorithm for performance details regarding algorithmic alternatives + * + */ +template < + typename T, + int BLOCK_DIM_X, + int ITEMS_PER_THREAD, + int BINS, + BlockHistogramAlgorithm ALGORITHM = BLOCK_HISTO_SORT, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int PTX_ARCH = CUB_PTX_ARCH> +class BlockHistogram +{ +private: + + /****************************************************************************** + * Constants and type definitions + ******************************************************************************/ + + /// Constants + enum + { + /// The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + }; + + /** + * Ensure the template parameterization meets the requirements of the + * targeted device architecture. BLOCK_HISTO_ATOMIC can only be used + * on version SM120 or later. Otherwise BLOCK_HISTO_SORT is used + * regardless. + */ + static const BlockHistogramAlgorithm SAFE_ALGORITHM = + ((ALGORITHM == BLOCK_HISTO_ATOMIC) && (PTX_ARCH < 120)) ? + BLOCK_HISTO_SORT : + ALGORITHM; + + /// Internal specialization. + typedef typename If<(SAFE_ALGORITHM == BLOCK_HISTO_SORT), + BlockHistogramSort, + BlockHistogramAtomic >::Type InternalBlockHistogram; + + /// Shared memory storage layout type for BlockHistogram + typedef typename InternalBlockHistogram::TempStorage _TempStorage; + + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + /// Shared storage reference + _TempStorage &temp_storage; + + /// Linear thread-id + unsigned int linear_tid; + + + /****************************************************************************** + * Utility methods + ******************************************************************************/ + + /// Internal storage allocator + __device__ __forceinline__ _TempStorage& PrivateStorage() + { + __shared__ _TempStorage private_storage; + return private_storage; + } + + +public: + + /// \smemstorage{BlockHistogram} + struct TempStorage : Uninitialized<_TempStorage> {}; + + + /******************************************************************//** + * \name Collective constructors + *********************************************************************/ + //@{ + + /** + * \brief Collective constructor using a private static allocation of shared memory as temporary storage. + */ + __device__ __forceinline__ BlockHistogram() + : + temp_storage(PrivateStorage()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + /** + * \brief Collective constructor using the specified memory allocation as temporary storage. + */ + __device__ __forceinline__ BlockHistogram( + TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + //@} end member group + /******************************************************************//** + * \name Histogram operations + *********************************************************************/ + //@{ + + + /** + * \brief Initialize the shared histogram counters to zero. + * + * \par Snippet + * The code snippet below illustrates a the initialization and update of a + * histogram of 512 integer samples that are partitioned across 128 threads + * where each thread owns 4 samples. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize a 256-bin BlockHistogram type for a 1D block of 128 threads having 4 character samples each + * typedef cub::BlockHistogram BlockHistogram; + * + * // Allocate shared memory for BlockHistogram + * __shared__ typename BlockHistogram::TempStorage temp_storage; + * + * // Allocate shared memory for block-wide histogram bin counts + * __shared__ unsigned int smem_histogram[256]; + * + * // Obtain input samples per thread + * unsigned char thread_samples[4]; + * ... + * + * // Initialize the block-wide histogram + * BlockHistogram(temp_storage).InitHistogram(smem_histogram); + * + * // Update the block-wide histogram + * BlockHistogram(temp_storage).Composite(thread_samples, smem_histogram); + * + * \endcode + * + * \tparam CounterT [inferred] Histogram counter type + */ + template + __device__ __forceinline__ void InitHistogram(CounterT histogram[BINS]) + { + // Initialize histogram bin counts to zeros + int histo_offset = 0; + + #pragma unroll + for(; histo_offset + BLOCK_THREADS <= BINS; histo_offset += BLOCK_THREADS) + { + histogram[histo_offset + linear_tid] = 0; + } + // Finish up with guarded initialization if necessary + if ((BINS % BLOCK_THREADS != 0) && (histo_offset + linear_tid < BINS)) + { + histogram[histo_offset + linear_tid] = 0; + } + } + + + /** + * \brief Constructs a block-wide histogram in shared/device-accessible memory. Each thread contributes an array of input elements. + * + * \par + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a 256-bin histogram of 512 integer samples that + * are partitioned across 128 threads where each thread owns 4 samples. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize a 256-bin BlockHistogram type for a 1D block of 128 threads having 4 character samples each + * typedef cub::BlockHistogram BlockHistogram; + * + * // Allocate shared memory for BlockHistogram + * __shared__ typename BlockHistogram::TempStorage temp_storage; + * + * // Allocate shared memory for block-wide histogram bin counts + * __shared__ unsigned int smem_histogram[256]; + * + * // Obtain input samples per thread + * unsigned char thread_samples[4]; + * ... + * + * // Compute the block-wide histogram + * BlockHistogram(temp_storage).Histogram(thread_samples, smem_histogram); + * + * \endcode + * + * \tparam CounterT [inferred] Histogram counter type + */ + template < + typename CounterT > + __device__ __forceinline__ void Histogram( + T (&items)[ITEMS_PER_THREAD], ///< [in] Calling thread's input values to histogram + CounterT histogram[BINS]) ///< [out] Reference to shared/device-accessible memory histogram + { + // Initialize histogram bin counts to zeros + InitHistogram(histogram); + + CTA_SYNC(); + + // Composite the histogram + InternalBlockHistogram(temp_storage).Composite(items, histogram); + } + + + + /** + * \brief Updates an existing block-wide histogram in shared/device-accessible memory. Each thread composites an array of input elements. + * + * \par + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a the initialization and update of a + * histogram of 512 integer samples that are partitioned across 128 threads + * where each thread owns 4 samples. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize a 256-bin BlockHistogram type for a 1D block of 128 threads having 4 character samples each + * typedef cub::BlockHistogram BlockHistogram; + * + * // Allocate shared memory for BlockHistogram + * __shared__ typename BlockHistogram::TempStorage temp_storage; + * + * // Allocate shared memory for block-wide histogram bin counts + * __shared__ unsigned int smem_histogram[256]; + * + * // Obtain input samples per thread + * unsigned char thread_samples[4]; + * ... + * + * // Initialize the block-wide histogram + * BlockHistogram(temp_storage).InitHistogram(smem_histogram); + * + * // Update the block-wide histogram + * BlockHistogram(temp_storage).Composite(thread_samples, smem_histogram); + * + * \endcode + * + * \tparam CounterT [inferred] Histogram counter type + */ + template < + typename CounterT > + __device__ __forceinline__ void Composite( + T (&items)[ITEMS_PER_THREAD], ///< [in] Calling thread's input values to histogram + CounterT histogram[BINS]) ///< [out] Reference to shared/device-accessible memory histogram + { + InternalBlockHistogram(temp_storage).Composite(items, histogram); + } + +}; + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/block/block_load.cuh b/fastertransformer/cuda/cub/block/block_load.cuh new file mode 100644 index 000000000..217f52123 --- /dev/null +++ b/fastertransformer/cuda/cub/block/block_load.cuh @@ -0,0 +1,1241 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2016, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Operations for reading linear tiles of data into the CUDA thread block. + */ + +#pragma once + +#include + +#include "block_exchange.cuh" +#include "../iterator/cache_modified_input_iterator.cuh" +#include "../util_ptx.cuh" +#include "../util_macro.cuh" +#include "../util_type.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \addtogroup UtilIo + * @{ + */ + + +/******************************************************************//** + * \name Blocked arrangement I/O (direct) + *********************************************************************/ +//@{ + + +/** + * \brief Load a linear segment of items into a blocked arrangement across the thread block. + * + * \blocked + * + * \tparam T [inferred] The data type to load. + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam InputIteratorT [inferred] The random-access iterator type for input \iterator. + */ +template < + typename InputT, + int ITEMS_PER_THREAD, + typename InputIteratorT> +__device__ __forceinline__ void LoadDirectBlocked( + int linear_tid, ///< [in] A suitable 1D thread-identifier for the calling thread (e.g., (threadIdx.y * blockDim.x) + linear_tid for 2D thread blocks) + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD]) ///< [out] Data to load +{ + InputIteratorT thread_itr = block_itr + (linear_tid * ITEMS_PER_THREAD); + + // Load directly in thread-blocked order + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + items[ITEM] = thread_itr[ITEM]; + } +} + + +/** + * \brief Load a linear segment of items into a blocked arrangement across the thread block, guarded by range. + * + * \blocked + * + * \tparam T [inferred] The data type to load. + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam InputIteratorT [inferred] The random-access iterator type for input \iterator. + */ +template < + typename InputT, + int ITEMS_PER_THREAD, + typename InputIteratorT> +__device__ __forceinline__ void LoadDirectBlocked( + int linear_tid, ///< [in] A suitable 1D thread-identifier for the calling thread (e.g., (threadIdx.y * blockDim.x) + linear_tid for 2D thread blocks) + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD], ///< [out] Data to load + int valid_items) ///< [in] Number of valid items to load +{ + InputIteratorT thread_itr = block_itr + (linear_tid * ITEMS_PER_THREAD); + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + if ((linear_tid * ITEMS_PER_THREAD) + ITEM < valid_items) + { + items[ITEM] = thread_itr[ITEM]; + } + } +} + + +/** + * \brief Load a linear segment of items into a blocked arrangement across the thread block, guarded by range, with a fall-back assignment of out-of-bound elements.. + * + * \blocked + * + * \tparam T [inferred] The data type to load. + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam InputIteratorT [inferred] The random-access iterator type for input \iterator. + */ +template < + typename InputT, + typename DefaultT, + int ITEMS_PER_THREAD, + typename InputIteratorT> +__device__ __forceinline__ void LoadDirectBlocked( + int linear_tid, ///< [in] A suitable 1D thread-identifier for the calling thread (e.g., (threadIdx.y * blockDim.x) + linear_tid for 2D thread blocks) + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD], ///< [out] Data to load + int valid_items, ///< [in] Number of valid items to load + DefaultT oob_default) ///< [in] Default value to assign out-of-bound items +{ + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + items[ITEM] = oob_default; + + LoadDirectBlocked(linear_tid, block_itr, items, valid_items); +} + + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +/** + * Internal implementation for load vectorization + */ +template < + CacheLoadModifier MODIFIER, + typename T, + int ITEMS_PER_THREAD> +__device__ __forceinline__ void InternalLoadDirectBlockedVectorized( + int linear_tid, ///< [in] A suitable 1D thread-identifier for the calling thread (e.g., (threadIdx.y * blockDim.x) + linear_tid for 2D thread blocks) + T *block_ptr, ///< [in] Input pointer for loading from + T (&items)[ITEMS_PER_THREAD]) ///< [out] Data to load +{ + // Biggest memory access word that T is a whole multiple of + typedef typename UnitWord::DeviceWord DeviceWord; + + enum + { + TOTAL_WORDS = sizeof(items) / sizeof(DeviceWord), + + VECTOR_SIZE = (TOTAL_WORDS % 4 == 0) ? + 4 : + (TOTAL_WORDS % 2 == 0) ? + 2 : + 1, + + VECTORS_PER_THREAD = TOTAL_WORDS / VECTOR_SIZE, + }; + + // Vector type + typedef typename CubVector::Type Vector; + + // Vector items + Vector vec_items[VECTORS_PER_THREAD]; + + // Aliased input ptr + Vector* vec_ptr = reinterpret_cast(block_ptr) + (linear_tid * VECTORS_PER_THREAD); + + // Load directly in thread-blocked order + #pragma unroll + for (int ITEM = 0; ITEM < VECTORS_PER_THREAD; ITEM++) + { + vec_items[ITEM] = ThreadLoad(vec_ptr + ITEM); + } + + // Copy + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + items[ITEM] = *(reinterpret_cast(vec_items) + ITEM); + } +} + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + +/** + * \brief Load a linear segment of items into a blocked arrangement across the thread block. + * + * \blocked + * + * The input offset (\p block_ptr + \p block_offset) must be quad-item aligned + * + * The following conditions will prevent vectorization and loading will fall back to cub::BLOCK_LOAD_DIRECT: + * - \p ITEMS_PER_THREAD is odd + * - The data type \p T is not a built-in primitive or CUDA vector type (e.g., \p short, \p int2, \p double, \p float2, etc.) + * + * \tparam T [inferred] The data type to load. + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + */ +template < + typename T, + int ITEMS_PER_THREAD> +__device__ __forceinline__ void LoadDirectBlockedVectorized( + int linear_tid, ///< [in] A suitable 1D thread-identifier for the calling thread (e.g., (threadIdx.y * blockDim.x) + linear_tid for 2D thread blocks) + T *block_ptr, ///< [in] Input pointer for loading from + T (&items)[ITEMS_PER_THREAD]) ///< [out] Data to load +{ + InternalLoadDirectBlockedVectorized(linear_tid, block_ptr, items); +} + + +//@} end member group +/******************************************************************//** + * \name Striped arrangement I/O (direct) + *********************************************************************/ +//@{ + + +/** + * \brief Load a linear segment of items into a striped arrangement across the thread block. + * + * \striped + * + * \tparam BLOCK_THREADS The thread block size in threads + * \tparam T [inferred] The data type to load. + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam InputIteratorT [inferred] The random-access iterator type for input \iterator. + */ +template < + int BLOCK_THREADS, + typename InputT, + int ITEMS_PER_THREAD, + typename InputIteratorT> +__device__ __forceinline__ void LoadDirectStriped( + int linear_tid, ///< [in] A suitable 1D thread-identifier for the calling thread (e.g., (threadIdx.y * blockDim.x) + linear_tid for 2D thread blocks) + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD]) ///< [out] Data to load +{ + InputIteratorT thread_itr = block_itr + linear_tid; + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + items[ITEM] = thread_itr[ITEM * BLOCK_THREADS]; + } +} + + +/** + * \brief Load a linear segment of items into a striped arrangement across the thread block, guarded by range + * + * \striped + * + * \tparam BLOCK_THREADS The thread block size in threads + * \tparam T [inferred] The data type to load. + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam InputIteratorT [inferred] The random-access iterator type for input \iterator. + */ +template < + int BLOCK_THREADS, + typename InputT, + int ITEMS_PER_THREAD, + typename InputIteratorT> +__device__ __forceinline__ void LoadDirectStriped( + int linear_tid, ///< [in] A suitable 1D thread-identifier for the calling thread (e.g., (threadIdx.y * blockDim.x) + linear_tid for 2D thread blocks) + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD], ///< [out] Data to load + int valid_items) ///< [in] Number of valid items to load +{ + InputIteratorT thread_itr = block_itr + linear_tid; + + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + if (linear_tid + (ITEM * BLOCK_THREADS) < valid_items) + { + items[ITEM] = thread_itr[ITEM * BLOCK_THREADS]; + } + } +} + + +/** + * \brief Load a linear segment of items into a striped arrangement across the thread block, guarded by range, with a fall-back assignment of out-of-bound elements. + * + * \striped + * + * \tparam BLOCK_THREADS The thread block size in threads + * \tparam T [inferred] The data type to load. + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam InputIteratorT [inferred] The random-access iterator type for input \iterator. + */ +template < + int BLOCK_THREADS, + typename InputT, + typename DefaultT, + int ITEMS_PER_THREAD, + typename InputIteratorT> +__device__ __forceinline__ void LoadDirectStriped( + int linear_tid, ///< [in] A suitable 1D thread-identifier for the calling thread (e.g., (threadIdx.y * blockDim.x) + linear_tid for 2D thread blocks) + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD], ///< [out] Data to load + int valid_items, ///< [in] Number of valid items to load + DefaultT oob_default) ///< [in] Default value to assign out-of-bound items +{ + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + items[ITEM] = oob_default; + + LoadDirectStriped(linear_tid, block_itr, items, valid_items); +} + + + +//@} end member group +/******************************************************************//** + * \name Warp-striped arrangement I/O (direct) + *********************************************************************/ +//@{ + + +/** + * \brief Load a linear segment of items into a warp-striped arrangement across the thread block. + * + * \warpstriped + * + * \par Usage Considerations + * The number of threads in the thread block must be a multiple of the architecture's warp size. + * + * \tparam T [inferred] The data type to load. + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam InputIteratorT [inferred] The random-access iterator type for input \iterator. + */ +template < + typename InputT, + int ITEMS_PER_THREAD, + typename InputIteratorT> +__device__ __forceinline__ void LoadDirectWarpStriped( + int linear_tid, ///< [in] A suitable 1D thread-identifier for the calling thread (e.g., (threadIdx.y * blockDim.x) + linear_tid for 2D thread blocks) + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD]) ///< [out] Data to load +{ + int tid = linear_tid & (CUB_PTX_WARP_THREADS - 1); + int wid = linear_tid >> CUB_PTX_LOG_WARP_THREADS; + int warp_offset = wid * CUB_PTX_WARP_THREADS * ITEMS_PER_THREAD; + + InputIteratorT thread_itr = block_itr + warp_offset + tid ; + + // Load directly in warp-striped order + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + items[ITEM] = thread_itr[(ITEM * CUB_PTX_WARP_THREADS)]; + } +} + + +/** + * \brief Load a linear segment of items into a warp-striped arrangement across the thread block, guarded by range + * + * \warpstriped + * + * \par Usage Considerations + * The number of threads in the thread block must be a multiple of the architecture's warp size. + * + * \tparam T [inferred] The data type to load. + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam InputIteratorT [inferred] The random-access iterator type for input \iterator. + */ +template < + typename InputT, + int ITEMS_PER_THREAD, + typename InputIteratorT> +__device__ __forceinline__ void LoadDirectWarpStriped( + int linear_tid, ///< [in] A suitable 1D thread-identifier for the calling thread (e.g., (threadIdx.y * blockDim.x) + linear_tid for 2D thread blocks) + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD], ///< [out] Data to load + int valid_items) ///< [in] Number of valid items to load +{ + int tid = linear_tid & (CUB_PTX_WARP_THREADS - 1); + int wid = linear_tid >> CUB_PTX_LOG_WARP_THREADS; + int warp_offset = wid * CUB_PTX_WARP_THREADS * ITEMS_PER_THREAD; + + InputIteratorT thread_itr = block_itr + warp_offset + tid ; + + // Load directly in warp-striped order + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + if (warp_offset + tid + (ITEM * CUB_PTX_WARP_THREADS) < valid_items) + { + items[ITEM] = thread_itr[(ITEM * CUB_PTX_WARP_THREADS)]; + } + } +} + + +/** + * \brief Load a linear segment of items into a warp-striped arrangement across the thread block, guarded by range, with a fall-back assignment of out-of-bound elements. + * + * \warpstriped + * + * \par Usage Considerations + * The number of threads in the thread block must be a multiple of the architecture's warp size. + * + * \tparam T [inferred] The data type to load. + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam InputIteratorT [inferred] The random-access iterator type for input \iterator. + */ +template < + typename InputT, + typename DefaultT, + int ITEMS_PER_THREAD, + typename InputIteratorT> +__device__ __forceinline__ void LoadDirectWarpStriped( + int linear_tid, ///< [in] A suitable 1D thread-identifier for the calling thread (e.g., (threadIdx.y * blockDim.x) + linear_tid for 2D thread blocks) + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD], ///< [out] Data to load + int valid_items, ///< [in] Number of valid items to load + DefaultT oob_default) ///< [in] Default value to assign out-of-bound items +{ + // Load directly in warp-striped order + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + items[ITEM] = oob_default; + + LoadDirectWarpStriped(linear_tid, block_itr, items, valid_items); +} + + + +//@} end member group + +/** @} */ // end group UtilIo + + + +//----------------------------------------------------------------------------- +// Generic BlockLoad abstraction +//----------------------------------------------------------------------------- + +/** + * \brief cub::BlockLoadAlgorithm enumerates alternative algorithms for cub::BlockLoad to read a linear segment of data from memory into a blocked arrangement across a CUDA thread block. + */ + +/** + * \brief cub::BlockLoadAlgorithm enumerates alternative algorithms for cub::BlockLoad to read a linear segment of data from memory into a blocked arrangement across a CUDA thread block. + */ +enum BlockLoadAlgorithm +{ + /** + * \par Overview + * + * A [blocked arrangement](index.html#sec5sec3) of data is read + * directly from memory. + * + * \par Performance Considerations + * - The utilization of memory transactions (coalescing) decreases as the + * access stride between threads increases (i.e., the number items per thread). + */ + BLOCK_LOAD_DIRECT, + + /** + * \par Overview + * + * A [blocked arrangement](index.html#sec5sec3) of data is read + * from memory using CUDA's built-in vectorized loads as a coalescing optimization. + * For example, ld.global.v4.s32 instructions will be generated + * when \p T = \p int and \p ITEMS_PER_THREAD % 4 == 0. + * + * \par Performance Considerations + * - The utilization of memory transactions (coalescing) remains high until the the + * access stride between threads (i.e., the number items per thread) exceeds the + * maximum vector load width (typically 4 items or 64B, whichever is lower). + * - The following conditions will prevent vectorization and loading will fall back to cub::BLOCK_LOAD_DIRECT: + * - \p ITEMS_PER_THREAD is odd + * - The \p InputIteratorTis not a simple pointer type + * - The block input offset is not quadword-aligned + * - The data type \p T is not a built-in primitive or CUDA vector type (e.g., \p short, \p int2, \p double, \p float2, etc.) + */ + BLOCK_LOAD_VECTORIZE, + + /** + * \par Overview + * + * A [striped arrangement](index.html#sec5sec3) of data is read + * efficiently from memory and then locally transposed into a + * [blocked arrangement](index.html#sec5sec3). + * + * \par Performance Considerations + * - The utilization of memory transactions (coalescing) remains high regardless + * of items loaded per thread. + * - The local reordering incurs slightly longer latencies and throughput than the + * direct cub::BLOCK_LOAD_DIRECT and cub::BLOCK_LOAD_VECTORIZE alternatives. + */ + BLOCK_LOAD_TRANSPOSE, + + + /** + * \par Overview + * + * A [warp-striped arrangement](index.html#sec5sec3) of data is + * read efficiently from memory and then locally transposed into a + * [blocked arrangement](index.html#sec5sec3). + * + * \par Usage Considerations + * - BLOCK_THREADS must be a multiple of WARP_THREADS + * + * \par Performance Considerations + * - The utilization of memory transactions (coalescing) remains high regardless + * of items loaded per thread. + * - The local reordering incurs slightly larger latencies than the + * direct cub::BLOCK_LOAD_DIRECT and cub::BLOCK_LOAD_VECTORIZE alternatives. + * - Provisions more shared storage, but incurs smaller latencies than the + * BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED alternative. + */ + BLOCK_LOAD_WARP_TRANSPOSE, + + + /** + * \par Overview + * + * Like \p BLOCK_LOAD_WARP_TRANSPOSE, a [warp-striped arrangement](index.html#sec5sec3) + * of data is read directly from memory and then is locally transposed into a + * [blocked arrangement](index.html#sec5sec3). To reduce the shared memory + * requirement, only one warp's worth of shared memory is provisioned and is + * subsequently time-sliced among warps. + * + * \par Usage Considerations + * - BLOCK_THREADS must be a multiple of WARP_THREADS + * + * \par Performance Considerations + * - The utilization of memory transactions (coalescing) remains high regardless + * of items loaded per thread. + * - Provisions less shared memory temporary storage, but incurs larger + * latencies than the BLOCK_LOAD_WARP_TRANSPOSE alternative. + */ + BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED, +}; + + +/** + * \brief The BlockLoad class provides [collective](index.html#sec0) data movement methods for loading a linear segment of items from memory into a [blocked arrangement](index.html#sec5sec3) across a CUDA thread block. ![](block_load_logo.png) + * \ingroup BlockModule + * \ingroup UtilIo + * + * \tparam InputT The data type to read into (which must be convertible from the input iterator's value type). + * \tparam BLOCK_DIM_X The thread block length in threads along the X dimension + * \tparam ITEMS_PER_THREAD The number of consecutive items partitioned onto each thread. + * \tparam ALGORITHM [optional] cub::BlockLoadAlgorithm tuning policy. default: cub::BLOCK_LOAD_DIRECT. + * \tparam WARP_TIME_SLICING [optional] Whether or not only one warp's worth of shared memory should be allocated and time-sliced among block-warps during any load-related data transpositions (versus each warp having its own storage). (default: false) + * \tparam BLOCK_DIM_Y [optional] The thread block length in threads along the Y dimension (default: 1) + * \tparam BLOCK_DIM_Z [optional] The thread block length in threads along the Z dimension (default: 1) + * \tparam PTX_ARCH [optional] \ptxversion + * + * \par Overview + * - The BlockLoad class provides a single data movement abstraction that can be specialized + * to implement different cub::BlockLoadAlgorithm strategies. This facilitates different + * performance policies for different architectures, data types, granularity sizes, etc. + * - BlockLoad can be optionally specialized by different data movement strategies: + * -# cub::BLOCK_LOAD_DIRECT. A [blocked arrangement](index.html#sec5sec3) + * of data is read directly from memory. [More...](\ref cub::BlockLoadAlgorithm) + * -# cub::BLOCK_LOAD_VECTORIZE. A [blocked arrangement](index.html#sec5sec3) + * of data is read directly from memory using CUDA's built-in vectorized loads as a + * coalescing optimization. [More...](\ref cub::BlockLoadAlgorithm) + * -# cub::BLOCK_LOAD_TRANSPOSE. A [striped arrangement](index.html#sec5sec3) + * of data is read directly from memory and is then locally transposed into a + * [blocked arrangement](index.html#sec5sec3). [More...](\ref cub::BlockLoadAlgorithm) + * -# cub::BLOCK_LOAD_WARP_TRANSPOSE. A [warp-striped arrangement](index.html#sec5sec3) + * of data is read directly from memory and is then locally transposed into a + * [blocked arrangement](index.html#sec5sec3). [More...](\ref cub::BlockLoadAlgorithm) + * -# cub::BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED,. A [warp-striped arrangement](index.html#sec5sec3) + * of data is read directly from memory and is then locally transposed into a + * [blocked arrangement](index.html#sec5sec3) one warp at a time. [More...](\ref cub::BlockLoadAlgorithm) + * - \rowmajor + * + * \par A Simple Example + * \blockcollective{BlockLoad} + * \par + * The code snippet below illustrates the loading of a linear + * segment of 512 integers into a "blocked" arrangement across 128 threads where each + * thread owns 4 consecutive items. The load is specialized for \p BLOCK_LOAD_WARP_TRANSPOSE, + * meaning memory references are efficiently coalesced using a warp-striped access + * pattern (after which items are locally reordered among threads). + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(int *d_data, ...) + * { + * // Specialize BlockLoad for a 1D block of 128 threads owning 4 integer items each + * typedef cub::BlockLoad BlockLoad; + * + * // Allocate shared memory for BlockLoad + * __shared__ typename BlockLoad::TempStorage temp_storage; + * + * // Load a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * BlockLoad(temp_storage).Load(d_data, thread_data); + * + * \endcode + * \par + * Suppose the input \p d_data is 0, 1, 2, 3, 4, 5, .... + * The set of \p thread_data across the block of threads in those threads will be + * { [0,1,2,3], [4,5,6,7], ..., [508,509,510,511] }. + * + */ +template < + typename InputT, + int BLOCK_DIM_X, + int ITEMS_PER_THREAD, + BlockLoadAlgorithm ALGORITHM = BLOCK_LOAD_DIRECT, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int PTX_ARCH = CUB_PTX_ARCH> +class BlockLoad +{ +private: + + /****************************************************************************** + * Constants and typed definitions + ******************************************************************************/ + + /// Constants + enum + { + /// The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + }; + + + /****************************************************************************** + * Algorithmic variants + ******************************************************************************/ + + /// Load helper + template + struct LoadInternal; + + + /** + * BLOCK_LOAD_DIRECT specialization of load helper + */ + template + struct LoadInternal + { + /// Shared memory storage layout type + typedef NullType TempStorage; + + /// Linear thread-id + int linear_tid; + + /// Constructor + __device__ __forceinline__ LoadInternal( + TempStorage &/*temp_storage*/, + int linear_tid) + : + linear_tid(linear_tid) + {} + + /// Load a linear segment of items from memory + template + __device__ __forceinline__ void Load( + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD]) ///< [out] Data to load + { + LoadDirectBlocked(linear_tid, block_itr, items); + } + + /// Load a linear segment of items from memory, guarded by range + template + __device__ __forceinline__ void Load( + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD], ///< [out] Data to load + int valid_items) ///< [in] Number of valid items to load + { + LoadDirectBlocked(linear_tid, block_itr, items, valid_items); + } + + /// Load a linear segment of items from memory, guarded by range, with a fall-back assignment of out-of-bound elements + template + __device__ __forceinline__ void Load( + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD], ///< [out] Data to load + int valid_items, ///< [in] Number of valid items to load + DefaultT oob_default) ///< [in] Default value to assign out-of-bound items + { + LoadDirectBlocked(linear_tid, block_itr, items, valid_items, oob_default); + } + + }; + + + /** + * BLOCK_LOAD_VECTORIZE specialization of load helper + */ + template + struct LoadInternal + { + /// Shared memory storage layout type + typedef NullType TempStorage; + + /// Linear thread-id + int linear_tid; + + /// Constructor + __device__ __forceinline__ LoadInternal( + TempStorage &/*temp_storage*/, + int linear_tid) + : + linear_tid(linear_tid) + {} + + /// Load a linear segment of items from memory, specialized for native pointer types (attempts vectorization) + template + __device__ __forceinline__ void Load( + InputT *block_ptr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD]) ///< [out] Data to load + { + InternalLoadDirectBlockedVectorized(linear_tid, block_ptr, items); + } + + /// Load a linear segment of items from memory, specialized for native pointer types (attempts vectorization) + template + __device__ __forceinline__ void Load( + const InputT *block_ptr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD]) ///< [out] Data to load + { + InternalLoadDirectBlockedVectorized(linear_tid, block_ptr, items); + } + + /// Load a linear segment of items from memory, specialized for native pointer types (attempts vectorization) + template < + CacheLoadModifier MODIFIER, + typename ValueType, + typename OffsetT> + __device__ __forceinline__ void Load( + CacheModifiedInputIterator block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD]) ///< [out] Data to load + { + InternalLoadDirectBlockedVectorized(linear_tid, block_itr.ptr, items); + } + + /// Load a linear segment of items from memory, specialized for opaque input iterators (skips vectorization) + template + __device__ __forceinline__ void Load( + _InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD]) ///< [out] Data to load + { + LoadDirectBlocked(linear_tid, block_itr, items); + } + + /// Load a linear segment of items from memory, guarded by range (skips vectorization) + template + __device__ __forceinline__ void Load( + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD], ///< [out] Data to load + int valid_items) ///< [in] Number of valid items to load + { + LoadDirectBlocked(linear_tid, block_itr, items, valid_items); + } + + /// Load a linear segment of items from memory, guarded by range, with a fall-back assignment of out-of-bound elements (skips vectorization) + template + __device__ __forceinline__ void Load( + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD], ///< [out] Data to load + int valid_items, ///< [in] Number of valid items to load + DefaultT oob_default) ///< [in] Default value to assign out-of-bound items + { + LoadDirectBlocked(linear_tid, block_itr, items, valid_items, oob_default); + } + + }; + + + /** + * BLOCK_LOAD_TRANSPOSE specialization of load helper + */ + template + struct LoadInternal + { + // BlockExchange utility type for keys + typedef BlockExchange BlockExchange; + + /// Shared memory storage layout type + struct _TempStorage : BlockExchange::TempStorage + {}; + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + /// Thread reference to shared storage + _TempStorage &temp_storage; + + /// Linear thread-id + int linear_tid; + + /// Constructor + __device__ __forceinline__ LoadInternal( + TempStorage &temp_storage, + int linear_tid) + : + temp_storage(temp_storage.Alias()), + linear_tid(linear_tid) + {} + + /// Load a linear segment of items from memory + template + __device__ __forceinline__ void Load( + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD]) ///< [out] Data to load{ + { + LoadDirectStriped(linear_tid, block_itr, items); + BlockExchange(temp_storage).StripedToBlocked(items, items); + } + + /// Load a linear segment of items from memory, guarded by range + template + __device__ __forceinline__ void Load( + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD], ///< [out] Data to load + int valid_items) ///< [in] Number of valid items to load + { + LoadDirectStriped(linear_tid, block_itr, items, valid_items); + BlockExchange(temp_storage).StripedToBlocked(items, items); + } + + /// Load a linear segment of items from memory, guarded by range, with a fall-back assignment of out-of-bound elements + template + __device__ __forceinline__ void Load( + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD], ///< [out] Data to load + int valid_items, ///< [in] Number of valid items to load + DefaultT oob_default) ///< [in] Default value to assign out-of-bound items + { + LoadDirectStriped(linear_tid, block_itr, items, valid_items, oob_default); + BlockExchange(temp_storage).StripedToBlocked(items, items); + } + + }; + + + /** + * BLOCK_LOAD_WARP_TRANSPOSE specialization of load helper + */ + template + struct LoadInternal + { + enum + { + WARP_THREADS = CUB_WARP_THREADS(PTX_ARCH) + }; + + // Assert BLOCK_THREADS must be a multiple of WARP_THREADS + CUB_STATIC_ASSERT((BLOCK_THREADS % WARP_THREADS == 0), "BLOCK_THREADS must be a multiple of WARP_THREADS"); + + // BlockExchange utility type for keys + typedef BlockExchange BlockExchange; + + /// Shared memory storage layout type + struct _TempStorage : BlockExchange::TempStorage + {}; + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + /// Thread reference to shared storage + _TempStorage &temp_storage; + + /// Linear thread-id + int linear_tid; + + /// Constructor + __device__ __forceinline__ LoadInternal( + TempStorage &temp_storage, + int linear_tid) + : + temp_storage(temp_storage.Alias()), + linear_tid(linear_tid) + {} + + /// Load a linear segment of items from memory + template + __device__ __forceinline__ void Load( + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD]) ///< [out] Data to load{ + { + LoadDirectWarpStriped(linear_tid, block_itr, items); + BlockExchange(temp_storage).WarpStripedToBlocked(items, items); + } + + /// Load a linear segment of items from memory, guarded by range + template + __device__ __forceinline__ void Load( + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD], ///< [out] Data to load + int valid_items) ///< [in] Number of valid items to load + { + LoadDirectWarpStriped(linear_tid, block_itr, items, valid_items); + BlockExchange(temp_storage).WarpStripedToBlocked(items, items); + } + + + /// Load a linear segment of items from memory, guarded by range, with a fall-back assignment of out-of-bound elements + template + __device__ __forceinline__ void Load( + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD], ///< [out] Data to load + int valid_items, ///< [in] Number of valid items to load + DefaultT oob_default) ///< [in] Default value to assign out-of-bound items + { + LoadDirectWarpStriped(linear_tid, block_itr, items, valid_items, oob_default); + BlockExchange(temp_storage).WarpStripedToBlocked(items, items); + } + }; + + + /** + * BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED specialization of load helper + */ + template + struct LoadInternal + { + enum + { + WARP_THREADS = CUB_WARP_THREADS(PTX_ARCH) + }; + + // Assert BLOCK_THREADS must be a multiple of WARP_THREADS + CUB_STATIC_ASSERT((BLOCK_THREADS % WARP_THREADS == 0), "BLOCK_THREADS must be a multiple of WARP_THREADS"); + + // BlockExchange utility type for keys + typedef BlockExchange BlockExchange; + + /// Shared memory storage layout type + struct _TempStorage : BlockExchange::TempStorage + {}; + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + /// Thread reference to shared storage + _TempStorage &temp_storage; + + /// Linear thread-id + int linear_tid; + + /// Constructor + __device__ __forceinline__ LoadInternal( + TempStorage &temp_storage, + int linear_tid) + : + temp_storage(temp_storage.Alias()), + linear_tid(linear_tid) + {} + + /// Load a linear segment of items from memory + template + __device__ __forceinline__ void Load( + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD]) ///< [out] Data to load{ + { + LoadDirectWarpStriped(linear_tid, block_itr, items); + BlockExchange(temp_storage).WarpStripedToBlocked(items, items); + } + + /// Load a linear segment of items from memory, guarded by range + template + __device__ __forceinline__ void Load( + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD], ///< [out] Data to load + int valid_items) ///< [in] Number of valid items to load + { + LoadDirectWarpStriped(linear_tid, block_itr, items, valid_items); + BlockExchange(temp_storage).WarpStripedToBlocked(items, items); + } + + + /// Load a linear segment of items from memory, guarded by range, with a fall-back assignment of out-of-bound elements + template + __device__ __forceinline__ void Load( + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD], ///< [out] Data to load + int valid_items, ///< [in] Number of valid items to load + DefaultT oob_default) ///< [in] Default value to assign out-of-bound items + { + LoadDirectWarpStriped(linear_tid, block_itr, items, valid_items, oob_default); + BlockExchange(temp_storage).WarpStripedToBlocked(items, items); + } + }; + + + /****************************************************************************** + * Type definitions + ******************************************************************************/ + + /// Internal load implementation to use + typedef LoadInternal InternalLoad; + + + /// Shared memory storage layout type + typedef typename InternalLoad::TempStorage _TempStorage; + + + /****************************************************************************** + * Utility methods + ******************************************************************************/ + + /// Internal storage allocator + __device__ __forceinline__ _TempStorage& PrivateStorage() + { + __shared__ _TempStorage private_storage; + return private_storage; + } + + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + /// Thread reference to shared storage + _TempStorage &temp_storage; + + /// Linear thread-id + int linear_tid; + +public: + + /// \smemstorage{BlockLoad} + struct TempStorage : Uninitialized<_TempStorage> {}; + + + /******************************************************************//** + * \name Collective constructors + *********************************************************************/ + //@{ + + /** + * \brief Collective constructor using a private static allocation of shared memory as temporary storage. + */ + __device__ __forceinline__ BlockLoad() + : + temp_storage(PrivateStorage()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + /** + * \brief Collective constructor using the specified memory allocation as temporary storage. + */ + __device__ __forceinline__ BlockLoad( + TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + + + //@} end member group + /******************************************************************//** + * \name Data movement + *********************************************************************/ + //@{ + + + /** + * \brief Load a linear segment of items from memory. + * + * \par + * - \blocked + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates the loading of a linear + * segment of 512 integers into a "blocked" arrangement across 128 threads where each + * thread owns 4 consecutive items. The load is specialized for \p BLOCK_LOAD_WARP_TRANSPOSE, + * meaning memory references are efficiently coalesced using a warp-striped access + * pattern (after which items are locally reordered among threads). + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(int *d_data, ...) + * { + * // Specialize BlockLoad for a 1D block of 128 threads owning 4 integer items each + * typedef cub::BlockLoad BlockLoad; + * + * // Allocate shared memory for BlockLoad + * __shared__ typename BlockLoad::TempStorage temp_storage; + * + * // Load a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * BlockLoad(temp_storage).Load(d_data, thread_data); + * + * \endcode + * \par + * Suppose the input \p d_data is 0, 1, 2, 3, 4, 5, .... + * The set of \p thread_data across the block of threads in those threads will be + * { [0,1,2,3], [4,5,6,7], ..., [508,509,510,511] }. + * + */ + template + __device__ __forceinline__ void Load( + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD]) ///< [out] Data to load + { + InternalLoad(temp_storage, linear_tid).Load(block_itr, items); + } + + + /** + * \brief Load a linear segment of items from memory, guarded by range. + * + * \par + * - \blocked + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates the guarded loading of a linear + * segment of 512 integers into a "blocked" arrangement across 128 threads where each + * thread owns 4 consecutive items. The load is specialized for \p BLOCK_LOAD_WARP_TRANSPOSE, + * meaning memory references are efficiently coalesced using a warp-striped access + * pattern (after which items are locally reordered among threads). + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(int *d_data, int valid_items, ...) + * { + * // Specialize BlockLoad for a 1D block of 128 threads owning 4 integer items each + * typedef cub::BlockLoad BlockLoad; + * + * // Allocate shared memory for BlockLoad + * __shared__ typename BlockLoad::TempStorage temp_storage; + * + * // Load a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * BlockLoad(temp_storage).Load(d_data, thread_data, valid_items); + * + * \endcode + * \par + * Suppose the input \p d_data is 0, 1, 2, 3, 4, 5, 6... and \p valid_items is \p 5. + * The set of \p thread_data across the block of threads in those threads will be + * { [0,1,2,3], [4,?,?,?], ..., [?,?,?,?] }, with only the first two threads + * being unmasked to load portions of valid data (and other items remaining unassigned). + * + */ + template + __device__ __forceinline__ void Load( + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD], ///< [out] Data to load + int valid_items) ///< [in] Number of valid items to load + { + InternalLoad(temp_storage, linear_tid).Load(block_itr, items, valid_items); + } + + + /** + * \brief Load a linear segment of items from memory, guarded by range, with a fall-back assignment of out-of-bound elements + * + * \par + * - \blocked + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates the guarded loading of a linear + * segment of 512 integers into a "blocked" arrangement across 128 threads where each + * thread owns 4 consecutive items. The load is specialized for \p BLOCK_LOAD_WARP_TRANSPOSE, + * meaning memory references are efficiently coalesced using a warp-striped access + * pattern (after which items are locally reordered among threads). + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(int *d_data, int valid_items, ...) + * { + * // Specialize BlockLoad for a 1D block of 128 threads owning 4 integer items each + * typedef cub::BlockLoad BlockLoad; + * + * // Allocate shared memory for BlockLoad + * __shared__ typename BlockLoad::TempStorage temp_storage; + * + * // Load a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * BlockLoad(temp_storage).Load(d_data, thread_data, valid_items, -1); + * + * \endcode + * \par + * Suppose the input \p d_data is 0, 1, 2, 3, 4, 5, 6..., + * \p valid_items is \p 5, and the out-of-bounds default is \p -1. + * The set of \p thread_data across the block of threads in those threads will be + * { [0,1,2,3], [4,-1,-1,-1], ..., [-1,-1,-1,-1] }, with only the first two threads + * being unmasked to load portions of valid data (and other items are assigned \p -1) + * + */ + template + __device__ __forceinline__ void Load( + InputIteratorT block_itr, ///< [in] The thread block's base input iterator for loading from + InputT (&items)[ITEMS_PER_THREAD], ///< [out] Data to load + int valid_items, ///< [in] Number of valid items to load + DefaultT oob_default) ///< [in] Default value to assign out-of-bound items + { + InternalLoad(temp_storage, linear_tid).Load(block_itr, items, valid_items, oob_default); + } + + + //@} end member group + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/block/block_radix_rank.cuh b/fastertransformer/cuda/cub/block/block_radix_rank.cuh new file mode 100644 index 000000000..c26451c66 --- /dev/null +++ b/fastertransformer/cuda/cub/block/block_radix_rank.cuh @@ -0,0 +1,696 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block + */ + +#pragma once + +#include + +#include "../thread/thread_reduce.cuh" +#include "../thread/thread_scan.cuh" +#include "../block/block_scan.cuh" +#include "../util_ptx.cuh" +#include "../util_arch.cuh" +#include "../util_type.cuh" +#include "../util_namespace.cuh" + + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \brief BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block. + * \ingroup BlockModule + * + * \tparam BLOCK_DIM_X The thread block length in threads along the X dimension + * \tparam RADIX_BITS The number of radix bits per digit place + * \tparam IS_DESCENDING Whether or not the sorted-order is high-to-low + * \tparam MEMOIZE_OUTER_SCAN [optional] Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure (default: true for architectures SM35 and newer, false otherwise). See BlockScanAlgorithm::BLOCK_SCAN_RAKING_MEMOIZE for more details. + * \tparam INNER_SCAN_ALGORITHM [optional] The cub::BlockScanAlgorithm algorithm to use (default: cub::BLOCK_SCAN_WARP_SCANS) + * \tparam SMEM_CONFIG [optional] Shared memory bank mode (default: \p cudaSharedMemBankSizeFourByte) + * \tparam BLOCK_DIM_Y [optional] The thread block length in threads along the Y dimension (default: 1) + * \tparam BLOCK_DIM_Z [optional] The thread block length in threads along the Z dimension (default: 1) + * \tparam PTX_ARCH [optional] \ptxversion + * + * \par Overview + * Blah... + * - Keys must be in a form suitable for radix ranking (i.e., unsigned bits). + * - \blocked + * + * \par Performance Considerations + * - \granularity + * + * \par Examples + * \par + * - Example 1: Simple radix rank of 32-bit integer keys + * \code + * #include + * + * template + * __global__ void ExampleKernel(...) + * { + * + * \endcode + */ +template < + int BLOCK_DIM_X, + int RADIX_BITS, + bool IS_DESCENDING, + bool MEMOIZE_OUTER_SCAN = (CUB_PTX_ARCH >= 350) ? true : false, + BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS, + cudaSharedMemConfig SMEM_CONFIG = cudaSharedMemBankSizeFourByte, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int PTX_ARCH = CUB_PTX_ARCH> +class BlockRadixRank +{ +private: + + /****************************************************************************** + * Type definitions and constants + ******************************************************************************/ + + // Integer type for digit counters (to be packed into words of type PackedCounters) + typedef unsigned short DigitCounter; + + // Integer type for packing DigitCounters into columns of shared memory banks + typedef typename If<(SMEM_CONFIG == cudaSharedMemBankSizeEightByte), + unsigned long long, + unsigned int>::Type PackedCounter; + + enum + { + // The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + + RADIX_DIGITS = 1 << RADIX_BITS, + + LOG_WARP_THREADS = CUB_LOG_WARP_THREADS(PTX_ARCH), + WARP_THREADS = 1 << LOG_WARP_THREADS, + WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, + + BYTES_PER_COUNTER = sizeof(DigitCounter), + LOG_BYTES_PER_COUNTER = Log2::VALUE, + + PACKING_RATIO = sizeof(PackedCounter) / sizeof(DigitCounter), + LOG_PACKING_RATIO = Log2::VALUE, + + LOG_COUNTER_LANES = CUB_MAX((RADIX_BITS - LOG_PACKING_RATIO), 0), // Always at least one lane + COUNTER_LANES = 1 << LOG_COUNTER_LANES, + + // The number of packed counters per thread (plus one for padding) + PADDED_COUNTER_LANES = COUNTER_LANES + 1, + RAKING_SEGMENT = PADDED_COUNTER_LANES, + }; + +public: + + enum + { + /// Number of bin-starting offsets tracked per thread + BINS_TRACKED_PER_THREAD = CUB_MAX(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS), + }; + +private: + + + /// BlockScan type + typedef BlockScan< + PackedCounter, + BLOCK_DIM_X, + INNER_SCAN_ALGORITHM, + BLOCK_DIM_Y, + BLOCK_DIM_Z, + PTX_ARCH> + BlockScan; + + + /// Shared memory storage layout type for BlockRadixRank + struct __align__(16) _TempStorage + { + union Aliasable + { + DigitCounter digit_counters[PADDED_COUNTER_LANES][BLOCK_THREADS][PACKING_RATIO]; + PackedCounter raking_grid[BLOCK_THREADS][RAKING_SEGMENT]; + + } aliasable; + + // Storage for scanning local ranks + typename BlockScan::TempStorage block_scan; + }; + + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + /// Shared storage reference + _TempStorage &temp_storage; + + /// Linear thread-id + unsigned int linear_tid; + + /// Copy of raking segment, promoted to registers + PackedCounter cached_segment[RAKING_SEGMENT]; + + + /****************************************************************************** + * Utility methods + ******************************************************************************/ + + /** + * Internal storage allocator + */ + __device__ __forceinline__ _TempStorage& PrivateStorage() + { + __shared__ _TempStorage private_storage; + return private_storage; + } + + + /** + * Performs upsweep raking reduction, returning the aggregate + */ + __device__ __forceinline__ PackedCounter Upsweep() + { + PackedCounter *smem_raking_ptr = temp_storage.aliasable.raking_grid[linear_tid]; + PackedCounter *raking_ptr; + + if (MEMOIZE_OUTER_SCAN) + { + // Copy data into registers + #pragma unroll + for (int i = 0; i < RAKING_SEGMENT; i++) + { + cached_segment[i] = smem_raking_ptr[i]; + } + raking_ptr = cached_segment; + } + else + { + raking_ptr = smem_raking_ptr; + } + + return internal::ThreadReduce(raking_ptr, Sum()); + } + + + /// Performs exclusive downsweep raking scan + __device__ __forceinline__ void ExclusiveDownsweep( + PackedCounter raking_partial) + { + PackedCounter *smem_raking_ptr = temp_storage.aliasable.raking_grid[linear_tid]; + + PackedCounter *raking_ptr = (MEMOIZE_OUTER_SCAN) ? + cached_segment : + smem_raking_ptr; + + // Exclusive raking downsweep scan + internal::ThreadScanExclusive(raking_ptr, raking_ptr, Sum(), raking_partial); + + if (MEMOIZE_OUTER_SCAN) + { + // Copy data back to smem + #pragma unroll + for (int i = 0; i < RAKING_SEGMENT; i++) + { + smem_raking_ptr[i] = cached_segment[i]; + } + } + } + + + /** + * Reset shared memory digit counters + */ + __device__ __forceinline__ void ResetCounters() + { + // Reset shared memory digit counters + #pragma unroll + for (int LANE = 0; LANE < PADDED_COUNTER_LANES; LANE++) + { + *((PackedCounter*) temp_storage.aliasable.digit_counters[LANE][linear_tid]) = 0; + } + } + + + /** + * Block-scan prefix callback + */ + struct PrefixCallBack + { + __device__ __forceinline__ PackedCounter operator()(PackedCounter block_aggregate) + { + PackedCounter block_prefix = 0; + + // Propagate totals in packed fields + #pragma unroll + for (int PACKED = 1; PACKED < PACKING_RATIO; PACKED++) + { + block_prefix += block_aggregate << (sizeof(DigitCounter) * 8 * PACKED); + } + + return block_prefix; + } + }; + + + /** + * Scan shared memory digit counters. + */ + __device__ __forceinline__ void ScanCounters() + { + // Upsweep scan + PackedCounter raking_partial = Upsweep(); + + // Compute exclusive sum + PackedCounter exclusive_partial; + PrefixCallBack prefix_call_back; + BlockScan(temp_storage.block_scan).ExclusiveSum(raking_partial, exclusive_partial, prefix_call_back); + + // Downsweep scan with exclusive partial + ExclusiveDownsweep(exclusive_partial); + } + +public: + + /// \smemstorage{BlockScan} + struct TempStorage : Uninitialized<_TempStorage> {}; + + + /******************************************************************//** + * \name Collective constructors + *********************************************************************/ + //@{ + + /** + * \brief Collective constructor using a private static allocation of shared memory as temporary storage. + */ + __device__ __forceinline__ BlockRadixRank() + : + temp_storage(PrivateStorage()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + /** + * \brief Collective constructor using the specified memory allocation as temporary storage. + */ + __device__ __forceinline__ BlockRadixRank( + TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + //@} end member group + /******************************************************************//** + * \name Raking + *********************************************************************/ + //@{ + + /** + * \brief Rank keys. + */ + template < + typename UnsignedBits, + int KEYS_PER_THREAD> + __device__ __forceinline__ void RankKeys( + UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile + int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile + int current_bit, ///< [in] The least-significant bit position of the current digit to extract + int num_bits) ///< [in] The number of bits in the current digit + { + DigitCounter thread_prefixes[KEYS_PER_THREAD]; // For each key, the count of previous keys in this tile having the same digit + DigitCounter* digit_counters[KEYS_PER_THREAD]; // For each key, the byte-offset of its corresponding digit counter in smem + + // Reset shared memory digit counters + ResetCounters(); + + #pragma unroll + for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) + { + // Get digit + unsigned int digit = BFE(keys[ITEM], current_bit, num_bits); + + // Get sub-counter + unsigned int sub_counter = digit >> LOG_COUNTER_LANES; + + // Get counter lane + unsigned int counter_lane = digit & (COUNTER_LANES - 1); + + if (IS_DESCENDING) + { + sub_counter = PACKING_RATIO - 1 - sub_counter; + counter_lane = COUNTER_LANES - 1 - counter_lane; + } + + // Pointer to smem digit counter + digit_counters[ITEM] = &temp_storage.aliasable.digit_counters[counter_lane][linear_tid][sub_counter]; + + // Load thread-exclusive prefix + thread_prefixes[ITEM] = *digit_counters[ITEM]; + + // Store inclusive prefix + *digit_counters[ITEM] = thread_prefixes[ITEM] + 1; + } + + CTA_SYNC(); + + // Scan shared memory counters + ScanCounters(); + + CTA_SYNC(); + + // Extract the local ranks of each key + for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) + { + // Add in thread block exclusive prefix + ranks[ITEM] = thread_prefixes[ITEM] + *digit_counters[ITEM]; + } + } + + + /** + * \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread. + */ + template < + typename UnsignedBits, + int KEYS_PER_THREAD> + __device__ __forceinline__ void RankKeys( + UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile + int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter) + int current_bit, ///< [in] The least-significant bit position of the current digit to extract + int num_bits, ///< [in] The number of bits in the current digit + int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] + { + // Rank keys + RankKeys(keys, ranks, current_bit, num_bits); + + // Get the inclusive and exclusive digit totals corresponding to the calling thread. + #pragma unroll + for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) + { + int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track; + + if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) + { + if (IS_DESCENDING) + bin_idx = RADIX_DIGITS - bin_idx - 1; + + // Obtain ex/inclusive digit counts. (Unfortunately these all reside in the + // first counter column, resulting in unavoidable bank conflicts.) + unsigned int counter_lane = (bin_idx & (COUNTER_LANES - 1)); + unsigned int sub_counter = bin_idx >> (LOG_COUNTER_LANES); + + exclusive_digit_prefix[track] = temp_storage.aliasable.digit_counters[counter_lane][0][sub_counter]; + } + } + } +}; + + + + + +/** + * Radix-rank using match.any + */ +template < + int BLOCK_DIM_X, + int RADIX_BITS, + bool IS_DESCENDING, + BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int PTX_ARCH = CUB_PTX_ARCH> +class BlockRadixRankMatch +{ +private: + + /****************************************************************************** + * Type definitions and constants + ******************************************************************************/ + + typedef int32_t RankT; + typedef int32_t DigitCounterT; + + enum + { + // The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + + RADIX_DIGITS = 1 << RADIX_BITS, + + LOG_WARP_THREADS = CUB_LOG_WARP_THREADS(PTX_ARCH), + WARP_THREADS = 1 << LOG_WARP_THREADS, + WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, + + PADDED_WARPS = ((WARPS & 0x1) == 0) ? + WARPS + 1 : + WARPS, + + COUNTERS = PADDED_WARPS * RADIX_DIGITS, + RAKING_SEGMENT = (COUNTERS + BLOCK_THREADS - 1) / BLOCK_THREADS, + PADDED_RAKING_SEGMENT = ((RAKING_SEGMENT & 0x1) == 0) ? + RAKING_SEGMENT + 1 : + RAKING_SEGMENT, + }; + +public: + + enum + { + /// Number of bin-starting offsets tracked per thread + BINS_TRACKED_PER_THREAD = CUB_MAX(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS), + }; + +private: + + /// BlockScan type + typedef BlockScan< + DigitCounterT, + BLOCK_THREADS, + INNER_SCAN_ALGORITHM, + BLOCK_DIM_Y, + BLOCK_DIM_Z, + PTX_ARCH> + BlockScanT; + + + /// Shared memory storage layout type for BlockRadixRank + struct __align__(16) _TempStorage + { + typename BlockScanT::TempStorage block_scan; + + union __align__(16) Aliasable + { + volatile DigitCounterT warp_digit_counters[RADIX_DIGITS][PADDED_WARPS]; + DigitCounterT raking_grid[BLOCK_THREADS][PADDED_RAKING_SEGMENT]; + + } aliasable; + }; + + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + /// Shared storage reference + _TempStorage &temp_storage; + + /// Linear thread-id + unsigned int linear_tid; + + + +public: + + /// \smemstorage{BlockScan} + struct TempStorage : Uninitialized<_TempStorage> {}; + + + /******************************************************************//** + * \name Collective constructors + *********************************************************************/ + //@{ + + + /** + * \brief Collective constructor using the specified memory allocation as temporary storage. + */ + __device__ __forceinline__ BlockRadixRankMatch( + TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + //@} end member group + /******************************************************************//** + * \name Raking + *********************************************************************/ + //@{ + + /** + * \brief Rank keys. + */ + template < + typename UnsignedBits, + int KEYS_PER_THREAD> + __device__ __forceinline__ void RankKeys( + UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile + int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile + int current_bit, ///< [in] The least-significant bit position of the current digit to extract + int num_bits) ///< [in] The number of bits in the current digit + { + // Initialize shared digit counters + + #pragma unroll + for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM) + temp_storage.aliasable.raking_grid[linear_tid][ITEM] = 0; + + CTA_SYNC(); + + // Each warp will strip-mine its section of input, one strip at a time + + volatile DigitCounterT *digit_counters[KEYS_PER_THREAD]; + uint32_t warp_id = linear_tid >> LOG_WARP_THREADS; + uint32_t lane_mask_lt = LaneMaskLt(); + + #pragma unroll + for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) + { + // My digit + uint32_t digit = BFE(keys[ITEM], current_bit, num_bits); + + if (IS_DESCENDING) + digit = RADIX_DIGITS - digit - 1; + + // Mask of peers who have same digit as me + uint32_t peer_mask = MatchAny(digit); + + // Pointer to smem digit counter for this key + digit_counters[ITEM] = &temp_storage.aliasable.warp_digit_counters[digit][warp_id]; + + // Number of occurrences in previous strips + DigitCounterT warp_digit_prefix = *digit_counters[ITEM]; + + // Warp-sync + WARP_SYNC(0xFFFFFFFF); + + // Number of peers having same digit as me + int32_t digit_count = __popc(peer_mask); + + // Number of lower-ranked peers having same digit seen so far + int32_t peer_digit_prefix = __popc(peer_mask & lane_mask_lt); + + if (peer_digit_prefix == 0) + { + // First thread for each digit updates the shared warp counter + *digit_counters[ITEM] = DigitCounterT(warp_digit_prefix + digit_count); + } + + // Warp-sync + WARP_SYNC(0xFFFFFFFF); + + // Number of prior keys having same digit + ranks[ITEM] = warp_digit_prefix + DigitCounterT(peer_digit_prefix); + } + + CTA_SYNC(); + + // Scan warp counters + + DigitCounterT scan_counters[PADDED_RAKING_SEGMENT]; + + #pragma unroll + for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM) + scan_counters[ITEM] = temp_storage.aliasable.raking_grid[linear_tid][ITEM]; + + BlockScanT(temp_storage.block_scan).ExclusiveSum(scan_counters, scan_counters); + + #pragma unroll + for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM) + temp_storage.aliasable.raking_grid[linear_tid][ITEM] = scan_counters[ITEM]; + + CTA_SYNC(); + + // Seed ranks with counter values from previous warps + #pragma unroll + for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) + ranks[ITEM] += *digit_counters[ITEM]; + } + + + /** + * \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread. + */ + template < + typename UnsignedBits, + int KEYS_PER_THREAD> + __device__ __forceinline__ void RankKeys( + UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile + int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter) + int current_bit, ///< [in] The least-significant bit position of the current digit to extract + int num_bits, ///< [in] The number of bits in the current digit + int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] + { + RankKeys(keys, ranks, current_bit, num_bits); + + // Get exclusive count for each digit + #pragma unroll + for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) + { + int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track; + + if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) + { + if (IS_DESCENDING) + bin_idx = RADIX_DIGITS - bin_idx - 1; + + exclusive_digit_prefix[track] = temp_storage.aliasable.warp_digit_counters[bin_idx][0]; + } + } + } +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/block/block_radix_sort.cuh b/fastertransformer/cuda/cub/block/block_radix_sort.cuh new file mode 100644 index 000000000..ac0c9f85b --- /dev/null +++ b/fastertransformer/cuda/cub/block/block_radix_sort.cuh @@ -0,0 +1,863 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * The cub::BlockRadixSort class provides [collective](index.html#sec0) methods for radix sorting of items partitioned across a CUDA thread block. + */ + + +#pragma once + +#include "block_exchange.cuh" +#include "block_radix_rank.cuh" +#include "../util_ptx.cuh" +#include "../util_arch.cuh" +#include "../util_type.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \brief The BlockRadixSort class provides [collective](index.html#sec0) methods for sorting items partitioned across a CUDA thread block using a radix sorting method. ![](sorting_logo.png) + * \ingroup BlockModule + * + * \tparam KeyT KeyT type + * \tparam BLOCK_DIM_X The thread block length in threads along the X dimension + * \tparam ITEMS_PER_THREAD The number of items per thread + * \tparam ValueT [optional] ValueT type (default: cub::NullType, which indicates a keys-only sort) + * \tparam RADIX_BITS [optional] The number of radix bits per digit place (default: 4 bits) + * \tparam MEMOIZE_OUTER_SCAN [optional] Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure (default: true for architectures SM35 and newer, false otherwise). + * \tparam INNER_SCAN_ALGORITHM [optional] The cub::BlockScanAlgorithm algorithm to use (default: cub::BLOCK_SCAN_WARP_SCANS) + * \tparam SMEM_CONFIG [optional] Shared memory bank mode (default: \p cudaSharedMemBankSizeFourByte) + * \tparam BLOCK_DIM_Y [optional] The thread block length in threads along the Y dimension (default: 1) + * \tparam BLOCK_DIM_Z [optional] The thread block length in threads along the Z dimension (default: 1) + * \tparam PTX_ARCH [optional] \ptxversion + * + * \par Overview + * - The [radix sorting method](http://en.wikipedia.org/wiki/Radix_sort) arranges + * items into ascending order. It relies upon a positional representation for + * keys, i.e., each key is comprised of an ordered sequence of symbols (e.g., digits, + * characters, etc.) specified from least-significant to most-significant. For a + * given input sequence of keys and a set of rules specifying a total ordering + * of the symbolic alphabet, the radix sorting method produces a lexicographic + * ordering of those keys. + * - BlockRadixSort can sort all of the built-in C++ numeric primitive types + * (unsigned char, \p int, \p double, etc.) as well as CUDA's \p __half + * half-precision floating-point type. Within each key, the implementation treats fixed-length + * bit-sequences of \p RADIX_BITS as radix digit places. Although the direct radix sorting + * method can only be applied to unsigned integral types, BlockRadixSort + * is able to sort signed and floating-point types via simple bit-wise transformations + * that ensure lexicographic key ordering. + * - \rowmajor + * + * \par Performance Considerations + * - \granularity + * + * \par A Simple Example + * \blockcollective{BlockRadixSort} + * \par + * The code snippet below illustrates a sort of 512 integer keys that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockRadixSort for a 1D block of 128 threads owning 4 integer items each + * typedef cub::BlockRadixSort BlockRadixSort; + * + * // Allocate shared memory for BlockRadixSort + * __shared__ typename BlockRadixSort::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_keys[4]; + * ... + * + * // Collectively sort the keys + * BlockRadixSort(temp_storage).Sort(thread_keys); + * + * ... + * \endcode + * \par + * Suppose the set of input \p thread_keys across the block of threads is + * { [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }. The + * corresponding output \p thread_keys in those threads will be + * { [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }. + * + */ +template < + typename KeyT, + int BLOCK_DIM_X, + int ITEMS_PER_THREAD, + typename ValueT = NullType, + int RADIX_BITS = 4, + bool MEMOIZE_OUTER_SCAN = (CUB_PTX_ARCH >= 350) ? true : false, + BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS, + cudaSharedMemConfig SMEM_CONFIG = cudaSharedMemBankSizeFourByte, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int PTX_ARCH = CUB_PTX_ARCH> +class BlockRadixSort +{ +private: + + /****************************************************************************** + * Constants and type definitions + ******************************************************************************/ + + enum + { + // The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + + // Whether or not there are values to be trucked along with keys + KEYS_ONLY = Equals::VALUE, + }; + + // KeyT traits and unsigned bits type + typedef Traits KeyTraits; + typedef typename KeyTraits::UnsignedBits UnsignedBits; + + /// Ascending BlockRadixRank utility type + typedef BlockRadixRank< + BLOCK_DIM_X, + RADIX_BITS, + false, + MEMOIZE_OUTER_SCAN, + INNER_SCAN_ALGORITHM, + SMEM_CONFIG, + BLOCK_DIM_Y, + BLOCK_DIM_Z, + PTX_ARCH> + AscendingBlockRadixRank; + + /// Descending BlockRadixRank utility type + typedef BlockRadixRank< + BLOCK_DIM_X, + RADIX_BITS, + true, + MEMOIZE_OUTER_SCAN, + INNER_SCAN_ALGORITHM, + SMEM_CONFIG, + BLOCK_DIM_Y, + BLOCK_DIM_Z, + PTX_ARCH> + DescendingBlockRadixRank; + + /// BlockExchange utility type for keys + typedef BlockExchange BlockExchangeKeys; + + /// BlockExchange utility type for values + typedef BlockExchange BlockExchangeValues; + + /// Shared memory storage layout type + union _TempStorage + { + typename AscendingBlockRadixRank::TempStorage asending_ranking_storage; + typename DescendingBlockRadixRank::TempStorage descending_ranking_storage; + typename BlockExchangeKeys::TempStorage exchange_keys; + typename BlockExchangeValues::TempStorage exchange_values; + }; + + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + /// Shared storage reference + _TempStorage &temp_storage; + + /// Linear thread-id + unsigned int linear_tid; + + /****************************************************************************** + * Utility methods + ******************************************************************************/ + + /// Internal storage allocator + __device__ __forceinline__ _TempStorage& PrivateStorage() + { + __shared__ _TempStorage private_storage; + return private_storage; + } + + /// Rank keys (specialized for ascending sort) + __device__ __forceinline__ void RankKeys( + UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD], + int (&ranks)[ITEMS_PER_THREAD], + int begin_bit, + int pass_bits, + Int2Type /*is_descending*/) + { + AscendingBlockRadixRank(temp_storage.asending_ranking_storage).RankKeys( + unsigned_keys, + ranks, + begin_bit, + pass_bits); + } + + /// Rank keys (specialized for descending sort) + __device__ __forceinline__ void RankKeys( + UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD], + int (&ranks)[ITEMS_PER_THREAD], + int begin_bit, + int pass_bits, + Int2Type /*is_descending*/) + { + DescendingBlockRadixRank(temp_storage.descending_ranking_storage).RankKeys( + unsigned_keys, + ranks, + begin_bit, + pass_bits); + } + + /// ExchangeValues (specialized for key-value sort, to-blocked arrangement) + __device__ __forceinline__ void ExchangeValues( + ValueT (&values)[ITEMS_PER_THREAD], + int (&ranks)[ITEMS_PER_THREAD], + Int2Type /*is_keys_only*/, + Int2Type /*is_blocked*/) + { + CTA_SYNC(); + + // Exchange values through shared memory in blocked arrangement + BlockExchangeValues(temp_storage.exchange_values).ScatterToBlocked(values, ranks); + } + + /// ExchangeValues (specialized for key-value sort, to-striped arrangement) + __device__ __forceinline__ void ExchangeValues( + ValueT (&values)[ITEMS_PER_THREAD], + int (&ranks)[ITEMS_PER_THREAD], + Int2Type /*is_keys_only*/, + Int2Type /*is_blocked*/) + { + CTA_SYNC(); + + // Exchange values through shared memory in blocked arrangement + BlockExchangeValues(temp_storage.exchange_values).ScatterToStriped(values, ranks); + } + + /// ExchangeValues (specialized for keys-only sort) + template + __device__ __forceinline__ void ExchangeValues( + ValueT (&/*values*/)[ITEMS_PER_THREAD], + int (&/*ranks*/)[ITEMS_PER_THREAD], + Int2Type /*is_keys_only*/, + Int2Type /*is_blocked*/) + {} + + /// Sort blocked arrangement + template + __device__ __forceinline__ void SortBlocked( + KeyT (&keys)[ITEMS_PER_THREAD], ///< Keys to sort + ValueT (&values)[ITEMS_PER_THREAD], ///< Values to sort + int begin_bit, ///< The beginning (least-significant) bit index needed for key comparison + int end_bit, ///< The past-the-end (most-significant) bit index needed for key comparison + Int2Type is_descending, ///< Tag whether is a descending-order sort + Int2Type is_keys_only) ///< Tag whether is keys-only sort + { + UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD] = + reinterpret_cast(keys); + + // Twiddle bits if necessary + #pragma unroll + for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) + { + unsigned_keys[KEY] = KeyTraits::TwiddleIn(unsigned_keys[KEY]); + } + + // Radix sorting passes + while (true) + { + int pass_bits = CUB_MIN(RADIX_BITS, end_bit - begin_bit); + + // Rank the blocked keys + int ranks[ITEMS_PER_THREAD]; + RankKeys(unsigned_keys, ranks, begin_bit, pass_bits, is_descending); + begin_bit += RADIX_BITS; + + CTA_SYNC(); + + // Exchange keys through shared memory in blocked arrangement + BlockExchangeKeys(temp_storage.exchange_keys).ScatterToBlocked(keys, ranks); + + // Exchange values through shared memory in blocked arrangement + ExchangeValues(values, ranks, is_keys_only, Int2Type()); + + // Quit if done + if (begin_bit >= end_bit) break; + + CTA_SYNC(); + } + + // Untwiddle bits if necessary + #pragma unroll + for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) + { + unsigned_keys[KEY] = KeyTraits::TwiddleOut(unsigned_keys[KEY]); + } + } + +public: + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + /// Sort blocked -> striped arrangement + template + __device__ __forceinline__ void SortBlockedToStriped( + KeyT (&keys)[ITEMS_PER_THREAD], ///< Keys to sort + ValueT (&values)[ITEMS_PER_THREAD], ///< Values to sort + int begin_bit, ///< The beginning (least-significant) bit index needed for key comparison + int end_bit, ///< The past-the-end (most-significant) bit index needed for key comparison + Int2Type is_descending, ///< Tag whether is a descending-order sort + Int2Type is_keys_only) ///< Tag whether is keys-only sort + { + UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD] = + reinterpret_cast(keys); + + // Twiddle bits if necessary + #pragma unroll + for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) + { + unsigned_keys[KEY] = KeyTraits::TwiddleIn(unsigned_keys[KEY]); + } + + // Radix sorting passes + while (true) + { + int pass_bits = CUB_MIN(RADIX_BITS, end_bit - begin_bit); + + // Rank the blocked keys + int ranks[ITEMS_PER_THREAD]; + RankKeys(unsigned_keys, ranks, begin_bit, pass_bits, is_descending); + begin_bit += RADIX_BITS; + + CTA_SYNC(); + + // Check if this is the last pass + if (begin_bit >= end_bit) + { + // Last pass exchanges keys through shared memory in striped arrangement + BlockExchangeKeys(temp_storage.exchange_keys).ScatterToStriped(keys, ranks); + + // Last pass exchanges through shared memory in striped arrangement + ExchangeValues(values, ranks, is_keys_only, Int2Type()); + + // Quit + break; + } + + // Exchange keys through shared memory in blocked arrangement + BlockExchangeKeys(temp_storage.exchange_keys).ScatterToBlocked(keys, ranks); + + // Exchange values through shared memory in blocked arrangement + ExchangeValues(values, ranks, is_keys_only, Int2Type()); + + CTA_SYNC(); + } + + // Untwiddle bits if necessary + #pragma unroll + for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) + { + unsigned_keys[KEY] = KeyTraits::TwiddleOut(unsigned_keys[KEY]); + } + } + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + /// \smemstorage{BlockRadixSort} + struct TempStorage : Uninitialized<_TempStorage> {}; + + + /******************************************************************//** + * \name Collective constructors + *********************************************************************/ + //@{ + + /** + * \brief Collective constructor using a private static allocation of shared memory as temporary storage. + */ + __device__ __forceinline__ BlockRadixSort() + : + temp_storage(PrivateStorage()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + /** + * \brief Collective constructor using the specified memory allocation as temporary storage. + */ + __device__ __forceinline__ BlockRadixSort( + TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + //@} end member group + /******************************************************************//** + * \name Sorting (blocked arrangements) + *********************************************************************/ + //@{ + + /** + * \brief Performs an ascending block-wide radix sort over a [blocked arrangement](index.html#sec5sec3) of keys. + * + * \par + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a sort of 512 integer keys that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive keys. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockRadixSort for a 1D block of 128 threads owning 4 integer keys each + * typedef cub::BlockRadixSort BlockRadixSort; + * + * // Allocate shared memory for BlockRadixSort + * __shared__ typename BlockRadixSort::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_keys[4]; + * ... + * + * // Collectively sort the keys + * BlockRadixSort(temp_storage).Sort(thread_keys); + * + * \endcode + * \par + * Suppose the set of input \p thread_keys across the block of threads is + * { [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }. + * The corresponding output \p thread_keys in those threads will be + * { [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }. + */ + __device__ __forceinline__ void Sort( + KeyT (&keys)[ITEMS_PER_THREAD], ///< [in-out] Keys to sort + int begin_bit = 0, ///< [in] [optional] The beginning (least-significant) bit index needed for key comparison + int end_bit = sizeof(KeyT) * 8) ///< [in] [optional] The past-the-end (most-significant) bit index needed for key comparison + { + NullType values[ITEMS_PER_THREAD]; + + SortBlocked(keys, values, begin_bit, end_bit, Int2Type(), Int2Type()); + } + + + /** + * \brief Performs an ascending block-wide radix sort across a [blocked arrangement](index.html#sec5sec3) of keys and values. + * + * \par + * - BlockRadixSort can only accommodate one associated tile of values. To "truck along" + * more than one tile of values, simply perform a key-value sort of the keys paired + * with a temporary value array that enumerates the key indices. The reordered indices + * can then be used as a gather-vector for exchanging other associated tile data through + * shared memory. + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a sort of 512 integer keys and values that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive pairs. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockRadixSort for a 1D block of 128 threads owning 4 integer keys and values each + * typedef cub::BlockRadixSort BlockRadixSort; + * + * // Allocate shared memory for BlockRadixSort + * __shared__ typename BlockRadixSort::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_keys[4]; + * int thread_values[4]; + * ... + * + * // Collectively sort the keys and values among block threads + * BlockRadixSort(temp_storage).Sort(thread_keys, thread_values); + * + * \endcode + * \par + * Suppose the set of input \p thread_keys across the block of threads is + * { [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }. The + * corresponding output \p thread_keys in those threads will be + * { [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }. + * + */ + __device__ __forceinline__ void Sort( + KeyT (&keys)[ITEMS_PER_THREAD], ///< [in-out] Keys to sort + ValueT (&values)[ITEMS_PER_THREAD], ///< [in-out] Values to sort + int begin_bit = 0, ///< [in] [optional] The beginning (least-significant) bit index needed for key comparison + int end_bit = sizeof(KeyT) * 8) ///< [in] [optional] The past-the-end (most-significant) bit index needed for key comparison + { + SortBlocked(keys, values, begin_bit, end_bit, Int2Type(), Int2Type()); + } + + /** + * \brief Performs a descending block-wide radix sort over a [blocked arrangement](index.html#sec5sec3) of keys. + * + * \par + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a sort of 512 integer keys that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive keys. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockRadixSort for a 1D block of 128 threads owning 4 integer keys each + * typedef cub::BlockRadixSort BlockRadixSort; + * + * // Allocate shared memory for BlockRadixSort + * __shared__ typename BlockRadixSort::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_keys[4]; + * ... + * + * // Collectively sort the keys + * BlockRadixSort(temp_storage).Sort(thread_keys); + * + * \endcode + * \par + * Suppose the set of input \p thread_keys across the block of threads is + * { [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }. + * The corresponding output \p thread_keys in those threads will be + * { [511,510,509,508], [11,10,9,8], [7,6,5,4], ..., [3,2,1,0] }. + */ + __device__ __forceinline__ void SortDescending( + KeyT (&keys)[ITEMS_PER_THREAD], ///< [in-out] Keys to sort + int begin_bit = 0, ///< [in] [optional] The beginning (least-significant) bit index needed for key comparison + int end_bit = sizeof(KeyT) * 8) ///< [in] [optional] The past-the-end (most-significant) bit index needed for key comparison + { + NullType values[ITEMS_PER_THREAD]; + + SortBlocked(keys, values, begin_bit, end_bit, Int2Type(), Int2Type()); + } + + + /** + * \brief Performs a descending block-wide radix sort across a [blocked arrangement](index.html#sec5sec3) of keys and values. + * + * \par + * - BlockRadixSort can only accommodate one associated tile of values. To "truck along" + * more than one tile of values, simply perform a key-value sort of the keys paired + * with a temporary value array that enumerates the key indices. The reordered indices + * can then be used as a gather-vector for exchanging other associated tile data through + * shared memory. + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a sort of 512 integer keys and values that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive pairs. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockRadixSort for a 1D block of 128 threads owning 4 integer keys and values each + * typedef cub::BlockRadixSort BlockRadixSort; + * + * // Allocate shared memory for BlockRadixSort + * __shared__ typename BlockRadixSort::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_keys[4]; + * int thread_values[4]; + * ... + * + * // Collectively sort the keys and values among block threads + * BlockRadixSort(temp_storage).Sort(thread_keys, thread_values); + * + * \endcode + * \par + * Suppose the set of input \p thread_keys across the block of threads is + * { [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }. The + * corresponding output \p thread_keys in those threads will be + * { [511,510,509,508], [11,10,9,8], [7,6,5,4], ..., [3,2,1,0] }. + * + */ + __device__ __forceinline__ void SortDescending( + KeyT (&keys)[ITEMS_PER_THREAD], ///< [in-out] Keys to sort + ValueT (&values)[ITEMS_PER_THREAD], ///< [in-out] Values to sort + int begin_bit = 0, ///< [in] [optional] The beginning (least-significant) bit index needed for key comparison + int end_bit = sizeof(KeyT) * 8) ///< [in] [optional] The past-the-end (most-significant) bit index needed for key comparison + { + SortBlocked(keys, values, begin_bit, end_bit, Int2Type(), Int2Type()); + } + + + //@} end member group + /******************************************************************//** + * \name Sorting (blocked arrangement -> striped arrangement) + *********************************************************************/ + //@{ + + + /** + * \brief Performs an ascending radix sort across a [blocked arrangement](index.html#sec5sec3) of keys, leaving them in a [striped arrangement](index.html#sec5sec3). + * + * \par + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a sort of 512 integer keys that + * are initially partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive keys. The final partitioning is striped. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockRadixSort for a 1D block of 128 threads owning 4 integer keys each + * typedef cub::BlockRadixSort BlockRadixSort; + * + * // Allocate shared memory for BlockRadixSort + * __shared__ typename BlockRadixSort::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_keys[4]; + * ... + * + * // Collectively sort the keys + * BlockRadixSort(temp_storage).SortBlockedToStriped(thread_keys); + * + * \endcode + * \par + * Suppose the set of input \p thread_keys across the block of threads is + * { [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }. The + * corresponding output \p thread_keys in those threads will be + * { [0,128,256,384], [1,129,257,385], [2,130,258,386], ..., [127,255,383,511] }. + * + */ + __device__ __forceinline__ void SortBlockedToStriped( + KeyT (&keys)[ITEMS_PER_THREAD], ///< [in-out] Keys to sort + int begin_bit = 0, ///< [in] [optional] The beginning (least-significant) bit index needed for key comparison + int end_bit = sizeof(KeyT) * 8) ///< [in] [optional] The past-the-end (most-significant) bit index needed for key comparison + { + NullType values[ITEMS_PER_THREAD]; + + SortBlockedToStriped(keys, values, begin_bit, end_bit, Int2Type(), Int2Type()); + } + + + /** + * \brief Performs an ascending radix sort across a [blocked arrangement](index.html#sec5sec3) of keys and values, leaving them in a [striped arrangement](index.html#sec5sec3). + * + * \par + * - BlockRadixSort can only accommodate one associated tile of values. To "truck along" + * more than one tile of values, simply perform a key-value sort of the keys paired + * with a temporary value array that enumerates the key indices. The reordered indices + * can then be used as a gather-vector for exchanging other associated tile data through + * shared memory. + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a sort of 512 integer keys and values that + * are initially partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive pairs. The final partitioning is striped. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockRadixSort for a 1D block of 128 threads owning 4 integer keys and values each + * typedef cub::BlockRadixSort BlockRadixSort; + * + * // Allocate shared memory for BlockRadixSort + * __shared__ typename BlockRadixSort::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_keys[4]; + * int thread_values[4]; + * ... + * + * // Collectively sort the keys and values among block threads + * BlockRadixSort(temp_storage).SortBlockedToStriped(thread_keys, thread_values); + * + * \endcode + * \par + * Suppose the set of input \p thread_keys across the block of threads is + * { [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }. The + * corresponding output \p thread_keys in those threads will be + * { [0,128,256,384], [1,129,257,385], [2,130,258,386], ..., [127,255,383,511] }. + * + */ + __device__ __forceinline__ void SortBlockedToStriped( + KeyT (&keys)[ITEMS_PER_THREAD], ///< [in-out] Keys to sort + ValueT (&values)[ITEMS_PER_THREAD], ///< [in-out] Values to sort + int begin_bit = 0, ///< [in] [optional] The beginning (least-significant) bit index needed for key comparison + int end_bit = sizeof(KeyT) * 8) ///< [in] [optional] The past-the-end (most-significant) bit index needed for key comparison + { + SortBlockedToStriped(keys, values, begin_bit, end_bit, Int2Type(), Int2Type()); + } + + + /** + * \brief Performs a descending radix sort across a [blocked arrangement](index.html#sec5sec3) of keys, leaving them in a [striped arrangement](index.html#sec5sec3). + * + * \par + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a sort of 512 integer keys that + * are initially partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive keys. The final partitioning is striped. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockRadixSort for a 1D block of 128 threads owning 4 integer keys each + * typedef cub::BlockRadixSort BlockRadixSort; + * + * // Allocate shared memory for BlockRadixSort + * __shared__ typename BlockRadixSort::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_keys[4]; + * ... + * + * // Collectively sort the keys + * BlockRadixSort(temp_storage).SortBlockedToStriped(thread_keys); + * + * \endcode + * \par + * Suppose the set of input \p thread_keys across the block of threads is + * { [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }. The + * corresponding output \p thread_keys in those threads will be + * { [511,383,255,127], [386,258,130,2], [385,257,128,1], ..., [384,256,128,0] }. + * + */ + __device__ __forceinline__ void SortDescendingBlockedToStriped( + KeyT (&keys)[ITEMS_PER_THREAD], ///< [in-out] Keys to sort + int begin_bit = 0, ///< [in] [optional] The beginning (least-significant) bit index needed for key comparison + int end_bit = sizeof(KeyT) * 8) ///< [in] [optional] The past-the-end (most-significant) bit index needed for key comparison + { + NullType values[ITEMS_PER_THREAD]; + + SortBlockedToStriped(keys, values, begin_bit, end_bit, Int2Type(), Int2Type()); + } + + + /** + * \brief Performs a descending radix sort across a [blocked arrangement](index.html#sec5sec3) of keys and values, leaving them in a [striped arrangement](index.html#sec5sec3). + * + * \par + * - BlockRadixSort can only accommodate one associated tile of values. To "truck along" + * more than one tile of values, simply perform a key-value sort of the keys paired + * with a temporary value array that enumerates the key indices. The reordered indices + * can then be used as a gather-vector for exchanging other associated tile data through + * shared memory. + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a sort of 512 integer keys and values that + * are initially partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive pairs. The final partitioning is striped. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockRadixSort for a 1D block of 128 threads owning 4 integer keys and values each + * typedef cub::BlockRadixSort BlockRadixSort; + * + * // Allocate shared memory for BlockRadixSort + * __shared__ typename BlockRadixSort::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_keys[4]; + * int thread_values[4]; + * ... + * + * // Collectively sort the keys and values among block threads + * BlockRadixSort(temp_storage).SortBlockedToStriped(thread_keys, thread_values); + * + * \endcode + * \par + * Suppose the set of input \p thread_keys across the block of threads is + * { [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }. The + * corresponding output \p thread_keys in those threads will be + * { [511,383,255,127], [386,258,130,2], [385,257,128,1], ..., [384,256,128,0] }. + * + */ + __device__ __forceinline__ void SortDescendingBlockedToStriped( + KeyT (&keys)[ITEMS_PER_THREAD], ///< [in-out] Keys to sort + ValueT (&values)[ITEMS_PER_THREAD], ///< [in-out] Values to sort + int begin_bit = 0, ///< [in] [optional] The beginning (least-significant) bit index needed for key comparison + int end_bit = sizeof(KeyT) * 8) ///< [in] [optional] The past-the-end (most-significant) bit index needed for key comparison + { + SortBlockedToStriped(keys, values, begin_bit, end_bit, Int2Type(), Int2Type()); + } + + + //@} end member group + +}; + +/** + * \example example_block_radix_sort.cu + */ + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/block/block_raking_layout.cuh b/fastertransformer/cuda/cub/block/block_raking_layout.cuh new file mode 100644 index 000000000..350061686 --- /dev/null +++ b/fastertransformer/cuda/cub/block/block_raking_layout.cuh @@ -0,0 +1,152 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::BlockRakingLayout provides a conflict-free shared memory layout abstraction for warp-raking across thread block data. + */ + + +#pragma once + +#include "../util_macro.cuh" +#include "../util_arch.cuh" +#include "../util_type.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \brief BlockRakingLayout provides a conflict-free shared memory layout abstraction for 1D raking across thread block data. ![](raking.png) + * \ingroup BlockModule + * + * \par Overview + * This type facilitates a shared memory usage pattern where a block of CUDA + * threads places elements into shared memory and then reduces the active + * parallelism to one "raking" warp of threads for serially aggregating consecutive + * sequences of shared items. Padding is inserted to eliminate bank conflicts + * (for most data types). + * + * \tparam T The data type to be exchanged. + * \tparam BLOCK_THREADS The thread block size in threads. + * \tparam PTX_ARCH [optional] \ptxversion + */ +template < + typename T, + int BLOCK_THREADS, + int PTX_ARCH = CUB_PTX_ARCH> +struct BlockRakingLayout +{ + //--------------------------------------------------------------------- + // Constants and type definitions + //--------------------------------------------------------------------- + + enum + { + /// The total number of elements that need to be cooperatively reduced + SHARED_ELEMENTS = BLOCK_THREADS, + + /// Maximum number of warp-synchronous raking threads + MAX_RAKING_THREADS = CUB_MIN(BLOCK_THREADS, CUB_WARP_THREADS(PTX_ARCH)), + + /// Number of raking elements per warp-synchronous raking thread (rounded up) + SEGMENT_LENGTH = (SHARED_ELEMENTS + MAX_RAKING_THREADS - 1) / MAX_RAKING_THREADS, + + /// Never use a raking thread that will have no valid data (e.g., when BLOCK_THREADS is 62 and SEGMENT_LENGTH is 2, we should only use 31 raking threads) + RAKING_THREADS = (SHARED_ELEMENTS + SEGMENT_LENGTH - 1) / SEGMENT_LENGTH, + + /// Whether we will have bank conflicts (technically we should find out if the GCD is > 1) + HAS_CONFLICTS = (CUB_SMEM_BANKS(PTX_ARCH) % SEGMENT_LENGTH == 0), + + /// Degree of bank conflicts (e.g., 4-way) + CONFLICT_DEGREE = (HAS_CONFLICTS) ? + (MAX_RAKING_THREADS * SEGMENT_LENGTH) / CUB_SMEM_BANKS(PTX_ARCH) : + 1, + + /// Pad each segment length with one element if segment length is not relatively prime to warp size and can't be optimized as a vector load + USE_SEGMENT_PADDING = ((SEGMENT_LENGTH & 1) == 0) && (SEGMENT_LENGTH > 2), + + /// Total number of elements in the raking grid + GRID_ELEMENTS = RAKING_THREADS * (SEGMENT_LENGTH + USE_SEGMENT_PADDING), + + /// Whether or not we need bounds checking during raking (the number of reduction elements is not a multiple of the number of raking threads) + UNGUARDED = (SHARED_ELEMENTS % RAKING_THREADS == 0), + }; + + + /** + * \brief Shared memory storage type + */ + struct __align__(16) _TempStorage + { + T buff[BlockRakingLayout::GRID_ELEMENTS]; + }; + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + /** + * \brief Returns the location for the calling thread to place data into the grid + */ + static __device__ __forceinline__ T* PlacementPtr( + TempStorage &temp_storage, + unsigned int linear_tid) + { + // Offset for partial + unsigned int offset = linear_tid; + + // Add in one padding element for every segment + if (USE_SEGMENT_PADDING > 0) + { + offset += offset / SEGMENT_LENGTH; + } + + // Incorporating a block of padding partials every shared memory segment + return temp_storage.Alias().buff + offset; + } + + + /** + * \brief Returns the location for the calling thread to begin sequential raking + */ + static __device__ __forceinline__ T* RakingPtr( + TempStorage &temp_storage, + unsigned int linear_tid) + { + return temp_storage.Alias().buff + (linear_tid * (SEGMENT_LENGTH + USE_SEGMENT_PADDING)); + } +}; + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/block/block_reduce.cuh b/fastertransformer/cuda/cub/block/block_reduce.cuh new file mode 100644 index 000000000..261f2ea6f --- /dev/null +++ b/fastertransformer/cuda/cub/block/block_reduce.cuh @@ -0,0 +1,607 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * The cub::BlockReduce class provides [collective](index.html#sec0) methods for computing a parallel reduction of items partitioned across a CUDA thread block. + */ + +#pragma once + +#include "specializations/block_reduce_raking.cuh" +#include "specializations/block_reduce_raking_commutative_only.cuh" +#include "specializations/block_reduce_warp_reductions.cuh" +#include "../util_ptx.cuh" +#include "../util_type.cuh" +#include "../thread/thread_operators.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + + +/****************************************************************************** + * Algorithmic variants + ******************************************************************************/ + +/** + * BlockReduceAlgorithm enumerates alternative algorithms for parallel + * reduction across a CUDA thread block. + */ +enum BlockReduceAlgorithm +{ + + /** + * \par Overview + * An efficient "raking" reduction algorithm that only supports commutative + * reduction operators (true for most operations, e.g., addition). + * + * \par + * Execution is comprised of three phases: + * -# Upsweep sequential reduction in registers (if threads contribute more + * than one input each). Threads in warps other than the first warp place + * their partial reductions into shared memory. + * -# Upsweep sequential reduction in shared memory. Threads within the first + * warp continue to accumulate by raking across segments of shared partial reductions + * -# A warp-synchronous Kogge-Stone style reduction within the raking warp. + * + * \par + * \image html block_reduce.png + *
\p BLOCK_REDUCE_RAKING data flow for a hypothetical 16-thread thread block and 4-thread raking warp.
+ * + * \par Performance Considerations + * - This variant performs less communication than BLOCK_REDUCE_RAKING_NON_COMMUTATIVE + * and is preferable when the reduction operator is commutative. This variant + * applies fewer reduction operators than BLOCK_REDUCE_WARP_REDUCTIONS, and can provide higher overall + * throughput across the GPU when suitably occupied. However, turn-around latency may be + * higher than to BLOCK_REDUCE_WARP_REDUCTIONS and thus less-desirable + * when the GPU is under-occupied. + */ + BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY, + + + /** + * \par Overview + * An efficient "raking" reduction algorithm that supports commutative + * (e.g., addition) and non-commutative (e.g., string concatenation) reduction + * operators. \blocked. + * + * \par + * Execution is comprised of three phases: + * -# Upsweep sequential reduction in registers (if threads contribute more + * than one input each). Each thread then places the partial reduction + * of its item(s) into shared memory. + * -# Upsweep sequential reduction in shared memory. Threads within a + * single warp rake across segments of shared partial reductions. + * -# A warp-synchronous Kogge-Stone style reduction within the raking warp. + * + * \par + * \image html block_reduce.png + *
\p BLOCK_REDUCE_RAKING data flow for a hypothetical 16-thread thread block and 4-thread raking warp.
+ * + * \par Performance Considerations + * - This variant performs more communication than BLOCK_REDUCE_RAKING + * and is only preferable when the reduction operator is non-commutative. This variant + * applies fewer reduction operators than BLOCK_REDUCE_WARP_REDUCTIONS, and can provide higher overall + * throughput across the GPU when suitably occupied. However, turn-around latency may be + * higher than to BLOCK_REDUCE_WARP_REDUCTIONS and thus less-desirable + * when the GPU is under-occupied. + */ + BLOCK_REDUCE_RAKING, + + + /** + * \par Overview + * A quick "tiled warp-reductions" reduction algorithm that supports commutative + * (e.g., addition) and non-commutative (e.g., string concatenation) reduction + * operators. + * + * \par + * Execution is comprised of four phases: + * -# Upsweep sequential reduction in registers (if threads contribute more + * than one input each). Each thread then places the partial reduction + * of its item(s) into shared memory. + * -# Compute a shallow, but inefficient warp-synchronous Kogge-Stone style + * reduction within each warp. + * -# A propagation phase where the warp reduction outputs in each warp are + * updated with the aggregate from each preceding warp. + * + * \par + * \image html block_scan_warpscans.png + *
\p BLOCK_REDUCE_WARP_REDUCTIONS data flow for a hypothetical 16-thread thread block and 4-thread raking warp.
+ * + * \par Performance Considerations + * - This variant applies more reduction operators than BLOCK_REDUCE_RAKING + * or BLOCK_REDUCE_RAKING_NON_COMMUTATIVE, which may result in lower overall + * throughput across the GPU. However turn-around latency may be lower and + * thus useful when the GPU is under-occupied. + */ + BLOCK_REDUCE_WARP_REDUCTIONS, +}; + + +/****************************************************************************** + * Block reduce + ******************************************************************************/ + +/** + * \brief The BlockReduce class provides [collective](index.html#sec0) methods for computing a parallel reduction of items partitioned across a CUDA thread block. ![](reduce_logo.png) + * \ingroup BlockModule + * + * \tparam T Data type being reduced + * \tparam BLOCK_DIM_X The thread block length in threads along the X dimension + * \tparam ALGORITHM [optional] cub::BlockReduceAlgorithm enumerator specifying the underlying algorithm to use (default: cub::BLOCK_REDUCE_WARP_REDUCTIONS) + * \tparam BLOCK_DIM_Y [optional] The thread block length in threads along the Y dimension (default: 1) + * \tparam BLOCK_DIM_Z [optional] The thread block length in threads along the Z dimension (default: 1) + * \tparam PTX_ARCH [optional] \ptxversion + * + * \par Overview + * - A reduction (or fold) + * uses a binary combining operator to compute a single aggregate from a list of input elements. + * - \rowmajor + * - BlockReduce can be optionally specialized by algorithm to accommodate different latency/throughput workload profiles: + * -# cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY. An efficient "raking" reduction algorithm that only supports commutative reduction operators. [More...](\ref cub::BlockReduceAlgorithm) + * -# cub::BLOCK_REDUCE_RAKING. An efficient "raking" reduction algorithm that supports commutative and non-commutative reduction operators. [More...](\ref cub::BlockReduceAlgorithm) + * -# cub::BLOCK_REDUCE_WARP_REDUCTIONS. A quick "tiled warp-reductions" reduction algorithm that supports commutative and non-commutative reduction operators. [More...](\ref cub::BlockReduceAlgorithm) + * + * \par Performance Considerations + * - \granularity + * - Very efficient (only one synchronization barrier). + * - Incurs zero bank conflicts for most types + * - Computation is slightly more efficient (i.e., having lower instruction overhead) for: + * - Summation (vs. generic reduction) + * - \p BLOCK_THREADS is a multiple of the architecture's warp size + * - Every thread has a valid input (i.e., full vs. partial-tiles) + * - See cub::BlockReduceAlgorithm for performance details regarding algorithmic alternatives + * + * \par A Simple Example + * \blockcollective{BlockReduce} + * \par + * The code snippet below illustrates a sum reduction of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockReduce for a 1D block of 128 threads on type int + * typedef cub::BlockReduce BlockReduce; + * + * // Allocate shared memory for BlockReduce + * __shared__ typename BlockReduce::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Compute the block-wide sum for thread0 + * int aggregate = BlockReduce(temp_storage).Sum(thread_data); + * + * \endcode + * + */ +template < + typename T, + int BLOCK_DIM_X, + BlockReduceAlgorithm ALGORITHM = BLOCK_REDUCE_WARP_REDUCTIONS, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int PTX_ARCH = CUB_PTX_ARCH> +class BlockReduce +{ +private: + + /****************************************************************************** + * Constants and type definitions + ******************************************************************************/ + + /// Constants + enum + { + /// The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + }; + + typedef BlockReduceWarpReductions WarpReductions; + typedef BlockReduceRakingCommutativeOnly RakingCommutativeOnly; + typedef BlockReduceRaking Raking; + + /// Internal specialization type + typedef typename If<(ALGORITHM == BLOCK_REDUCE_WARP_REDUCTIONS), + WarpReductions, + typename If<(ALGORITHM == BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY), + RakingCommutativeOnly, + Raking>::Type>::Type InternalBlockReduce; // BlockReduceRaking + + /// Shared memory storage layout type for BlockReduce + typedef typename InternalBlockReduce::TempStorage _TempStorage; + + + /****************************************************************************** + * Utility methods + ******************************************************************************/ + + /// Internal storage allocator + __device__ __forceinline__ _TempStorage& PrivateStorage() + { + __shared__ _TempStorage private_storage; + return private_storage; + } + + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + /// Shared storage reference + _TempStorage &temp_storage; + + /// Linear thread-id + unsigned int linear_tid; + + +public: + + /// \smemstorage{BlockReduce} + struct TempStorage : Uninitialized<_TempStorage> {}; + + + /******************************************************************//** + * \name Collective constructors + *********************************************************************/ + //@{ + + /** + * \brief Collective constructor using a private static allocation of shared memory as temporary storage. + */ + __device__ __forceinline__ BlockReduce() + : + temp_storage(PrivateStorage()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + /** + * \brief Collective constructor using the specified memory allocation as temporary storage. + */ + __device__ __forceinline__ BlockReduce( + TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + //@} end member group + /******************************************************************//** + * \name Generic reductions + *********************************************************************/ + //@{ + + + /** + * \brief Computes a block-wide reduction for thread0 using the specified binary reduction functor. Each thread contributes one input element. + * + * \par + * - The return value is undefined in threads other than thread0. + * - \rowmajor + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a max reduction of 128 integer items that + * are partitioned across 128 threads. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockReduce for a 1D block of 128 threads on type int + * typedef cub::BlockReduce BlockReduce; + * + * // Allocate shared memory for BlockReduce + * __shared__ typename BlockReduce::TempStorage temp_storage; + * + * // Each thread obtains an input item + * int thread_data; + * ... + * + * // Compute the block-wide max for thread0 + * int aggregate = BlockReduce(temp_storage).Reduce(thread_data, cub::Max()); + * + * \endcode + * + * \tparam ReductionOp [inferred] Binary reduction functor type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ T Reduce( + T input, ///< [in] Calling thread's input + ReductionOp reduction_op) ///< [in] Binary reduction functor + { + return InternalBlockReduce(temp_storage).template Reduce(input, BLOCK_THREADS, reduction_op); + } + + + /** + * \brief Computes a block-wide reduction for thread0 using the specified binary reduction functor. Each thread contributes an array of consecutive input elements. + * + * \par + * - The return value is undefined in threads other than thread0. + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a max reduction of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockReduce for a 1D block of 128 threads on type int + * typedef cub::BlockReduce BlockReduce; + * + * // Allocate shared memory for BlockReduce + * __shared__ typename BlockReduce::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Compute the block-wide max for thread0 + * int aggregate = BlockReduce(temp_storage).Reduce(thread_data, cub::Max()); + * + * \endcode + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam ReductionOp [inferred] Binary reduction functor type having member T operator()(const T &a, const T &b) + */ + template < + int ITEMS_PER_THREAD, + typename ReductionOp> + __device__ __forceinline__ T Reduce( + T (&inputs)[ITEMS_PER_THREAD], ///< [in] Calling thread's input segment + ReductionOp reduction_op) ///< [in] Binary reduction functor + { + // Reduce partials + T partial = internal::ThreadReduce(inputs, reduction_op); + return Reduce(partial, reduction_op); + } + + + /** + * \brief Computes a block-wide reduction for thread0 using the specified binary reduction functor. The first \p num_valid threads each contribute one input element. + * + * \par + * - The return value is undefined in threads other than thread0. + * - \rowmajor + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a max reduction of a partially-full tile of integer items that + * are partitioned across 128 threads. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(int num_valid, ...) + * { + * // Specialize BlockReduce for a 1D block of 128 threads on type int + * typedef cub::BlockReduce BlockReduce; + * + * // Allocate shared memory for BlockReduce + * __shared__ typename BlockReduce::TempStorage temp_storage; + * + * // Each thread obtains an input item + * int thread_data; + * if (threadIdx.x < num_valid) thread_data = ... + * + * // Compute the block-wide max for thread0 + * int aggregate = BlockReduce(temp_storage).Reduce(thread_data, cub::Max(), num_valid); + * + * \endcode + * + * \tparam ReductionOp [inferred] Binary reduction functor type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ T Reduce( + T input, ///< [in] Calling thread's input + ReductionOp reduction_op, ///< [in] Binary reduction functor + int num_valid) ///< [in] Number of threads containing valid elements (may be less than BLOCK_THREADS) + { + // Determine if we scan skip bounds checking + if (num_valid >= BLOCK_THREADS) + { + return InternalBlockReduce(temp_storage).template Reduce(input, num_valid, reduction_op); + } + else + { + return InternalBlockReduce(temp_storage).template Reduce(input, num_valid, reduction_op); + } + } + + + //@} end member group + /******************************************************************//** + * \name Summation reductions + *********************************************************************/ + //@{ + + + /** + * \brief Computes a block-wide reduction for thread0 using addition (+) as the reduction operator. Each thread contributes one input element. + * + * \par + * - The return value is undefined in threads other than thread0. + * - \rowmajor + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a sum reduction of 128 integer items that + * are partitioned across 128 threads. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockReduce for a 1D block of 128 threads on type int + * typedef cub::BlockReduce BlockReduce; + * + * // Allocate shared memory for BlockReduce + * __shared__ typename BlockReduce::TempStorage temp_storage; + * + * // Each thread obtains an input item + * int thread_data; + * ... + * + * // Compute the block-wide sum for thread0 + * int aggregate = BlockReduce(temp_storage).Sum(thread_data); + * + * \endcode + * + */ + __device__ __forceinline__ T Sum( + T input) ///< [in] Calling thread's input + { + return InternalBlockReduce(temp_storage).template Sum(input, BLOCK_THREADS); + } + + /** + * \brief Computes a block-wide reduction for thread0 using addition (+) as the reduction operator. Each thread contributes an array of consecutive input elements. + * + * \par + * - The return value is undefined in threads other than thread0. + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a sum reduction of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockReduce for a 1D block of 128 threads on type int + * typedef cub::BlockReduce BlockReduce; + * + * // Allocate shared memory for BlockReduce + * __shared__ typename BlockReduce::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Compute the block-wide sum for thread0 + * int aggregate = BlockReduce(temp_storage).Sum(thread_data); + * + * \endcode + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + */ + template + __device__ __forceinline__ T Sum( + T (&inputs)[ITEMS_PER_THREAD]) ///< [in] Calling thread's input segment + { + // Reduce partials + T partial = internal::ThreadReduce(inputs, cub::Sum()); + return Sum(partial); + } + + + /** + * \brief Computes a block-wide reduction for thread0 using addition (+) as the reduction operator. The first \p num_valid threads each contribute one input element. + * + * \par + * - The return value is undefined in threads other than thread0. + * - \rowmajor + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a sum reduction of a partially-full tile of integer items that + * are partitioned across 128 threads. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(int num_valid, ...) + * { + * // Specialize BlockReduce for a 1D block of 128 threads on type int + * typedef cub::BlockReduce BlockReduce; + * + * // Allocate shared memory for BlockReduce + * __shared__ typename BlockReduce::TempStorage temp_storage; + * + * // Each thread obtains an input item (up to num_items) + * int thread_data; + * if (threadIdx.x < num_valid) + * thread_data = ... + * + * // Compute the block-wide sum for thread0 + * int aggregate = BlockReduce(temp_storage).Sum(thread_data, num_valid); + * + * \endcode + * + */ + __device__ __forceinline__ T Sum( + T input, ///< [in] Calling thread's input + int num_valid) ///< [in] Number of threads containing valid elements (may be less than BLOCK_THREADS) + { + // Determine if we scan skip bounds checking + if (num_valid >= BLOCK_THREADS) + { + return InternalBlockReduce(temp_storage).template Sum(input, num_valid); + } + else + { + return InternalBlockReduce(temp_storage).template Sum(input, num_valid); + } + } + + + //@} end member group +}; + +/** + * \example example_block_reduce.cu + */ + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/block/block_scan.cuh b/fastertransformer/cuda/cub/block/block_scan.cuh new file mode 100644 index 000000000..27ea7ed40 --- /dev/null +++ b/fastertransformer/cuda/cub/block/block_scan.cuh @@ -0,0 +1,2126 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * The cub::BlockScan class provides [collective](index.html#sec0) methods for computing a parallel prefix sum/scan of items partitioned across a CUDA thread block. + */ + +#pragma once + +#include "specializations/block_scan_raking.cuh" +#include "specializations/block_scan_warp_scans.cuh" +#include "../util_arch.cuh" +#include "../util_type.cuh" +#include "../util_ptx.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/****************************************************************************** + * Algorithmic variants + ******************************************************************************/ + +/** + * \brief BlockScanAlgorithm enumerates alternative algorithms for cub::BlockScan to compute a parallel prefix scan across a CUDA thread block. + */ +enum BlockScanAlgorithm +{ + + /** + * \par Overview + * An efficient "raking reduce-then-scan" prefix scan algorithm. Execution is comprised of five phases: + * -# Upsweep sequential reduction in registers (if threads contribute more than one input each). Each thread then places the partial reduction of its item(s) into shared memory. + * -# Upsweep sequential reduction in shared memory. Threads within a single warp rake across segments of shared partial reductions. + * -# A warp-synchronous Kogge-Stone style exclusive scan within the raking warp. + * -# Downsweep sequential exclusive scan in shared memory. Threads within a single warp rake across segments of shared partial reductions, seeded with the warp-scan output. + * -# Downsweep sequential scan in registers (if threads contribute more than one input), seeded with the raking scan output. + * + * \par + * \image html block_scan_raking.png + *
\p BLOCK_SCAN_RAKING data flow for a hypothetical 16-thread thread block and 4-thread raking warp.
+ * + * \par Performance Considerations + * - Although this variant may suffer longer turnaround latencies when the + * GPU is under-occupied, it can often provide higher overall throughput + * across the GPU when suitably occupied. + */ + BLOCK_SCAN_RAKING, + + + /** + * \par Overview + * Similar to cub::BLOCK_SCAN_RAKING, but with fewer shared memory reads at + * the expense of higher register pressure. Raking threads preserve their + * "upsweep" segment of values in registers while performing warp-synchronous + * scan, allowing the "downsweep" not to re-read them from shared memory. + */ + BLOCK_SCAN_RAKING_MEMOIZE, + + + /** + * \par Overview + * A quick "tiled warpscans" prefix scan algorithm. Execution is comprised of four phases: + * -# Upsweep sequential reduction in registers (if threads contribute more than one input each). Each thread then places the partial reduction of its item(s) into shared memory. + * -# Compute a shallow, but inefficient warp-synchronous Kogge-Stone style scan within each warp. + * -# A propagation phase where the warp scan outputs in each warp are updated with the aggregate from each preceding warp. + * -# Downsweep sequential scan in registers (if threads contribute more than one input), seeded with the raking scan output. + * + * \par + * \image html block_scan_warpscans.png + *
\p BLOCK_SCAN_WARP_SCANS data flow for a hypothetical 16-thread thread block and 4-thread raking warp.
+ * + * \par Performance Considerations + * - Although this variant may suffer lower overall throughput across the + * GPU because due to a heavy reliance on inefficient warpscans, it can + * often provide lower turnaround latencies when the GPU is under-occupied. + */ + BLOCK_SCAN_WARP_SCANS, +}; + + +/****************************************************************************** + * Block scan + ******************************************************************************/ + +/** + * \brief The BlockScan class provides [collective](index.html#sec0) methods for computing a parallel prefix sum/scan of items partitioned across a CUDA thread block. ![](block_scan_logo.png) + * \ingroup BlockModule + * + * \tparam T Data type being scanned + * \tparam BLOCK_DIM_X The thread block length in threads along the X dimension + * \tparam ALGORITHM [optional] cub::BlockScanAlgorithm enumerator specifying the underlying algorithm to use (default: cub::BLOCK_SCAN_RAKING) + * \tparam BLOCK_DIM_Y [optional] The thread block length in threads along the Y dimension (default: 1) + * \tparam BLOCK_DIM_Z [optional] The thread block length in threads along the Z dimension (default: 1) + * \tparam PTX_ARCH [optional] \ptxversion + * + * \par Overview + * - Given a list of input elements and a binary reduction operator, a [prefix scan](http://en.wikipedia.org/wiki/Prefix_sum) + * produces an output list where each element is computed to be the reduction + * of the elements occurring earlier in the input list. Prefix sum + * connotes a prefix scan with the addition operator. The term \em inclusive indicates + * that the ith output reduction incorporates the ith input. + * The term \em exclusive indicates the ith input is not incorporated into + * the ith output reduction. + * - \rowmajor + * - BlockScan can be optionally specialized by algorithm to accommodate different workload profiles: + * -# cub::BLOCK_SCAN_RAKING. An efficient (high throughput) "raking reduce-then-scan" prefix scan algorithm. [More...](\ref cub::BlockScanAlgorithm) + * -# cub::BLOCK_SCAN_RAKING_MEMOIZE. Similar to cub::BLOCK_SCAN_RAKING, but having higher throughput at the expense of additional register pressure for intermediate storage. [More...](\ref cub::BlockScanAlgorithm) + * -# cub::BLOCK_SCAN_WARP_SCANS. A quick (low latency) "tiled warpscans" prefix scan algorithm. [More...](\ref cub::BlockScanAlgorithm) + * + * \par Performance Considerations + * - \granularity + * - Uses special instructions when applicable (e.g., warp \p SHFL) + * - Uses synchronization-free communication between warp lanes when applicable + * - Invokes a minimal number of minimal block-wide synchronization barriers (only + * one or two depending on algorithm selection) + * - Incurs zero bank conflicts for most types + * - Computation is slightly more efficient (i.e., having lower instruction overhead) for: + * - Prefix sum variants (vs. generic scan) + * - \blocksize + * - See cub::BlockScanAlgorithm for performance details regarding algorithmic alternatives + * + * \par A Simple Example + * \blockcollective{BlockScan} + * \par + * The code snippet below illustrates an exclusive prefix sum of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockScan for a 1D block of 128 threads on type int + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Collectively compute the block-wide exclusive prefix sum + * BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is + * {[1,1,1,1], [1,1,1,1], ..., [1,1,1,1]}. + * The corresponding output \p thread_data in those threads will be + * {[0,1,2,3], [4,5,6,7], ..., [508,509,510,511]}. + * + */ +template < + typename T, + int BLOCK_DIM_X, + BlockScanAlgorithm ALGORITHM = BLOCK_SCAN_RAKING, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int PTX_ARCH = CUB_PTX_ARCH> +class BlockScan +{ +private: + + /****************************************************************************** + * Constants and type definitions + ******************************************************************************/ + + /// Constants + enum + { + /// The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + }; + + /** + * Ensure the template parameterization meets the requirements of the + * specified algorithm. Currently, the BLOCK_SCAN_WARP_SCANS policy + * cannot be used with thread block sizes not a multiple of the + * architectural warp size. + */ + static const BlockScanAlgorithm SAFE_ALGORITHM = + ((ALGORITHM == BLOCK_SCAN_WARP_SCANS) && (BLOCK_THREADS % CUB_WARP_THREADS(PTX_ARCH) != 0)) ? + BLOCK_SCAN_RAKING : + ALGORITHM; + + typedef BlockScanWarpScans WarpScans; + typedef BlockScanRaking Raking; + + /// Define the delegate type for the desired algorithm + typedef typename If<(SAFE_ALGORITHM == BLOCK_SCAN_WARP_SCANS), + WarpScans, + Raking>::Type InternalBlockScan; + + /// Shared memory storage layout type for BlockScan + typedef typename InternalBlockScan::TempStorage _TempStorage; + + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + /// Shared storage reference + _TempStorage &temp_storage; + + /// Linear thread-id + unsigned int linear_tid; + + + /****************************************************************************** + * Utility methods + ******************************************************************************/ + + /// Internal storage allocator + __device__ __forceinline__ _TempStorage& PrivateStorage() + { + __shared__ _TempStorage private_storage; + return private_storage; + } + + + /****************************************************************************** + * Public types + ******************************************************************************/ +public: + + /// \smemstorage{BlockScan} + struct TempStorage : Uninitialized<_TempStorage> {}; + + + /******************************************************************//** + * \name Collective constructors + *********************************************************************/ + //@{ + + /** + * \brief Collective constructor using a private static allocation of shared memory as temporary storage. + */ + __device__ __forceinline__ BlockScan() + : + temp_storage(PrivateStorage()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + /** + * \brief Collective constructor using the specified memory allocation as temporary storage. + */ + __device__ __forceinline__ BlockScan( + TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + + //@} end member group + /******************************************************************//** + * \name Exclusive prefix sum operations + *********************************************************************/ + //@{ + + + /** + * \brief Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. The value of 0 is applied as the initial value, and is assigned to \p output in thread0. + * + * \par + * - \identityzero + * - \rowmajor + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates an exclusive prefix sum of 128 integer items that + * are partitioned across 128 threads. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockScan for a 1D block of 128 threads on type int + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Obtain input item for each thread + * int thread_data; + * ... + * + * // Collectively compute the block-wide exclusive prefix sum + * BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is 1, 1, ..., 1. The + * corresponding output \p thread_data in those threads will be 0, 1, ..., 127. + * + */ + __device__ __forceinline__ void ExclusiveSum( + T input, ///< [in] Calling thread's input item + T &output) ///< [out] Calling thread's output item (may be aliased to \p input) + { + T initial_value = 0; + ExclusiveScan(input, output, initial_value, cub::Sum()); + } + + + /** + * \brief Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. The value of 0 is applied as the initial value, and is assigned to \p output in thread0. Also provides every thread with the block-wide \p block_aggregate of all inputs. + * + * \par + * - \identityzero + * - \rowmajor + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates an exclusive prefix sum of 128 integer items that + * are partitioned across 128 threads. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockScan for a 1D block of 128 threads on type int + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Obtain input item for each thread + * int thread_data; + * ... + * + * // Collectively compute the block-wide exclusive prefix sum + * int block_aggregate; + * BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data, block_aggregate); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is 1, 1, ..., 1. The + * corresponding output \p thread_data in those threads will be 0, 1, ..., 127. + * Furthermore the value \p 128 will be stored in \p block_aggregate for all threads. + * + */ + __device__ __forceinline__ void ExclusiveSum( + T input, ///< [in] Calling thread's input item + T &output, ///< [out] Calling thread's output item (may be aliased to \p input) + T &block_aggregate) ///< [out] block-wide aggregate reduction of input items + { + T initial_value = 0; + ExclusiveScan(input, output, initial_value, cub::Sum(), block_aggregate); + } + + + /** + * \brief Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. Instead of using 0 as the block-wide prefix, the call-back functor \p block_prefix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + * + * \par + * - \identityzero + * - The \p block_prefix_callback_op functor must implement a member function T operator()(T block_aggregate). + * The functor's input parameter \p block_aggregate is the same value also returned by the scan operation. + * The functor will be invoked by the first warp of threads in the block, however only the return value from + * lane0 is applied as the block-wide prefix. Can be stateful. + * - \rowmajor + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a single thread block that progressively + * computes an exclusive prefix sum over multiple "tiles" of input using a + * prefix functor to maintain a running total between block-wide scans. Each tile consists + * of 128 integer items that are partitioned across 128 threads. + * \par + * \code + * #include // or equivalently + * + * // A stateful callback functor that maintains a running prefix to be applied + * // during consecutive scan operations. + * struct BlockPrefixCallbackOp + * { + * // Running prefix + * int running_total; + * + * // Constructor + * __device__ BlockPrefixCallbackOp(int running_total) : running_total(running_total) {} + * + * // Callback operator to be entered by the first warp of threads in the block. + * // Thread-0 is responsible for returning a value for seeding the block-wide scan. + * __device__ int operator()(int block_aggregate) + * { + * int old_prefix = running_total; + * running_total += block_aggregate; + * return old_prefix; + * } + * }; + * + * __global__ void ExampleKernel(int *d_data, int num_items, ...) + * { + * // Specialize BlockScan for a 1D block of 128 threads + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Initialize running total + * BlockPrefixCallbackOp prefix_op(0); + * + * // Have the block iterate over segments of items + * for (int block_offset = 0; block_offset < num_items; block_offset += 128) + * { + * // Load a segment of consecutive items that are blocked across threads + * int thread_data = d_data[block_offset]; + * + * // Collectively compute the block-wide exclusive prefix sum + * BlockScan(temp_storage).ExclusiveSum( + * thread_data, thread_data, prefix_op); + * CTA_SYNC(); + * + * // Store scanned items to output segment + * d_data[block_offset] = thread_data; + * } + * \endcode + * \par + * Suppose the input \p d_data is 1, 1, 1, 1, 1, 1, 1, 1, .... + * The corresponding output for the first segment will be 0, 1, ..., 127. + * The output for the second segment will be 128, 129, ..., 255. + * + * \tparam BlockPrefixCallbackOp [inferred] Call-back functor type having member T operator()(T block_aggregate) + */ + template + __device__ __forceinline__ void ExclusiveSum( + T input, ///< [in] Calling thread's input item + T &output, ///< [out] Calling thread's output item (may be aliased to \p input) + BlockPrefixCallbackOp &block_prefix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a block-wide prefix to be applied to the logical input sequence. + { + ExclusiveScan(input, output, cub::Sum(), block_prefix_callback_op); + } + + + //@} end member group + /******************************************************************//** + * \name Exclusive prefix sum operations (multiple data per thread) + *********************************************************************/ + //@{ + + + /** + * \brief Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. The value of 0 is applied as the initial value, and is assigned to \p output[0] in thread0. + * + * \par + * - \identityzero + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates an exclusive prefix sum of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockScan for a 1D block of 128 threads on type int + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Collectively compute the block-wide exclusive prefix sum + * BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is { [1,1,1,1], [1,1,1,1], ..., [1,1,1,1] }. The + * corresponding output \p thread_data in those threads will be { [0,1,2,3], [4,5,6,7], ..., [508,509,510,511] }. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + */ + template + __device__ __forceinline__ void ExclusiveSum( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD]) ///< [out] Calling thread's output items (may be aliased to \p input) + { + T initial_value = 0; + ExclusiveScan(input, output, initial_value, cub::Sum()); + } + + + /** + * \brief Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. The value of 0 is applied as the initial value, and is assigned to \p output[0] in thread0. Also provides every thread with the block-wide \p block_aggregate of all inputs. + * + * \par + * - \identityzero + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates an exclusive prefix sum of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockScan for a 1D block of 128 threads on type int + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Collectively compute the block-wide exclusive prefix sum + * int block_aggregate; + * BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data, block_aggregate); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is { [1,1,1,1], [1,1,1,1], ..., [1,1,1,1] }. The + * corresponding output \p thread_data in those threads will be { [0,1,2,3], [4,5,6,7], ..., [508,509,510,511] }. + * Furthermore the value \p 512 will be stored in \p block_aggregate for all threads. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + */ + template + __device__ __forceinline__ void ExclusiveSum( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) + T &block_aggregate) ///< [out] block-wide aggregate reduction of input items + { + // Reduce consecutive thread items in registers + T initial_value = 0; + ExclusiveScan(input, output, initial_value, cub::Sum(), block_aggregate); + } + + + /** + * \brief Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. Instead of using 0 as the block-wide prefix, the call-back functor \p block_prefix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + * + * \par + * - \identityzero + * - The \p block_prefix_callback_op functor must implement a member function T operator()(T block_aggregate). + * The functor's input parameter \p block_aggregate is the same value also returned by the scan operation. + * The functor will be invoked by the first warp of threads in the block, however only the return value from + * lane0 is applied as the block-wide prefix. Can be stateful. + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a single thread block that progressively + * computes an exclusive prefix sum over multiple "tiles" of input using a + * prefix functor to maintain a running total between block-wide scans. Each tile consists + * of 512 integer items that are partitioned in a [blocked arrangement](index.html#sec5sec3) + * across 128 threads where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * // A stateful callback functor that maintains a running prefix to be applied + * // during consecutive scan operations. + * struct BlockPrefixCallbackOp + * { + * // Running prefix + * int running_total; + * + * // Constructor + * __device__ BlockPrefixCallbackOp(int running_total) : running_total(running_total) {} + * + * // Callback operator to be entered by the first warp of threads in the block. + * // Thread-0 is responsible for returning a value for seeding the block-wide scan. + * __device__ int operator()(int block_aggregate) + * { + * int old_prefix = running_total; + * running_total += block_aggregate; + * return old_prefix; + * } + * }; + * + * __global__ void ExampleKernel(int *d_data, int num_items, ...) + * { + * // Specialize BlockLoad, BlockStore, and BlockScan for a 1D block of 128 threads, 4 ints per thread + * typedef cub::BlockLoad BlockLoad; + * typedef cub::BlockStore BlockStore; + * typedef cub::BlockScan BlockScan; + * + * // Allocate aliased shared memory for BlockLoad, BlockStore, and BlockScan + * __shared__ union { + * typename BlockLoad::TempStorage load; + * typename BlockScan::TempStorage scan; + * typename BlockStore::TempStorage store; + * } temp_storage; + * + * // Initialize running total + * BlockPrefixCallbackOp prefix_op(0); + * + * // Have the block iterate over segments of items + * for (int block_offset = 0; block_offset < num_items; block_offset += 128 * 4) + * { + * // Load a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * BlockLoad(temp_storage.load).Load(d_data + block_offset, thread_data); + * CTA_SYNC(); + * + * // Collectively compute the block-wide exclusive prefix sum + * int block_aggregate; + * BlockScan(temp_storage.scan).ExclusiveSum( + * thread_data, thread_data, prefix_op); + * CTA_SYNC(); + * + * // Store scanned items to output segment + * BlockStore(temp_storage.store).Store(d_data + block_offset, thread_data); + * CTA_SYNC(); + * } + * \endcode + * \par + * Suppose the input \p d_data is 1, 1, 1, 1, 1, 1, 1, 1, .... + * The corresponding output for the first segment will be 0, 1, 2, 3, ..., 510, 511. + * The output for the second segment will be 512, 513, 514, 515, ..., 1022, 1023. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam BlockPrefixCallbackOp [inferred] Call-back functor type having member T operator()(T block_aggregate) + */ + template < + int ITEMS_PER_THREAD, + typename BlockPrefixCallbackOp> + __device__ __forceinline__ void ExclusiveSum( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) + BlockPrefixCallbackOp &block_prefix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a block-wide prefix to be applied to the logical input sequence. + { + ExclusiveScan(input, output, cub::Sum(), block_prefix_callback_op); + } + + + + //@} end member group // Exclusive prefix sums + /******************************************************************//** + * \name Exclusive prefix scan operations + *********************************************************************/ + //@{ + + + /** + * \brief Computes an exclusive block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. + * + * \par + * - Supports non-commutative scan operators. + * - \rowmajor + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates an exclusive prefix max scan of 128 integer items that + * are partitioned across 128 threads. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockScan for a 1D block of 128 threads on type int + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Obtain input item for each thread + * int thread_data; + * ... + * + * // Collectively compute the block-wide exclusive prefix max scan + * BlockScan(temp_storage).ExclusiveScan(thread_data, thread_data, INT_MIN, cub::Max()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is 0, -1, 2, -3, ..., 126, -127. The + * corresponding output \p thread_data in those threads will be INT_MIN, 0, 0, 2, ..., 124, 126. + * + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item + T &output, ///< [out] Calling thread's output item (may be aliased to \p input) + T initial_value, ///< [in] Initial value to seed the exclusive scan (and is assigned to \p output[0] in thread0) + ScanOp scan_op) ///< [in] Binary scan functor + { + InternalBlockScan(temp_storage).ExclusiveScan(input, output, initial_value, scan_op); + } + + + /** + * \brief Computes an exclusive block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. Also provides every thread with the block-wide \p block_aggregate of all inputs. + * + * \par + * - Supports non-commutative scan operators. + * - \rowmajor + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates an exclusive prefix max scan of 128 integer items that + * are partitioned across 128 threads. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockScan for a 1D block of 128 threads on type int + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Obtain input item for each thread + * int thread_data; + * ... + * + * // Collectively compute the block-wide exclusive prefix max scan + * int block_aggregate; + * BlockScan(temp_storage).ExclusiveScan(thread_data, thread_data, INT_MIN, cub::Max(), block_aggregate); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is 0, -1, 2, -3, ..., 126, -127. The + * corresponding output \p thread_data in those threads will be INT_MIN, 0, 0, 2, ..., 124, 126. + * Furthermore the value \p 126 will be stored in \p block_aggregate for all threads. + * + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input items + T &output, ///< [out] Calling thread's output items (may be aliased to \p input) + T initial_value, ///< [in] Initial value to seed the exclusive scan (and is assigned to \p output[0] in thread0) + ScanOp scan_op, ///< [in] Binary scan functor + T &block_aggregate) ///< [out] block-wide aggregate reduction of input items + { + InternalBlockScan(temp_storage).ExclusiveScan(input, output, initial_value, scan_op, block_aggregate); + } + + + /** + * \brief Computes an exclusive block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_prefix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + * + * \par + * - The \p block_prefix_callback_op functor must implement a member function T operator()(T block_aggregate). + * The functor's input parameter \p block_aggregate is the same value also returned by the scan operation. + * The functor will be invoked by the first warp of threads in the block, however only the return value from + * lane0 is applied as the block-wide prefix. Can be stateful. + * - Supports non-commutative scan operators. + * - \rowmajor + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a single thread block that progressively + * computes an exclusive prefix max scan over multiple "tiles" of input using a + * prefix functor to maintain a running total between block-wide scans. Each tile consists + * of 128 integer items that are partitioned across 128 threads. + * \par + * \code + * #include // or equivalently + * + * // A stateful callback functor that maintains a running prefix to be applied + * // during consecutive scan operations. + * struct BlockPrefixCallbackOp + * { + * // Running prefix + * int running_total; + * + * // Constructor + * __device__ BlockPrefixCallbackOp(int running_total) : running_total(running_total) {} + * + * // Callback operator to be entered by the first warp of threads in the block. + * // Thread-0 is responsible for returning a value for seeding the block-wide scan. + * __device__ int operator()(int block_aggregate) + * { + * int old_prefix = running_total; + * running_total = (block_aggregate > old_prefix) ? block_aggregate : old_prefix; + * return old_prefix; + * } + * }; + * + * __global__ void ExampleKernel(int *d_data, int num_items, ...) + * { + * // Specialize BlockScan for a 1D block of 128 threads + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Initialize running total + * BlockPrefixCallbackOp prefix_op(INT_MIN); + * + * // Have the block iterate over segments of items + * for (int block_offset = 0; block_offset < num_items; block_offset += 128) + * { + * // Load a segment of consecutive items that are blocked across threads + * int thread_data = d_data[block_offset]; + * + * // Collectively compute the block-wide exclusive prefix max scan + * BlockScan(temp_storage).ExclusiveScan( + * thread_data, thread_data, INT_MIN, cub::Max(), prefix_op); + * CTA_SYNC(); + * + * // Store scanned items to output segment + * d_data[block_offset] = thread_data; + * } + * \endcode + * \par + * Suppose the input \p d_data is 0, -1, 2, -3, 4, -5, .... + * The corresponding output for the first segment will be INT_MIN, 0, 0, 2, ..., 124, 126. + * The output for the second segment will be 126, 128, 128, 130, ..., 252, 254. + * + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + * \tparam BlockPrefixCallbackOp [inferred] Call-back functor type having member T operator()(T block_aggregate) + */ + template < + typename ScanOp, + typename BlockPrefixCallbackOp> + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item + T &output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan functor + BlockPrefixCallbackOp &block_prefix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a block-wide prefix to be applied to the logical input sequence. + { + InternalBlockScan(temp_storage).ExclusiveScan(input, output, scan_op, block_prefix_callback_op); + } + + + //@} end member group // Inclusive prefix sums + /******************************************************************//** + * \name Exclusive prefix scan operations (multiple data per thread) + *********************************************************************/ + //@{ + + + /** + * \brief Computes an exclusive block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. + * + * \par + * - Supports non-commutative scan operators. + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates an exclusive prefix max scan of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockScan for a 1D block of 128 threads on type int + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Collectively compute the block-wide exclusive prefix max scan + * BlockScan(temp_storage).ExclusiveScan(thread_data, thread_data, INT_MIN, cub::Max()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is + * { [0,-1,2,-3], [4,-5,6,-7], ..., [508,-509,510,-511] }. + * The corresponding output \p thread_data in those threads will be + * { [INT_MIN,0,0,2], [2,4,4,6], ..., [506,508,508,510] }. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + */ + template < + int ITEMS_PER_THREAD, + typename ScanOp> + __device__ __forceinline__ void ExclusiveScan( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) + T initial_value, ///< [in] Initial value to seed the exclusive scan (and is assigned to \p output[0] in thread0) + ScanOp scan_op) ///< [in] Binary scan functor + { + // Reduce consecutive thread items in registers + T thread_prefix = internal::ThreadReduce(input, scan_op); + + // Exclusive thread block-scan + ExclusiveScan(thread_prefix, thread_prefix, initial_value, scan_op); + + // Exclusive scan in registers with prefix as seed + internal::ThreadScanExclusive(input, output, scan_op, thread_prefix); + } + + + /** + * \brief Computes an exclusive block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. Also provides every thread with the block-wide \p block_aggregate of all inputs. + * + * \par + * - Supports non-commutative scan operators. + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates an exclusive prefix max scan of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockScan for a 1D block of 128 threads on type int + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Collectively compute the block-wide exclusive prefix max scan + * int block_aggregate; + * BlockScan(temp_storage).ExclusiveScan(thread_data, thread_data, INT_MIN, cub::Max(), block_aggregate); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is { [0,-1,2,-3], [4,-5,6,-7], ..., [508,-509,510,-511] }. The + * corresponding output \p thread_data in those threads will be { [INT_MIN,0,0,2], [2,4,4,6], ..., [506,508,508,510] }. + * Furthermore the value \p 510 will be stored in \p block_aggregate for all threads. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + */ + template < + int ITEMS_PER_THREAD, + typename ScanOp> + __device__ __forceinline__ void ExclusiveScan( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) + T initial_value, ///< [in] Initial value to seed the exclusive scan (and is assigned to \p output[0] in thread0) + ScanOp scan_op, ///< [in] Binary scan functor + T &block_aggregate) ///< [out] block-wide aggregate reduction of input items + { + // Reduce consecutive thread items in registers + T thread_prefix = internal::ThreadReduce(input, scan_op); + + // Exclusive thread block-scan + ExclusiveScan(thread_prefix, thread_prefix, initial_value, scan_op, block_aggregate); + + // Exclusive scan in registers with prefix as seed + internal::ThreadScanExclusive(input, output, scan_op, thread_prefix); + } + + + /** + * \brief Computes an exclusive block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_prefix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + * + * \par + * - The \p block_prefix_callback_op functor must implement a member function T operator()(T block_aggregate). + * The functor's input parameter \p block_aggregate is the same value also returned by the scan operation. + * The functor will be invoked by the first warp of threads in the block, however only the return value from + * lane0 is applied as the block-wide prefix. Can be stateful. + * - Supports non-commutative scan operators. + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a single thread block that progressively + * computes an exclusive prefix max scan over multiple "tiles" of input using a + * prefix functor to maintain a running total between block-wide scans. Each tile consists + * of 128 integer items that are partitioned across 128 threads. + * \par + * \code + * #include // or equivalently + * + * // A stateful callback functor that maintains a running prefix to be applied + * // during consecutive scan operations. + * struct BlockPrefixCallbackOp + * { + * // Running prefix + * int running_total; + * + * // Constructor + * __device__ BlockPrefixCallbackOp(int running_total) : running_total(running_total) {} + * + * // Callback operator to be entered by the first warp of threads in the block. + * // Thread-0 is responsible for returning a value for seeding the block-wide scan. + * __device__ int operator()(int block_aggregate) + * { + * int old_prefix = running_total; + * running_total = (block_aggregate > old_prefix) ? block_aggregate : old_prefix; + * return old_prefix; + * } + * }; + * + * __global__ void ExampleKernel(int *d_data, int num_items, ...) + * { + * // Specialize BlockLoad, BlockStore, and BlockScan for a 1D block of 128 threads, 4 ints per thread + * typedef cub::BlockLoad BlockLoad; + * typedef cub::BlockStore BlockStore; + * typedef cub::BlockScan BlockScan; + * + * // Allocate aliased shared memory for BlockLoad, BlockStore, and BlockScan + * __shared__ union { + * typename BlockLoad::TempStorage load; + * typename BlockScan::TempStorage scan; + * typename BlockStore::TempStorage store; + * } temp_storage; + * + * // Initialize running total + * BlockPrefixCallbackOp prefix_op(0); + * + * // Have the block iterate over segments of items + * for (int block_offset = 0; block_offset < num_items; block_offset += 128 * 4) + * { + * // Load a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * BlockLoad(temp_storage.load).Load(d_data + block_offset, thread_data); + * CTA_SYNC(); + * + * // Collectively compute the block-wide exclusive prefix max scan + * BlockScan(temp_storage.scan).ExclusiveScan( + * thread_data, thread_data, INT_MIN, cub::Max(), prefix_op); + * CTA_SYNC(); + * + * // Store scanned items to output segment + * BlockStore(temp_storage.store).Store(d_data + block_offset, thread_data); + * CTA_SYNC(); + * } + * \endcode + * \par + * Suppose the input \p d_data is 0, -1, 2, -3, 4, -5, .... + * The corresponding output for the first segment will be INT_MIN, 0, 0, 2, 2, 4, ..., 508, 510. + * The output for the second segment will be 510, 512, 512, 514, 514, 516, ..., 1020, 1022. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + * \tparam BlockPrefixCallbackOp [inferred] Call-back functor type having member T operator()(T block_aggregate) + */ + template < + int ITEMS_PER_THREAD, + typename ScanOp, + typename BlockPrefixCallbackOp> + __device__ __forceinline__ void ExclusiveScan( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan functor + BlockPrefixCallbackOp &block_prefix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a block-wide prefix to be applied to the logical input sequence. + { + // Reduce consecutive thread items in registers + T thread_prefix = internal::ThreadReduce(input, scan_op); + + // Exclusive thread block-scan + ExclusiveScan(thread_prefix, thread_prefix, scan_op, block_prefix_callback_op); + + // Exclusive scan in registers with prefix as seed + internal::ThreadScanExclusive(input, output, scan_op, thread_prefix); + } + + + //@} end member group +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document no-initial-value scans + + /******************************************************************//** + * \name Exclusive prefix scan operations (no initial value, single datum per thread) + *********************************************************************/ + //@{ + + + /** + * \brief Computes an exclusive block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. With no initial value, the output computed for thread0 is undefined. + * + * \par + * - Supports non-commutative scan operators. + * - \rowmajor + * - \smemreuse + * + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item + T &output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op) ///< [in] Binary scan functor + { + InternalBlockScan(temp_storage).ExclusiveScan(input, output, scan_op); + } + + + /** + * \brief Computes an exclusive block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. Also provides every thread with the block-wide \p block_aggregate of all inputs. With no initial value, the output computed for thread0 is undefined. + * + * \par + * - Supports non-commutative scan operators. + * - \rowmajor + * - \smemreuse + * + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item + T &output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan functor + T &block_aggregate) ///< [out] block-wide aggregate reduction of input items + { + InternalBlockScan(temp_storage).ExclusiveScan(input, output, scan_op, block_aggregate); + } + + //@} end member group + /******************************************************************//** + * \name Exclusive prefix scan operations (no initial value, multiple data per thread) + *********************************************************************/ + //@{ + + + /** + * \brief Computes an exclusive block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. With no initial value, the output computed for thread0 is undefined. + * + * \par + * - Supports non-commutative scan operators. + * - \blocked + * - \granularity + * - \smemreuse + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + */ + template < + int ITEMS_PER_THREAD, + typename ScanOp> + __device__ __forceinline__ void ExclusiveScan( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) + ScanOp scan_op) ///< [in] Binary scan functor + { + // Reduce consecutive thread items in registers + T thread_partial = internal::ThreadReduce(input, scan_op); + + // Exclusive thread block-scan + ExclusiveScan(thread_partial, thread_partial, scan_op); + + // Exclusive scan in registers with prefix + internal::ThreadScanExclusive(input, output, scan_op, thread_partial, (linear_tid != 0)); + } + + + /** + * \brief Computes an exclusive block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. Also provides every thread with the block-wide \p block_aggregate of all inputs. With no initial value, the output computed for thread0 is undefined. + * + * \par + * - Supports non-commutative scan operators. + * - \blocked + * - \granularity + * - \smemreuse + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + */ + template < + int ITEMS_PER_THREAD, + typename ScanOp> + __device__ __forceinline__ void ExclusiveScan( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan functor + T &block_aggregate) ///< [out] block-wide aggregate reduction of input items + { + // Reduce consecutive thread items in registers + T thread_partial = internal::ThreadReduce(input, scan_op); + + // Exclusive thread block-scan + ExclusiveScan(thread_partial, thread_partial, scan_op, block_aggregate); + + // Exclusive scan in registers with prefix + internal::ThreadScanExclusive(input, output, scan_op, thread_partial, (linear_tid != 0)); + } + + + //@} end member group +#endif // DOXYGEN_SHOULD_SKIP_THIS // Do not document no-initial-value scans + + /******************************************************************//** + * \name Inclusive prefix sum operations + *********************************************************************/ + //@{ + + + /** + * \brief Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. + * + * \par + * - \rowmajor + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates an inclusive prefix sum of 128 integer items that + * are partitioned across 128 threads. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockScan for a 1D block of 128 threads on type int + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Obtain input item for each thread + * int thread_data; + * ... + * + * // Collectively compute the block-wide inclusive prefix sum + * BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is 1, 1, ..., 1. The + * corresponding output \p thread_data in those threads will be 1, 2, ..., 128. + * + */ + __device__ __forceinline__ void InclusiveSum( + T input, ///< [in] Calling thread's input item + T &output) ///< [out] Calling thread's output item (may be aliased to \p input) + { + InclusiveScan(input, output, cub::Sum()); + } + + + /** + * \brief Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. Also provides every thread with the block-wide \p block_aggregate of all inputs. + * + * \par + * - \rowmajor + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates an inclusive prefix sum of 128 integer items that + * are partitioned across 128 threads. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockScan for a 1D block of 128 threads on type int + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Obtain input item for each thread + * int thread_data; + * ... + * + * // Collectively compute the block-wide inclusive prefix sum + * int block_aggregate; + * BlockScan(temp_storage).InclusiveSum(thread_data, thread_data, block_aggregate); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is 1, 1, ..., 1. The + * corresponding output \p thread_data in those threads will be 1, 2, ..., 128. + * Furthermore the value \p 128 will be stored in \p block_aggregate for all threads. + * + */ + __device__ __forceinline__ void InclusiveSum( + T input, ///< [in] Calling thread's input item + T &output, ///< [out] Calling thread's output item (may be aliased to \p input) + T &block_aggregate) ///< [out] block-wide aggregate reduction of input items + { + InclusiveScan(input, output, cub::Sum(), block_aggregate); + } + + + + /** + * \brief Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. Instead of using 0 as the block-wide prefix, the call-back functor \p block_prefix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + * + * \par + * - The \p block_prefix_callback_op functor must implement a member function T operator()(T block_aggregate). + * The functor's input parameter \p block_aggregate is the same value also returned by the scan operation. + * The functor will be invoked by the first warp of threads in the block, however only the return value from + * lane0 is applied as the block-wide prefix. Can be stateful. + * - \rowmajor + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a single thread block that progressively + * computes an inclusive prefix sum over multiple "tiles" of input using a + * prefix functor to maintain a running total between block-wide scans. Each tile consists + * of 128 integer items that are partitioned across 128 threads. + * \par + * \code + * #include // or equivalently + * + * // A stateful callback functor that maintains a running prefix to be applied + * // during consecutive scan operations. + * struct BlockPrefixCallbackOp + * { + * // Running prefix + * int running_total; + * + * // Constructor + * __device__ BlockPrefixCallbackOp(int running_total) : running_total(running_total) {} + * + * // Callback operator to be entered by the first warp of threads in the block. + * // Thread-0 is responsible for returning a value for seeding the block-wide scan. + * __device__ int operator()(int block_aggregate) + * { + * int old_prefix = running_total; + * running_total += block_aggregate; + * return old_prefix; + * } + * }; + * + * __global__ void ExampleKernel(int *d_data, int num_items, ...) + * { + * // Specialize BlockScan for a 1D block of 128 threads + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Initialize running total + * BlockPrefixCallbackOp prefix_op(0); + * + * // Have the block iterate over segments of items + * for (int block_offset = 0; block_offset < num_items; block_offset += 128) + * { + * // Load a segment of consecutive items that are blocked across threads + * int thread_data = d_data[block_offset]; + * + * // Collectively compute the block-wide inclusive prefix sum + * BlockScan(temp_storage).InclusiveSum( + * thread_data, thread_data, prefix_op); + * CTA_SYNC(); + * + * // Store scanned items to output segment + * d_data[block_offset] = thread_data; + * } + * \endcode + * \par + * Suppose the input \p d_data is 1, 1, 1, 1, 1, 1, 1, 1, .... + * The corresponding output for the first segment will be 1, 2, ..., 128. + * The output for the second segment will be 129, 130, ..., 256. + * + * \tparam BlockPrefixCallbackOp [inferred] Call-back functor type having member T operator()(T block_aggregate) + */ + template + __device__ __forceinline__ void InclusiveSum( + T input, ///< [in] Calling thread's input item + T &output, ///< [out] Calling thread's output item (may be aliased to \p input) + BlockPrefixCallbackOp &block_prefix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a block-wide prefix to be applied to the logical input sequence. + { + InclusiveScan(input, output, cub::Sum(), block_prefix_callback_op); + } + + + //@} end member group + /******************************************************************//** + * \name Inclusive prefix sum operations (multiple data per thread) + *********************************************************************/ + //@{ + + + /** + * \brief Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. + * + * \par + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates an inclusive prefix sum of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockScan for a 1D block of 128 threads on type int + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Collectively compute the block-wide inclusive prefix sum + * BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is { [1,1,1,1], [1,1,1,1], ..., [1,1,1,1] }. The + * corresponding output \p thread_data in those threads will be { [1,2,3,4], [5,6,7,8], ..., [509,510,511,512] }. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + */ + template + __device__ __forceinline__ void InclusiveSum( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD]) ///< [out] Calling thread's output items (may be aliased to \p input) + { + if (ITEMS_PER_THREAD == 1) + { + InclusiveSum(input[0], output[0]); + } + else + { + // Reduce consecutive thread items in registers + Sum scan_op; + T thread_prefix = internal::ThreadReduce(input, scan_op); + + // Exclusive thread block-scan + ExclusiveSum(thread_prefix, thread_prefix); + + // Inclusive scan in registers with prefix as seed + internal::ThreadScanInclusive(input, output, scan_op, thread_prefix, (linear_tid != 0)); + } + } + + + /** + * \brief Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. Also provides every thread with the block-wide \p block_aggregate of all inputs. + * + * \par + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates an inclusive prefix sum of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockScan for a 1D block of 128 threads on type int + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Collectively compute the block-wide inclusive prefix sum + * int block_aggregate; + * BlockScan(temp_storage).InclusiveSum(thread_data, thread_data, block_aggregate); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is + * { [1,1,1,1], [1,1,1,1], ..., [1,1,1,1] }. The + * corresponding output \p thread_data in those threads will be + * { [1,2,3,4], [5,6,7,8], ..., [509,510,511,512] }. + * Furthermore the value \p 512 will be stored in \p block_aggregate for all threads. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ void InclusiveSum( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) + T &block_aggregate) ///< [out] block-wide aggregate reduction of input items + { + if (ITEMS_PER_THREAD == 1) + { + InclusiveSum(input[0], output[0], block_aggregate); + } + else + { + // Reduce consecutive thread items in registers + Sum scan_op; + T thread_prefix = internal::ThreadReduce(input, scan_op); + + // Exclusive thread block-scan + ExclusiveSum(thread_prefix, thread_prefix, block_aggregate); + + // Inclusive scan in registers with prefix as seed + internal::ThreadScanInclusive(input, output, scan_op, thread_prefix, (linear_tid != 0)); + } + } + + + /** + * \brief Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. Instead of using 0 as the block-wide prefix, the call-back functor \p block_prefix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + * + * \par + * - The \p block_prefix_callback_op functor must implement a member function T operator()(T block_aggregate). + * The functor's input parameter \p block_aggregate is the same value also returned by the scan operation. + * The functor will be invoked by the first warp of threads in the block, however only the return value from + * lane0 is applied as the block-wide prefix. Can be stateful. + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a single thread block that progressively + * computes an inclusive prefix sum over multiple "tiles" of input using a + * prefix functor to maintain a running total between block-wide scans. Each tile consists + * of 512 integer items that are partitioned in a [blocked arrangement](index.html#sec5sec3) + * across 128 threads where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * // A stateful callback functor that maintains a running prefix to be applied + * // during consecutive scan operations. + * struct BlockPrefixCallbackOp + * { + * // Running prefix + * int running_total; + * + * // Constructor + * __device__ BlockPrefixCallbackOp(int running_total) : running_total(running_total) {} + * + * // Callback operator to be entered by the first warp of threads in the block. + * // Thread-0 is responsible for returning a value for seeding the block-wide scan. + * __device__ int operator()(int block_aggregate) + * { + * int old_prefix = running_total; + * running_total += block_aggregate; + * return old_prefix; + * } + * }; + * + * __global__ void ExampleKernel(int *d_data, int num_items, ...) + * { + * // Specialize BlockLoad, BlockStore, and BlockScan for a 1D block of 128 threads, 4 ints per thread + * typedef cub::BlockLoad BlockLoad; + * typedef cub::BlockStore BlockStore; + * typedef cub::BlockScan BlockScan; + * + * // Allocate aliased shared memory for BlockLoad, BlockStore, and BlockScan + * __shared__ union { + * typename BlockLoad::TempStorage load; + * typename BlockScan::TempStorage scan; + * typename BlockStore::TempStorage store; + * } temp_storage; + * + * // Initialize running total + * BlockPrefixCallbackOp prefix_op(0); + * + * // Have the block iterate over segments of items + * for (int block_offset = 0; block_offset < num_items; block_offset += 128 * 4) + * { + * // Load a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * BlockLoad(temp_storage.load).Load(d_data + block_offset, thread_data); + * CTA_SYNC(); + * + * // Collectively compute the block-wide inclusive prefix sum + * BlockScan(temp_storage.scan).IncluisveSum( + * thread_data, thread_data, prefix_op); + * CTA_SYNC(); + * + * // Store scanned items to output segment + * BlockStore(temp_storage.store).Store(d_data + block_offset, thread_data); + * CTA_SYNC(); + * } + * \endcode + * \par + * Suppose the input \p d_data is 1, 1, 1, 1, 1, 1, 1, 1, .... + * The corresponding output for the first segment will be 1, 2, 3, 4, ..., 511, 512. + * The output for the second segment will be 513, 514, 515, 516, ..., 1023, 1024. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam BlockPrefixCallbackOp [inferred] Call-back functor type having member T operator()(T block_aggregate) + */ + template < + int ITEMS_PER_THREAD, + typename BlockPrefixCallbackOp> + __device__ __forceinline__ void InclusiveSum( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) + BlockPrefixCallbackOp &block_prefix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a block-wide prefix to be applied to the logical input sequence. + { + if (ITEMS_PER_THREAD == 1) + { + InclusiveSum(input[0], output[0], block_prefix_callback_op); + } + else + { + // Reduce consecutive thread items in registers + Sum scan_op; + T thread_prefix = internal::ThreadReduce(input, scan_op); + + // Exclusive thread block-scan + ExclusiveSum(thread_prefix, thread_prefix, block_prefix_callback_op); + + // Inclusive scan in registers with prefix as seed + internal::ThreadScanInclusive(input, output, scan_op, thread_prefix); + } + } + + + //@} end member group + /******************************************************************//** + * \name Inclusive prefix scan operations + *********************************************************************/ + //@{ + + + /** + * \brief Computes an inclusive block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. + * + * \par + * - Supports non-commutative scan operators. + * - \rowmajor + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates an inclusive prefix max scan of 128 integer items that + * are partitioned across 128 threads. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockScan for a 1D block of 128 threads on type int + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Obtain input item for each thread + * int thread_data; + * ... + * + * // Collectively compute the block-wide inclusive prefix max scan + * BlockScan(temp_storage).InclusiveScan(thread_data, thread_data, cub::Max()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is 0, -1, 2, -3, ..., 126, -127. The + * corresponding output \p thread_data in those threads will be 0, 0, 2, 2, ..., 126, 126. + * + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item + T &output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op) ///< [in] Binary scan functor + { + InternalBlockScan(temp_storage).InclusiveScan(input, output, scan_op); + } + + + /** + * \brief Computes an inclusive block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. Also provides every thread with the block-wide \p block_aggregate of all inputs. + * + * \par + * - Supports non-commutative scan operators. + * - \rowmajor + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates an inclusive prefix max scan of 128 integer items that + * are partitioned across 128 threads. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockScan for a 1D block of 128 threads on type int + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Obtain input item for each thread + * int thread_data; + * ... + * + * // Collectively compute the block-wide inclusive prefix max scan + * int block_aggregate; + * BlockScan(temp_storage).InclusiveScan(thread_data, thread_data, cub::Max(), block_aggregate); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is 0, -1, 2, -3, ..., 126, -127. The + * corresponding output \p thread_data in those threads will be 0, 0, 2, 2, ..., 126, 126. + * Furthermore the value \p 126 will be stored in \p block_aggregate for all threads. + * + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item + T &output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan functor + T &block_aggregate) ///< [out] block-wide aggregate reduction of input items + { + InternalBlockScan(temp_storage).InclusiveScan(input, output, scan_op, block_aggregate); + } + + + /** + * \brief Computes an inclusive block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_prefix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + * + * \par + * - The \p block_prefix_callback_op functor must implement a member function T operator()(T block_aggregate). + * The functor's input parameter \p block_aggregate is the same value also returned by the scan operation. + * The functor will be invoked by the first warp of threads in the block, however only the return value from + * lane0 is applied as the block-wide prefix. Can be stateful. + * - Supports non-commutative scan operators. + * - \rowmajor + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a single thread block that progressively + * computes an inclusive prefix max scan over multiple "tiles" of input using a + * prefix functor to maintain a running total between block-wide scans. Each tile consists + * of 128 integer items that are partitioned across 128 threads. + * \par + * \code + * #include // or equivalently + * + * // A stateful callback functor that maintains a running prefix to be applied + * // during consecutive scan operations. + * struct BlockPrefixCallbackOp + * { + * // Running prefix + * int running_total; + * + * // Constructor + * __device__ BlockPrefixCallbackOp(int running_total) : running_total(running_total) {} + * + * // Callback operator to be entered by the first warp of threads in the block. + * // Thread-0 is responsible for returning a value for seeding the block-wide scan. + * __device__ int operator()(int block_aggregate) + * { + * int old_prefix = running_total; + * running_total = (block_aggregate > old_prefix) ? block_aggregate : old_prefix; + * return old_prefix; + * } + * }; + * + * __global__ void ExampleKernel(int *d_data, int num_items, ...) + * { + * // Specialize BlockScan for a 1D block of 128 threads + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Initialize running total + * BlockPrefixCallbackOp prefix_op(INT_MIN); + * + * // Have the block iterate over segments of items + * for (int block_offset = 0; block_offset < num_items; block_offset += 128) + * { + * // Load a segment of consecutive items that are blocked across threads + * int thread_data = d_data[block_offset]; + * + * // Collectively compute the block-wide inclusive prefix max scan + * BlockScan(temp_storage).InclusiveScan( + * thread_data, thread_data, cub::Max(), prefix_op); + * CTA_SYNC(); + * + * // Store scanned items to output segment + * d_data[block_offset] = thread_data; + * } + * \endcode + * \par + * Suppose the input \p d_data is 0, -1, 2, -3, 4, -5, .... + * The corresponding output for the first segment will be 0, 0, 2, 2, ..., 126, 126. + * The output for the second segment will be 128, 128, 130, 130, ..., 254, 254. + * + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + * \tparam BlockPrefixCallbackOp [inferred] Call-back functor type having member T operator()(T block_aggregate) + */ + template < + typename ScanOp, + typename BlockPrefixCallbackOp> + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item + T &output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan functor + BlockPrefixCallbackOp &block_prefix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a block-wide prefix to be applied to the logical input sequence. + { + InternalBlockScan(temp_storage).InclusiveScan(input, output, scan_op, block_prefix_callback_op); + } + + + //@} end member group + /******************************************************************//** + * \name Inclusive prefix scan operations (multiple data per thread) + *********************************************************************/ + //@{ + + + /** + * \brief Computes an inclusive block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. + * + * \par + * - Supports non-commutative scan operators. + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates an inclusive prefix max scan of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockScan for a 1D block of 128 threads on type int + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Collectively compute the block-wide inclusive prefix max scan + * BlockScan(temp_storage).InclusiveScan(thread_data, thread_data, cub::Max()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is { [0,-1,2,-3], [4,-5,6,-7], ..., [508,-509,510,-511] }. The + * corresponding output \p thread_data in those threads will be { [0,0,2,2], [4,4,6,6], ..., [508,508,510,510] }. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + */ + template < + int ITEMS_PER_THREAD, + typename ScanOp> + __device__ __forceinline__ void InclusiveScan( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) + ScanOp scan_op) ///< [in] Binary scan functor + { + if (ITEMS_PER_THREAD == 1) + { + InclusiveScan(input[0], output[0], scan_op); + } + else + { + // Reduce consecutive thread items in registers + T thread_prefix = internal::ThreadReduce(input, scan_op); + + // Exclusive thread block-scan + ExclusiveScan(thread_prefix, thread_prefix, scan_op); + + // Inclusive scan in registers with prefix as seed (first thread does not seed) + internal::ThreadScanInclusive(input, output, scan_op, thread_prefix, (linear_tid != 0)); + } + } + + + /** + * \brief Computes an inclusive block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. Also provides every thread with the block-wide \p block_aggregate of all inputs. + * + * \par + * - Supports non-commutative scan operators. + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates an inclusive prefix max scan of 512 integer items that + * are partitioned in a [blocked arrangement](index.html#sec5sec3) across 128 threads + * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockScan for a 1D block of 128 threads on type int + * typedef cub::BlockScan BlockScan; + * + * // Allocate shared memory for BlockScan + * __shared__ typename BlockScan::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Collectively compute the block-wide inclusive prefix max scan + * int block_aggregate; + * BlockScan(temp_storage).InclusiveScan(thread_data, thread_data, cub::Max(), block_aggregate); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is + * { [0,-1,2,-3], [4,-5,6,-7], ..., [508,-509,510,-511] }. + * The corresponding output \p thread_data in those threads will be + * { [0,0,2,2], [4,4,6,6], ..., [508,508,510,510] }. + * Furthermore the value \p 510 will be stored in \p block_aggregate for all threads. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + */ + template < + int ITEMS_PER_THREAD, + typename ScanOp> + __device__ __forceinline__ void InclusiveScan( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan functor + T &block_aggregate) ///< [out] block-wide aggregate reduction of input items + { + if (ITEMS_PER_THREAD == 1) + { + InclusiveScan(input[0], output[0], scan_op, block_aggregate); + } + else + { + // Reduce consecutive thread items in registers + T thread_prefix = internal::ThreadReduce(input, scan_op); + + // Exclusive thread block-scan (with no initial value) + ExclusiveScan(thread_prefix, thread_prefix, scan_op, block_aggregate); + + // Inclusive scan in registers with prefix as seed (first thread does not seed) + internal::ThreadScanInclusive(input, output, scan_op, thread_prefix, (linear_tid != 0)); + } + } + + + /** + * \brief Computes an inclusive block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_prefix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + * + * \par + * - The \p block_prefix_callback_op functor must implement a member function T operator()(T block_aggregate). + * The functor's input parameter \p block_aggregate is the same value also returned by the scan operation. + * The functor will be invoked by the first warp of threads in the block, however only the return value from + * lane0 is applied as the block-wide prefix. Can be stateful. + * - Supports non-commutative scan operators. + * - \blocked + * - \granularity + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates a single thread block that progressively + * computes an inclusive prefix max scan over multiple "tiles" of input using a + * prefix functor to maintain a running total between block-wide scans. Each tile consists + * of 128 integer items that are partitioned across 128 threads. + * \par + * \code + * #include // or equivalently + * + * // A stateful callback functor that maintains a running prefix to be applied + * // during consecutive scan operations. + * struct BlockPrefixCallbackOp + * { + * // Running prefix + * int running_total; + * + * // Constructor + * __device__ BlockPrefixCallbackOp(int running_total) : running_total(running_total) {} + * + * // Callback operator to be entered by the first warp of threads in the block. + * // Thread-0 is responsible for returning a value for seeding the block-wide scan. + * __device__ int operator()(int block_aggregate) + * { + * int old_prefix = running_total; + * running_total = (block_aggregate > old_prefix) ? block_aggregate : old_prefix; + * return old_prefix; + * } + * }; + * + * __global__ void ExampleKernel(int *d_data, int num_items, ...) + * { + * // Specialize BlockLoad, BlockStore, and BlockScan for a 1D block of 128 threads, 4 ints per thread + * typedef cub::BlockLoad BlockLoad; + * typedef cub::BlockStore BlockStore; + * typedef cub::BlockScan BlockScan; + * + * // Allocate aliased shared memory for BlockLoad, BlockStore, and BlockScan + * __shared__ union { + * typename BlockLoad::TempStorage load; + * typename BlockScan::TempStorage scan; + * typename BlockStore::TempStorage store; + * } temp_storage; + * + * // Initialize running total + * BlockPrefixCallbackOp prefix_op(0); + * + * // Have the block iterate over segments of items + * for (int block_offset = 0; block_offset < num_items; block_offset += 128 * 4) + * { + * // Load a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * BlockLoad(temp_storage.load).Load(d_data + block_offset, thread_data); + * CTA_SYNC(); + * + * // Collectively compute the block-wide inclusive prefix max scan + * BlockScan(temp_storage.scan).InclusiveScan( + * thread_data, thread_data, cub::Max(), prefix_op); + * CTA_SYNC(); + * + * // Store scanned items to output segment + * BlockStore(temp_storage.store).Store(d_data + block_offset, thread_data); + * CTA_SYNC(); + * } + * \endcode + * \par + * Suppose the input \p d_data is 0, -1, 2, -3, 4, -5, .... + * The corresponding output for the first segment will be 0, 0, 2, 2, 4, 4, ..., 510, 510. + * The output for the second segment will be 512, 512, 514, 514, 516, 516, ..., 1022, 1022. + * + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + * \tparam BlockPrefixCallbackOp [inferred] Call-back functor type having member T operator()(T block_aggregate) + */ + template < + int ITEMS_PER_THREAD, + typename ScanOp, + typename BlockPrefixCallbackOp> + __device__ __forceinline__ void InclusiveScan( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan functor + BlockPrefixCallbackOp &block_prefix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a block-wide prefix to be applied to the logical input sequence. + { + if (ITEMS_PER_THREAD == 1) + { + InclusiveScan(input[0], output[0], scan_op, block_prefix_callback_op); + } + else + { + // Reduce consecutive thread items in registers + T thread_prefix = internal::ThreadReduce(input, scan_op); + + // Exclusive thread block-scan + ExclusiveScan(thread_prefix, thread_prefix, scan_op, block_prefix_callback_op); + + // Inclusive scan in registers with prefix as seed + internal::ThreadScanInclusive(input, output, scan_op, thread_prefix); + } + } + + //@} end member group + + +}; + +/** + * \example example_block_scan.cu + */ + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/block/block_shuffle.cuh b/fastertransformer/cuda/cub/block/block_shuffle.cuh new file mode 100644 index 000000000..a0cc71d22 --- /dev/null +++ b/fastertransformer/cuda/cub/block/block_shuffle.cuh @@ -0,0 +1,305 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * The cub::BlockShuffle class provides [collective](index.html#sec0) methods for shuffling data partitioned across a CUDA thread block. + */ + +#pragma once + +#include "../util_arch.cuh" +#include "../util_ptx.cuh" +#include "../util_macro.cuh" +#include "../util_type.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \brief The BlockShuffle class provides [collective](index.html#sec0) methods for shuffling data partitioned across a CUDA thread block. + * \ingroup BlockModule + * + * \tparam T The data type to be exchanged. + * \tparam BLOCK_DIM_X The thread block length in threads along the X dimension + * \tparam BLOCK_DIM_Y [optional] The thread block length in threads along the Y dimension (default: 1) + * \tparam BLOCK_DIM_Z [optional] The thread block length in threads along the Z dimension (default: 1) + * \tparam PTX_ARCH [optional] \ptxversion + * + * \par Overview + * It is commonplace for blocks of threads to rearrange data items between + * threads. The BlockShuffle abstraction allows threads to efficiently shift items + * either (a) up to their successor or (b) down to their predecessor. + * + */ +template < + typename T, + int BLOCK_DIM_X, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int PTX_ARCH = CUB_PTX_ARCH> +class BlockShuffle +{ +private: + + /****************************************************************************** + * Constants + ******************************************************************************/ + + enum + { + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + + LOG_WARP_THREADS = CUB_LOG_WARP_THREADS(PTX_ARCH), + WARP_THREADS = 1 << LOG_WARP_THREADS, + WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, + }; + + /****************************************************************************** + * Type definitions + ******************************************************************************/ + + /// Shared memory storage layout type (last element from each thread's input) + struct _TempStorage + { + T prev[BLOCK_THREADS]; + T next[BLOCK_THREADS]; + }; + + +public: + + /// \smemstorage{BlockShuffle} + struct TempStorage : Uninitialized<_TempStorage> {}; + +private: + + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + /// Shared storage reference + _TempStorage &temp_storage; + + /// Linear thread-id + unsigned int linear_tid; + + + /****************************************************************************** + * Utility methods + ******************************************************************************/ + + /// Internal storage allocator + __device__ __forceinline__ _TempStorage& PrivateStorage() + { + __shared__ _TempStorage private_storage; + return private_storage; + } + + +public: + + /******************************************************************//** + * \name Collective constructors + *********************************************************************/ + //@{ + + /** + * \brief Collective constructor using a private static allocation of shared memory as temporary storage. + */ + __device__ __forceinline__ BlockShuffle() + : + temp_storage(PrivateStorage()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + /** + * \brief Collective constructor using the specified memory allocation as temporary storage. + */ + __device__ __forceinline__ BlockShuffle( + TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + //@} end member group + /******************************************************************//** + * \name Shuffle movement + *********************************************************************/ + //@{ + + + /** + * \brief Each threadi obtains the \p input provided by threadi+distance. The offset \p distance may be negative. + * + * \par + * - \smemreuse + */ + __device__ __forceinline__ void Offset( + T input, ///< [in] The input item from the calling thread (threadi) + T& output, ///< [out] The \p input item from the successor (or predecessor) thread threadi+distance (may be aliased to \p input). This value is only updated for for threadi when 0 <= (i + \p distance) < BLOCK_THREADS-1 + int distance = 1) ///< [in] Offset distance (may be negative) + { + temp_storage[linear_tid].prev = input; + + CTA_SYNC(); + + if ((linear_tid + distance >= 0) && (linear_tid + distance < BLOCK_THREADS)) + output = temp_storage[linear_tid + distance].prev; + } + + + /** + * \brief Each threadi obtains the \p input provided by threadi+distance. + * + * \par + * - \smemreuse + */ + __device__ __forceinline__ void Rotate( + T input, ///< [in] The calling thread's input item + T& output, ///< [out] The \p input item from thread thread(i+distance>)% (may be aliased to \p input). This value is not updated for threadBLOCK_THREADS-1 + unsigned int distance = 1) ///< [in] Offset distance (0 < \p distance < BLOCK_THREADS) + { + temp_storage[linear_tid].prev = input; + + CTA_SYNC(); + + unsigned int offset = threadIdx.x + distance; + if (offset >= BLOCK_THREADS) + offset -= BLOCK_THREADS; + + output = temp_storage[offset].prev; + } + + + /** + * \brief The thread block rotates its [blocked arrangement](index.html#sec5sec3) of \p input items, shifting it up by one item + * + * \par + * - \blocked + * - \granularity + * - \smemreuse + */ + template + __device__ __forceinline__ void Up( + T (&input)[ITEMS_PER_THREAD], ///< [in] The calling thread's input items + T (&prev)[ITEMS_PER_THREAD]) ///< [out] The corresponding predecessor items (may be aliased to \p input). The item \p prev[0] is not updated for thread0. + { + temp_storage[linear_tid].prev = input[ITEMS_PER_THREAD - 1]; + + CTA_SYNC(); + + #pragma unroll + for (int ITEM = ITEMS_PER_THREAD - 1; ITEM > 0; --ITEM) + prev[ITEM] = input[ITEM - 1]; + + + if (linear_tid > 0) + prev[0] = temp_storage[linear_tid - 1].prev; + } + + + /** + * \brief The thread block rotates its [blocked arrangement](index.html#sec5sec3) of \p input items, shifting it up by one item. All threads receive the \p input provided by threadBLOCK_THREADS-1. + * + * \par + * - \blocked + * - \granularity + * - \smemreuse + */ + template + __device__ __forceinline__ void Up( + T (&input)[ITEMS_PER_THREAD], ///< [in] The calling thread's input items + T (&prev)[ITEMS_PER_THREAD], ///< [out] The corresponding predecessor items (may be aliased to \p input). The item \p prev[0] is not updated for thread0. + T &block_suffix) ///< [out] The item \p input[ITEMS_PER_THREAD-1] from threadBLOCK_THREADS-1, provided to all threads + { + Up(input, prev); + block_suffix = temp_storage[BLOCK_THREADS - 1].prev; + } + + + /** + * \brief The thread block rotates its [blocked arrangement](index.html#sec5sec3) of \p input items, shifting it down by one item + * + * \par + * - \blocked + * - \granularity + * - \smemreuse + */ + template + __device__ __forceinline__ void Down( + T (&input)[ITEMS_PER_THREAD], ///< [in] The calling thread's input items + T (&prev)[ITEMS_PER_THREAD]) ///< [out] The corresponding predecessor items (may be aliased to \p input). The value \p prev[0] is not updated for threadBLOCK_THREADS-1. + { + temp_storage[linear_tid].prev = input[ITEMS_PER_THREAD - 1]; + + CTA_SYNC(); + + #pragma unroll + for (int ITEM = ITEMS_PER_THREAD - 1; ITEM > 0; --ITEM) + prev[ITEM] = input[ITEM - 1]; + + if (linear_tid > 0) + prev[0] = temp_storage[linear_tid - 1].prev; + } + + + /** + * \brief The thread block rotates its [blocked arrangement](index.html#sec5sec3) of input items, shifting it down by one item. All threads receive \p input[0] provided by thread0. + * + * \par + * - \blocked + * - \granularity + * - \smemreuse + */ + template + __device__ __forceinline__ void Down( + T (&input)[ITEMS_PER_THREAD], ///< [in] The calling thread's input items + T (&prev)[ITEMS_PER_THREAD], ///< [out] The corresponding predecessor items (may be aliased to \p input). The value \p prev[0] is not updated for threadBLOCK_THREADS-1. + T &block_prefix) ///< [out] The item \p input[0] from thread0, provided to all threads + { + Up(input, prev); + block_prefix = temp_storage[BLOCK_THREADS - 1].prev; + } + + //@} end member group + + +}; + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/block/block_store.cuh b/fastertransformer/cuda/cub/block/block_store.cuh new file mode 100644 index 000000000..648bf9ff4 --- /dev/null +++ b/fastertransformer/cuda/cub/block/block_store.cuh @@ -0,0 +1,1000 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Operations for writing linear segments of data from the CUDA thread block + */ + +#pragma once + +#include + +#include "block_exchange.cuh" +#include "../util_ptx.cuh" +#include "../util_macro.cuh" +#include "../util_type.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \addtogroup UtilIo + * @{ + */ + + +/******************************************************************//** + * \name Blocked arrangement I/O (direct) + *********************************************************************/ +//@{ + +/** + * \brief Store a blocked arrangement of items across a thread block into a linear segment of items. + * + * \blocked + * + * \tparam T [inferred] The data type to store. + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam OutputIteratorT [inferred] The random-access iterator type for output \iterator. + */ +template < + typename T, + int ITEMS_PER_THREAD, + typename OutputIteratorT> +__device__ __forceinline__ void StoreDirectBlocked( + int linear_tid, ///< [in] A suitable 1D thread-identifier for the calling thread (e.g., (threadIdx.y * blockDim.x) + linear_tid for 2D thread blocks) + OutputIteratorT block_itr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD]) ///< [in] Data to store +{ + OutputIteratorT thread_itr = block_itr + (linear_tid * ITEMS_PER_THREAD); + + // Store directly in thread-blocked order + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + thread_itr[ITEM] = items[ITEM]; + } +} + + +/** + * \brief Store a blocked arrangement of items across a thread block into a linear segment of items, guarded by range + * + * \blocked + * + * \tparam T [inferred] The data type to store. + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam OutputIteratorT [inferred] The random-access iterator type for output \iterator. + */ +template < + typename T, + int ITEMS_PER_THREAD, + typename OutputIteratorT> +__device__ __forceinline__ void StoreDirectBlocked( + int linear_tid, ///< [in] A suitable 1D thread-identifier for the calling thread (e.g., (threadIdx.y * blockDim.x) + linear_tid for 2D thread blocks) + OutputIteratorT block_itr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD], ///< [in] Data to store + int valid_items) ///< [in] Number of valid items to write +{ + OutputIteratorT thread_itr = block_itr + (linear_tid * ITEMS_PER_THREAD); + + // Store directly in thread-blocked order + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + if (ITEM + (linear_tid * ITEMS_PER_THREAD) < valid_items) + { + thread_itr[ITEM] = items[ITEM]; + } + } +} + + +/** + * \brief Store a blocked arrangement of items across a thread block into a linear segment of items. + * + * \blocked + * + * The output offset (\p block_ptr + \p block_offset) must be quad-item aligned, + * which is the default starting offset returned by \p cudaMalloc() + * + * \par + * The following conditions will prevent vectorization and storing will fall back to cub::BLOCK_STORE_DIRECT: + * - \p ITEMS_PER_THREAD is odd + * - The data type \p T is not a built-in primitive or CUDA vector type (e.g., \p short, \p int2, \p double, \p float2, etc.) + * + * \tparam T [inferred] The data type to store. + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * + */ +template < + typename T, + int ITEMS_PER_THREAD> +__device__ __forceinline__ void StoreDirectBlockedVectorized( + int linear_tid, ///< [in] A suitable 1D thread-identifier for the calling thread (e.g., (threadIdx.y * blockDim.x) + linear_tid for 2D thread blocks) + T *block_ptr, ///< [in] Input pointer for storing from + T (&items)[ITEMS_PER_THREAD]) ///< [in] Data to store +{ + enum + { + // Maximum CUDA vector size is 4 elements + MAX_VEC_SIZE = CUB_MIN(4, ITEMS_PER_THREAD), + + // Vector size must be a power of two and an even divisor of the items per thread + VEC_SIZE = ((((MAX_VEC_SIZE - 1) & MAX_VEC_SIZE) == 0) && ((ITEMS_PER_THREAD % MAX_VEC_SIZE) == 0)) ? + MAX_VEC_SIZE : + 1, + + VECTORS_PER_THREAD = ITEMS_PER_THREAD / VEC_SIZE, + }; + + // Vector type + typedef typename CubVector::Type Vector; + + // Alias global pointer + Vector *block_ptr_vectors = reinterpret_cast(const_cast(block_ptr)); + + // Alias pointers (use "raw" array here which should get optimized away to prevent conservative PTXAS lmem spilling) + Vector raw_vector[VECTORS_PER_THREAD]; + T *raw_items = reinterpret_cast(raw_vector); + + // Copy + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + raw_items[ITEM] = items[ITEM]; + } + + // Direct-store using vector types + StoreDirectBlocked(linear_tid, block_ptr_vectors, raw_vector); +} + + + +//@} end member group +/******************************************************************//** + * \name Striped arrangement I/O (direct) + *********************************************************************/ +//@{ + + +/** + * \brief Store a striped arrangement of data across the thread block into a linear segment of items. + * + * \striped + * + * \tparam BLOCK_THREADS The thread block size in threads + * \tparam T [inferred] The data type to store. + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam OutputIteratorT [inferred] The random-access iterator type for output \iterator. + */ +template < + int BLOCK_THREADS, + typename T, + int ITEMS_PER_THREAD, + typename OutputIteratorT> +__device__ __forceinline__ void StoreDirectStriped( + int linear_tid, ///< [in] A suitable 1D thread-identifier for the calling thread (e.g., (threadIdx.y * blockDim.x) + linear_tid for 2D thread blocks) + OutputIteratorT block_itr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD]) ///< [in] Data to store +{ + OutputIteratorT thread_itr = block_itr + linear_tid; + + // Store directly in striped order + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + thread_itr[(ITEM * BLOCK_THREADS)] = items[ITEM]; + } +} + + +/** + * \brief Store a striped arrangement of data across the thread block into a linear segment of items, guarded by range + * + * \striped + * + * \tparam BLOCK_THREADS The thread block size in threads + * \tparam T [inferred] The data type to store. + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam OutputIteratorT [inferred] The random-access iterator type for output \iterator. + */ +template < + int BLOCK_THREADS, + typename T, + int ITEMS_PER_THREAD, + typename OutputIteratorT> +__device__ __forceinline__ void StoreDirectStriped( + int linear_tid, ///< [in] A suitable 1D thread-identifier for the calling thread (e.g., (threadIdx.y * blockDim.x) + linear_tid for 2D thread blocks) + OutputIteratorT block_itr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD], ///< [in] Data to store + int valid_items) ///< [in] Number of valid items to write +{ + OutputIteratorT thread_itr = block_itr + linear_tid; + + // Store directly in striped order + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + if ((ITEM * BLOCK_THREADS) + linear_tid < valid_items) + { + thread_itr[(ITEM * BLOCK_THREADS)] = items[ITEM]; + } + } +} + + + +//@} end member group +/******************************************************************//** + * \name Warp-striped arrangement I/O (direct) + *********************************************************************/ +//@{ + + +/** + * \brief Store a warp-striped arrangement of data across the thread block into a linear segment of items. + * + * \warpstriped + * + * \par Usage Considerations + * The number of threads in the thread block must be a multiple of the architecture's warp size. + * + * \tparam T [inferred] The data type to store. + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam OutputIteratorT [inferred] The random-access iterator type for output \iterator. + */ +template < + typename T, + int ITEMS_PER_THREAD, + typename OutputIteratorT> +__device__ __forceinline__ void StoreDirectWarpStriped( + int linear_tid, ///< [in] A suitable 1D thread-identifier for the calling thread (e.g., (threadIdx.y * blockDim.x) + linear_tid for 2D thread blocks) + OutputIteratorT block_itr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD]) ///< [out] Data to load +{ + int tid = linear_tid & (CUB_PTX_WARP_THREADS - 1); + int wid = linear_tid >> CUB_PTX_LOG_WARP_THREADS; + int warp_offset = wid * CUB_PTX_WARP_THREADS * ITEMS_PER_THREAD; + + OutputIteratorT thread_itr = block_itr + warp_offset + tid; + + // Store directly in warp-striped order + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + thread_itr[(ITEM * CUB_PTX_WARP_THREADS)] = items[ITEM]; + } +} + + +/** + * \brief Store a warp-striped arrangement of data across the thread block into a linear segment of items, guarded by range + * + * \warpstriped + * + * \par Usage Considerations + * The number of threads in the thread block must be a multiple of the architecture's warp size. + * + * \tparam T [inferred] The data type to store. + * \tparam ITEMS_PER_THREAD [inferred] The number of consecutive items partitioned onto each thread. + * \tparam OutputIteratorT [inferred] The random-access iterator type for output \iterator. + */ +template < + typename T, + int ITEMS_PER_THREAD, + typename OutputIteratorT> +__device__ __forceinline__ void StoreDirectWarpStriped( + int linear_tid, ///< [in] A suitable 1D thread-identifier for the calling thread (e.g., (threadIdx.y * blockDim.x) + linear_tid for 2D thread blocks) + OutputIteratorT block_itr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD], ///< [in] Data to store + int valid_items) ///< [in] Number of valid items to write +{ + int tid = linear_tid & (CUB_PTX_WARP_THREADS - 1); + int wid = linear_tid >> CUB_PTX_LOG_WARP_THREADS; + int warp_offset = wid * CUB_PTX_WARP_THREADS * ITEMS_PER_THREAD; + + OutputIteratorT thread_itr = block_itr + warp_offset + tid; + + // Store directly in warp-striped order + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) + { + if (warp_offset + tid + (ITEM * CUB_PTX_WARP_THREADS) < valid_items) + { + thread_itr[(ITEM * CUB_PTX_WARP_THREADS)] = items[ITEM]; + } + } +} + + +//@} end member group + + +/** @} */ // end group UtilIo + + +//----------------------------------------------------------------------------- +// Generic BlockStore abstraction +//----------------------------------------------------------------------------- + +/** + * \brief cub::BlockStoreAlgorithm enumerates alternative algorithms for cub::BlockStore to write a blocked arrangement of items across a CUDA thread block to a linear segment of memory. + */ +enum BlockStoreAlgorithm +{ + /** + * \par Overview + * + * A [blocked arrangement](index.html#sec5sec3) of data is written + * directly to memory. + * + * \par Performance Considerations + * - The utilization of memory transactions (coalescing) decreases as the + * access stride between threads increases (i.e., the number items per thread). + */ + BLOCK_STORE_DIRECT, + + /** + * \par Overview + * + * A [blocked arrangement](index.html#sec5sec3) of data is written directly + * to memory using CUDA's built-in vectorized stores as a coalescing optimization. + * For example, st.global.v4.s32 instructions will be generated + * when \p T = \p int and \p ITEMS_PER_THREAD % 4 == 0. + * + * \par Performance Considerations + * - The utilization of memory transactions (coalescing) remains high until the the + * access stride between threads (i.e., the number items per thread) exceeds the + * maximum vector store width (typically 4 items or 64B, whichever is lower). + * - The following conditions will prevent vectorization and writing will fall back to cub::BLOCK_STORE_DIRECT: + * - \p ITEMS_PER_THREAD is odd + * - The \p OutputIteratorT is not a simple pointer type + * - The block output offset is not quadword-aligned + * - The data type \p T is not a built-in primitive or CUDA vector type (e.g., \p short, \p int2, \p double, \p float2, etc.) + */ + BLOCK_STORE_VECTORIZE, + + /** + * \par Overview + * A [blocked arrangement](index.html#sec5sec3) is locally + * transposed and then efficiently written to memory as a [striped arrangement](index.html#sec5sec3). + * + * \par Performance Considerations + * - The utilization of memory transactions (coalescing) remains high regardless + * of items written per thread. + * - The local reordering incurs slightly longer latencies and throughput than the + * direct cub::BLOCK_STORE_DIRECT and cub::BLOCK_STORE_VECTORIZE alternatives. + */ + BLOCK_STORE_TRANSPOSE, + + /** + * \par Overview + * A [blocked arrangement](index.html#sec5sec3) is locally + * transposed and then efficiently written to memory as a + * [warp-striped arrangement](index.html#sec5sec3) + * + * \par Usage Considerations + * - BLOCK_THREADS must be a multiple of WARP_THREADS + * + * \par Performance Considerations + * - The utilization of memory transactions (coalescing) remains high regardless + * of items written per thread. + * - The local reordering incurs slightly longer latencies and throughput than the + * direct cub::BLOCK_STORE_DIRECT and cub::BLOCK_STORE_VECTORIZE alternatives. + */ + BLOCK_STORE_WARP_TRANSPOSE, + + /** + * \par Overview + * A [blocked arrangement](index.html#sec5sec3) is locally + * transposed and then efficiently written to memory as a + * [warp-striped arrangement](index.html#sec5sec3) + * To reduce the shared memory requirement, only one warp's worth of shared + * memory is provisioned and is subsequently time-sliced among warps. + * + * \par Usage Considerations + * - BLOCK_THREADS must be a multiple of WARP_THREADS + * + * \par Performance Considerations + * - The utilization of memory transactions (coalescing) remains high regardless + * of items written per thread. + * - Provisions less shared memory temporary storage, but incurs larger + * latencies than the BLOCK_STORE_WARP_TRANSPOSE alternative. + */ + BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED, + +}; + + +/** + * \brief The BlockStore class provides [collective](index.html#sec0) data movement methods for writing a [blocked arrangement](index.html#sec5sec3) of items partitioned across a CUDA thread block to a linear segment of memory. ![](block_store_logo.png) + * \ingroup BlockModule + * \ingroup UtilIo + * + * \tparam T The type of data to be written. + * \tparam BLOCK_DIM_X The thread block length in threads along the X dimension + * \tparam ITEMS_PER_THREAD The number of consecutive items partitioned onto each thread. + * \tparam ALGORITHM [optional] cub::BlockStoreAlgorithm tuning policy enumeration. default: cub::BLOCK_STORE_DIRECT. + * \tparam WARP_TIME_SLICING [optional] Whether or not only one warp's worth of shared memory should be allocated and time-sliced among block-warps during any load-related data transpositions (versus each warp having its own storage). (default: false) + * \tparam BLOCK_DIM_Y [optional] The thread block length in threads along the Y dimension (default: 1) + * \tparam BLOCK_DIM_Z [optional] The thread block length in threads along the Z dimension (default: 1) + * \tparam PTX_ARCH [optional] \ptxversion + * + * \par Overview + * - The BlockStore class provides a single data movement abstraction that can be specialized + * to implement different cub::BlockStoreAlgorithm strategies. This facilitates different + * performance policies for different architectures, data types, granularity sizes, etc. + * - BlockStore can be optionally specialized by different data movement strategies: + * -# cub::BLOCK_STORE_DIRECT. A [blocked arrangement](index.html#sec5sec3) of data is written + * directly to memory. [More...](\ref cub::BlockStoreAlgorithm) + * -# cub::BLOCK_STORE_VECTORIZE. A [blocked arrangement](index.html#sec5sec3) + * of data is written directly to memory using CUDA's built-in vectorized stores as a + * coalescing optimization. [More...](\ref cub::BlockStoreAlgorithm) + * -# cub::BLOCK_STORE_TRANSPOSE. A [blocked arrangement](index.html#sec5sec3) + * is locally transposed into a [striped arrangement](index.html#sec5sec3) which is + * then written to memory. [More...](\ref cub::BlockStoreAlgorithm) + * -# cub::BLOCK_STORE_WARP_TRANSPOSE. A [blocked arrangement](index.html#sec5sec3) + * is locally transposed into a [warp-striped arrangement](index.html#sec5sec3) which is + * then written to memory. [More...](\ref cub::BlockStoreAlgorithm) + * - \rowmajor + * + * \par A Simple Example + * \blockcollective{BlockStore} + * \par + * The code snippet below illustrates the storing of a "blocked" arrangement + * of 512 integers across 128 threads (where each thread owns 4 consecutive items) + * into a linear segment of memory. The store is specialized for \p BLOCK_STORE_WARP_TRANSPOSE, + * meaning items are locally reordered among threads so that memory references will be + * efficiently coalesced using a warp-striped access pattern. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(int *d_data, ...) + * { + * // Specialize BlockStore for a 1D block of 128 threads owning 4 integer items each + * typedef cub::BlockStore BlockStore; + * + * // Allocate shared memory for BlockStore + * __shared__ typename BlockStore::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Store items to linear memory + * int thread_data[4]; + * BlockStore(temp_storage).Store(d_data, thread_data); + * + * \endcode + * \par + * Suppose the set of \p thread_data across the block of threads is + * { [0,1,2,3], [4,5,6,7], ..., [508,509,510,511] }. + * The output \p d_data will be 0, 1, 2, 3, 4, 5, .... + * + */ +template < + typename T, + int BLOCK_DIM_X, + int ITEMS_PER_THREAD, + BlockStoreAlgorithm ALGORITHM = BLOCK_STORE_DIRECT, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int PTX_ARCH = CUB_PTX_ARCH> +class BlockStore +{ +private: + /****************************************************************************** + * Constants and typed definitions + ******************************************************************************/ + + /// Constants + enum + { + /// The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + }; + + + /****************************************************************************** + * Algorithmic variants + ******************************************************************************/ + + /// Store helper + template + struct StoreInternal; + + + /** + * BLOCK_STORE_DIRECT specialization of store helper + */ + template + struct StoreInternal + { + /// Shared memory storage layout type + typedef NullType TempStorage; + + /// Linear thread-id + int linear_tid; + + /// Constructor + __device__ __forceinline__ StoreInternal( + TempStorage &/*temp_storage*/, + int linear_tid) + : + linear_tid(linear_tid) + {} + + /// Store items into a linear segment of memory + template + __device__ __forceinline__ void Store( + OutputIteratorT block_itr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD]) ///< [in] Data to store + { + StoreDirectBlocked(linear_tid, block_itr, items); + } + + /// Store items into a linear segment of memory, guarded by range + template + __device__ __forceinline__ void Store( + OutputIteratorT block_itr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD], ///< [in] Data to store + int valid_items) ///< [in] Number of valid items to write + { + StoreDirectBlocked(linear_tid, block_itr, items, valid_items); + } + }; + + + /** + * BLOCK_STORE_VECTORIZE specialization of store helper + */ + template + struct StoreInternal + { + /// Shared memory storage layout type + typedef NullType TempStorage; + + /// Linear thread-id + int linear_tid; + + /// Constructor + __device__ __forceinline__ StoreInternal( + TempStorage &/*temp_storage*/, + int linear_tid) + : + linear_tid(linear_tid) + {} + + /// Store items into a linear segment of memory, specialized for native pointer types (attempts vectorization) + __device__ __forceinline__ void Store( + T *block_ptr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD]) ///< [in] Data to store + { + StoreDirectBlockedVectorized(linear_tid, block_ptr, items); + } + + /// Store items into a linear segment of memory, specialized for opaque input iterators (skips vectorization) + template + __device__ __forceinline__ void Store( + OutputIteratorT block_itr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD]) ///< [in] Data to store + { + StoreDirectBlocked(linear_tid, block_itr, items); + } + + /// Store items into a linear segment of memory, guarded by range + template + __device__ __forceinline__ void Store( + OutputIteratorT block_itr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD], ///< [in] Data to store + int valid_items) ///< [in] Number of valid items to write + { + StoreDirectBlocked(linear_tid, block_itr, items, valid_items); + } + }; + + + /** + * BLOCK_STORE_TRANSPOSE specialization of store helper + */ + template + struct StoreInternal + { + // BlockExchange utility type for keys + typedef BlockExchange BlockExchange; + + /// Shared memory storage layout type + struct _TempStorage : BlockExchange::TempStorage + { + /// Temporary storage for partially-full block guard + volatile int valid_items; + }; + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + /// Thread reference to shared storage + _TempStorage &temp_storage; + + /// Linear thread-id + int linear_tid; + + /// Constructor + __device__ __forceinline__ StoreInternal( + TempStorage &temp_storage, + int linear_tid) + : + temp_storage(temp_storage.Alias()), + linear_tid(linear_tid) + {} + + /// Store items into a linear segment of memory + template + __device__ __forceinline__ void Store( + OutputIteratorT block_itr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD]) ///< [in] Data to store + { + BlockExchange(temp_storage).BlockedToStriped(items); + StoreDirectStriped(linear_tid, block_itr, items); + } + + /// Store items into a linear segment of memory, guarded by range + template + __device__ __forceinline__ void Store( + OutputIteratorT block_itr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD], ///< [in] Data to store + int valid_items) ///< [in] Number of valid items to write + { + BlockExchange(temp_storage).BlockedToStriped(items); + if (linear_tid == 0) + temp_storage.valid_items = valid_items; // Move through volatile smem as a workaround to prevent RF spilling on subsequent loads + CTA_SYNC(); + StoreDirectStriped(linear_tid, block_itr, items, temp_storage.valid_items); + } + }; + + + /** + * BLOCK_STORE_WARP_TRANSPOSE specialization of store helper + */ + template + struct StoreInternal + { + enum + { + WARP_THREADS = CUB_WARP_THREADS(PTX_ARCH) + }; + + // Assert BLOCK_THREADS must be a multiple of WARP_THREADS + CUB_STATIC_ASSERT((BLOCK_THREADS % WARP_THREADS == 0), "BLOCK_THREADS must be a multiple of WARP_THREADS"); + + // BlockExchange utility type for keys + typedef BlockExchange BlockExchange; + + /// Shared memory storage layout type + struct _TempStorage : BlockExchange::TempStorage + { + /// Temporary storage for partially-full block guard + volatile int valid_items; + }; + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + /// Thread reference to shared storage + _TempStorage &temp_storage; + + /// Linear thread-id + int linear_tid; + + /// Constructor + __device__ __forceinline__ StoreInternal( + TempStorage &temp_storage, + int linear_tid) + : + temp_storage(temp_storage.Alias()), + linear_tid(linear_tid) + {} + + /// Store items into a linear segment of memory + template + __device__ __forceinline__ void Store( + OutputIteratorT block_itr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD]) ///< [in] Data to store + { + BlockExchange(temp_storage).BlockedToWarpStriped(items); + StoreDirectWarpStriped(linear_tid, block_itr, items); + } + + /// Store items into a linear segment of memory, guarded by range + template + __device__ __forceinline__ void Store( + OutputIteratorT block_itr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD], ///< [in] Data to store + int valid_items) ///< [in] Number of valid items to write + { + BlockExchange(temp_storage).BlockedToWarpStriped(items); + if (linear_tid == 0) + temp_storage.valid_items = valid_items; // Move through volatile smem as a workaround to prevent RF spilling on subsequent loads + CTA_SYNC(); + StoreDirectWarpStriped(linear_tid, block_itr, items, temp_storage.valid_items); + } + }; + + + /** + * BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED specialization of store helper + */ + template + struct StoreInternal + { + enum + { + WARP_THREADS = CUB_WARP_THREADS(PTX_ARCH) + }; + + // Assert BLOCK_THREADS must be a multiple of WARP_THREADS + CUB_STATIC_ASSERT((BLOCK_THREADS % WARP_THREADS == 0), "BLOCK_THREADS must be a multiple of WARP_THREADS"); + + // BlockExchange utility type for keys + typedef BlockExchange BlockExchange; + + /// Shared memory storage layout type + struct _TempStorage : BlockExchange::TempStorage + { + /// Temporary storage for partially-full block guard + volatile int valid_items; + }; + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + /// Thread reference to shared storage + _TempStorage &temp_storage; + + /// Linear thread-id + int linear_tid; + + /// Constructor + __device__ __forceinline__ StoreInternal( + TempStorage &temp_storage, + int linear_tid) + : + temp_storage(temp_storage.Alias()), + linear_tid(linear_tid) + {} + + /// Store items into a linear segment of memory + template + __device__ __forceinline__ void Store( + OutputIteratorT block_itr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD]) ///< [in] Data to store + { + BlockExchange(temp_storage).BlockedToWarpStriped(items); + StoreDirectWarpStriped(linear_tid, block_itr, items); + } + + /// Store items into a linear segment of memory, guarded by range + template + __device__ __forceinline__ void Store( + OutputIteratorT block_itr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD], ///< [in] Data to store + int valid_items) ///< [in] Number of valid items to write + { + BlockExchange(temp_storage).BlockedToWarpStriped(items); + if (linear_tid == 0) + temp_storage.valid_items = valid_items; // Move through volatile smem as a workaround to prevent RF spilling on subsequent loads + CTA_SYNC(); + StoreDirectWarpStriped(linear_tid, block_itr, items, temp_storage.valid_items); + } + }; + + /****************************************************************************** + * Type definitions + ******************************************************************************/ + + /// Internal load implementation to use + typedef StoreInternal InternalStore; + + + /// Shared memory storage layout type + typedef typename InternalStore::TempStorage _TempStorage; + + + /****************************************************************************** + * Utility methods + ******************************************************************************/ + + /// Internal storage allocator + __device__ __forceinline__ _TempStorage& PrivateStorage() + { + __shared__ _TempStorage private_storage; + return private_storage; + } + + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + /// Thread reference to shared storage + _TempStorage &temp_storage; + + /// Linear thread-id + int linear_tid; + +public: + + + /// \smemstorage{BlockStore} + struct TempStorage : Uninitialized<_TempStorage> {}; + + + /******************************************************************//** + * \name Collective constructors + *********************************************************************/ + //@{ + + /** + * \brief Collective constructor using a private static allocation of shared memory as temporary storage. + */ + __device__ __forceinline__ BlockStore() + : + temp_storage(PrivateStorage()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + /** + * \brief Collective constructor using the specified memory allocation as temporary storage. + */ + __device__ __forceinline__ BlockStore( + TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + //@} end member group + /******************************************************************//** + * \name Data movement + *********************************************************************/ + //@{ + + + /** + * \brief Store items into a linear segment of memory. + * + * \par + * - \blocked + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates the storing of a "blocked" arrangement + * of 512 integers across 128 threads (where each thread owns 4 consecutive items) + * into a linear segment of memory. The store is specialized for \p BLOCK_STORE_WARP_TRANSPOSE, + * meaning items are locally reordered among threads so that memory references will be + * efficiently coalesced using a warp-striped access pattern. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(int *d_data, ...) + * { + * // Specialize BlockStore for a 1D block of 128 threads owning 4 integer items each + * typedef cub::BlockStore BlockStore; + * + * // Allocate shared memory for BlockStore + * __shared__ typename BlockStore::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Store items to linear memory + * int thread_data[4]; + * BlockStore(temp_storage).Store(d_data, thread_data); + * + * \endcode + * \par + * Suppose the set of \p thread_data across the block of threads is + * { [0,1,2,3], [4,5,6,7], ..., [508,509,510,511] }. + * The output \p d_data will be 0, 1, 2, 3, 4, 5, .... + * + */ + template + __device__ __forceinline__ void Store( + OutputIteratorT block_itr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD]) ///< [in] Data to store + { + InternalStore(temp_storage, linear_tid).Store(block_itr, items); + } + + /** + * \brief Store items into a linear segment of memory, guarded by range. + * + * \par + * - \blocked + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates the guarded storing of a "blocked" arrangement + * of 512 integers across 128 threads (where each thread owns 4 consecutive items) + * into a linear segment of memory. The store is specialized for \p BLOCK_STORE_WARP_TRANSPOSE, + * meaning items are locally reordered among threads so that memory references will be + * efficiently coalesced using a warp-striped access pattern. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(int *d_data, int valid_items, ...) + * { + * // Specialize BlockStore for a 1D block of 128 threads owning 4 integer items each + * typedef cub::BlockStore BlockStore; + * + * // Allocate shared memory for BlockStore + * __shared__ typename BlockStore::TempStorage temp_storage; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_data[4]; + * ... + * + * // Store items to linear memory + * int thread_data[4]; + * BlockStore(temp_storage).Store(d_data, thread_data, valid_items); + * + * \endcode + * \par + * Suppose the set of \p thread_data across the block of threads is + * { [0,1,2,3], [4,5,6,7], ..., [508,509,510,511] } and \p valid_items is \p 5. + * The output \p d_data will be 0, 1, 2, 3, 4, ?, ?, ?, ..., with + * only the first two threads being unmasked to store portions of valid data. + * + */ + template + __device__ __forceinline__ void Store( + OutputIteratorT block_itr, ///< [in] The thread block's base output iterator for storing to + T (&items)[ITEMS_PER_THREAD], ///< [in] Data to store + int valid_items) ///< [in] Number of valid items to write + { + InternalStore(temp_storage, linear_tid).Store(block_itr, items, valid_items); + } +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/block/specializations/block_histogram_atomic.cuh b/fastertransformer/cuda/cub/block/specializations/block_histogram_atomic.cuh new file mode 100644 index 000000000..29db0df71 --- /dev/null +++ b/fastertransformer/cuda/cub/block/specializations/block_histogram_atomic.cuh @@ -0,0 +1,82 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * The cub::BlockHistogramAtomic class provides atomic-based methods for constructing block-wide histograms from data samples partitioned across a CUDA thread block. + */ + +#pragma once + +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \brief The BlockHistogramAtomic class provides atomic-based methods for constructing block-wide histograms from data samples partitioned across a CUDA thread block. + */ +template +struct BlockHistogramAtomic +{ + /// Shared memory storage layout type + struct TempStorage {}; + + + /// Constructor + __device__ __forceinline__ BlockHistogramAtomic( + TempStorage &temp_storage) + {} + + + /// Composite data onto an existing histogram + template < + typename T, + typename CounterT, + int ITEMS_PER_THREAD> + __device__ __forceinline__ void Composite( + T (&items)[ITEMS_PER_THREAD], ///< [in] Calling thread's input values to histogram + CounterT histogram[BINS]) ///< [out] Reference to shared/device-accessible memory histogram + { + // Update histogram + #pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; ++i) + { + atomicAdd(histogram + items[i], 1); + } + } + +}; + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/block/specializations/block_histogram_sort.cuh b/fastertransformer/cuda/cub/block/specializations/block_histogram_sort.cuh new file mode 100644 index 000000000..9ef417adc --- /dev/null +++ b/fastertransformer/cuda/cub/block/specializations/block_histogram_sort.cuh @@ -0,0 +1,226 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * The cub::BlockHistogramSort class provides sorting-based methods for constructing block-wide histograms from data samples partitioned across a CUDA thread block. + */ + +#pragma once + +#include "../../block/block_radix_sort.cuh" +#include "../../block/block_discontinuity.cuh" +#include "../../util_ptx.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + + +/** + * \brief The BlockHistogramSort class provides sorting-based methods for constructing block-wide histograms from data samples partitioned across a CUDA thread block. + */ +template < + typename T, ///< Sample type + int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension + int ITEMS_PER_THREAD, ///< The number of samples per thread + int BINS, ///< The number of bins into which histogram samples may fall + int BLOCK_DIM_Y, ///< The thread block length in threads along the Y dimension + int BLOCK_DIM_Z, ///< The thread block length in threads along the Z dimension + int PTX_ARCH> ///< The PTX compute capability for which to to specialize this collective +struct BlockHistogramSort +{ + /// Constants + enum + { + /// The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + }; + + // Parameterize BlockRadixSort type for our thread block + typedef BlockRadixSort< + T, + BLOCK_DIM_X, + ITEMS_PER_THREAD, + NullType, + 4, + (PTX_ARCH >= 350) ? true : false, + BLOCK_SCAN_WARP_SCANS, + cudaSharedMemBankSizeFourByte, + BLOCK_DIM_Y, + BLOCK_DIM_Z, + PTX_ARCH> + BlockRadixSortT; + + // Parameterize BlockDiscontinuity type for our thread block + typedef BlockDiscontinuity< + T, + BLOCK_DIM_X, + BLOCK_DIM_Y, + BLOCK_DIM_Z, + PTX_ARCH> + BlockDiscontinuityT; + + /// Shared memory + union _TempStorage + { + // Storage for sorting bin values + typename BlockRadixSortT::TempStorage sort; + + struct + { + // Storage for detecting discontinuities in the tile of sorted bin values + typename BlockDiscontinuityT::TempStorage flag; + + // Storage for noting begin/end offsets of bin runs in the tile of sorted bin values + unsigned int run_begin[BINS]; + unsigned int run_end[BINS]; + }; + }; + + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + // Thread fields + _TempStorage &temp_storage; + unsigned int linear_tid; + + + /// Constructor + __device__ __forceinline__ BlockHistogramSort( + TempStorage &temp_storage) + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + // Discontinuity functor + struct DiscontinuityOp + { + // Reference to temp_storage + _TempStorage &temp_storage; + + // Constructor + __device__ __forceinline__ DiscontinuityOp(_TempStorage &temp_storage) : + temp_storage(temp_storage) + {} + + // Discontinuity predicate + __device__ __forceinline__ bool operator()(const T &a, const T &b, int b_index) + { + if (a != b) + { + // Note the begin/end offsets in shared storage + temp_storage.run_begin[b] = b_index; + temp_storage.run_end[a] = b_index; + + return true; + } + else + { + return false; + } + } + }; + + + // Composite data onto an existing histogram + template < + typename CounterT > + __device__ __forceinline__ void Composite( + T (&items)[ITEMS_PER_THREAD], ///< [in] Calling thread's input values to histogram + CounterT histogram[BINS]) ///< [out] Reference to shared/device-accessible memory histogram + { + enum { TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD }; + + // Sort bytes in blocked arrangement + BlockRadixSortT(temp_storage.sort).Sort(items); + + CTA_SYNC(); + + // Initialize the shared memory's run_begin and run_end for each bin + int histo_offset = 0; + + #pragma unroll + for(; histo_offset + BLOCK_THREADS <= BINS; histo_offset += BLOCK_THREADS) + { + temp_storage.run_begin[histo_offset + linear_tid] = TILE_SIZE; + temp_storage.run_end[histo_offset + linear_tid] = TILE_SIZE; + } + // Finish up with guarded initialization if necessary + if ((BINS % BLOCK_THREADS != 0) && (histo_offset + linear_tid < BINS)) + { + temp_storage.run_begin[histo_offset + linear_tid] = TILE_SIZE; + temp_storage.run_end[histo_offset + linear_tid] = TILE_SIZE; + } + + CTA_SYNC(); + + int flags[ITEMS_PER_THREAD]; // unused + + // Compute head flags to demarcate contiguous runs of the same bin in the sorted tile + DiscontinuityOp flag_op(temp_storage); + BlockDiscontinuityT(temp_storage.flag).FlagHeads(flags, items, flag_op); + + // Update begin for first item + if (linear_tid == 0) temp_storage.run_begin[items[0]] = 0; + + CTA_SYNC(); + + // Composite into histogram + histo_offset = 0; + + #pragma unroll + for(; histo_offset + BLOCK_THREADS <= BINS; histo_offset += BLOCK_THREADS) + { + int thread_offset = histo_offset + linear_tid; + CounterT count = temp_storage.run_end[thread_offset] - temp_storage.run_begin[thread_offset]; + histogram[thread_offset] += count; + } + + // Finish up with guarded composition if necessary + if ((BINS % BLOCK_THREADS != 0) && (histo_offset + linear_tid < BINS)) + { + int thread_offset = histo_offset + linear_tid; + CounterT count = temp_storage.run_end[thread_offset] - temp_storage.run_begin[thread_offset]; + histogram[thread_offset] += count; + } + } + +}; + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/block/specializations/block_reduce_raking.cuh b/fastertransformer/cuda/cub/block/specializations/block_reduce_raking.cuh new file mode 100644 index 000000000..aff97fc9b --- /dev/null +++ b/fastertransformer/cuda/cub/block/specializations/block_reduce_raking.cuh @@ -0,0 +1,226 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::BlockReduceRaking provides raking-based methods of parallel reduction across a CUDA thread block. Supports non-commutative reduction operators. + */ + +#pragma once + +#include "../../block/block_raking_layout.cuh" +#include "../../warp/warp_reduce.cuh" +#include "../../thread/thread_reduce.cuh" +#include "../../util_ptx.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \brief BlockReduceRaking provides raking-based methods of parallel reduction across a CUDA thread block. Supports non-commutative reduction operators. + * + * Supports non-commutative binary reduction operators. Unlike commutative + * reduction operators (e.g., addition), the application of a non-commutative + * reduction operator (e.g, string concatenation) across a sequence of inputs must + * honor the relative ordering of items and partial reductions when applying the + * reduction operator. + * + * Compared to the implementation of BlockReduceRaking (which does not support + * non-commutative operators), this implementation requires a few extra + * rounds of inter-thread communication. + */ +template < + typename T, ///< Data type being reduced + int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension + int BLOCK_DIM_Y, ///< The thread block length in threads along the Y dimension + int BLOCK_DIM_Z, ///< The thread block length in threads along the Z dimension + int PTX_ARCH> ///< The PTX compute capability for which to to specialize this collective +struct BlockReduceRaking +{ + /// Constants + enum + { + /// The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + }; + + /// Layout type for padded thread block raking grid + typedef BlockRakingLayout BlockRakingLayout; + + /// WarpReduce utility type + typedef typename WarpReduce::InternalWarpReduce WarpReduce; + + /// Constants + enum + { + /// Number of raking threads + RAKING_THREADS = BlockRakingLayout::RAKING_THREADS, + + /// Number of raking elements per warp synchronous raking thread + SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH, + + /// Cooperative work can be entirely warp synchronous + WARP_SYNCHRONOUS = (RAKING_THREADS == BLOCK_THREADS), + + /// Whether or not warp-synchronous reduction should be unguarded (i.e., the warp-reduction elements is a power of two + WARP_SYNCHRONOUS_UNGUARDED = PowerOfTwo::VALUE, + + /// Whether or not accesses into smem are unguarded + RAKING_UNGUARDED = BlockRakingLayout::UNGUARDED, + + }; + + + /// Shared memory storage layout type + union _TempStorage + { + typename WarpReduce::TempStorage warp_storage; ///< Storage for warp-synchronous reduction + typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid + }; + + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + // Thread fields + _TempStorage &temp_storage; + unsigned int linear_tid; + + + /// Constructor + __device__ __forceinline__ BlockReduceRaking( + TempStorage &temp_storage) + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + template + __device__ __forceinline__ T RakingReduction( + ReductionOp reduction_op, ///< [in] Binary scan operator + T *raking_segment, + T partial, ///< [in] [lane0 only] Warp-wide aggregate reduction of input items + int num_valid, ///< [in] Number of valid elements (may be less than BLOCK_THREADS) + Int2Type /*iteration*/) + { + // Update partial if addend is in range + if ((IS_FULL_TILE && RAKING_UNGUARDED) || ((linear_tid * SEGMENT_LENGTH) + ITERATION < num_valid)) + { + T addend = raking_segment[ITERATION]; + partial = reduction_op(partial, addend); + } + return RakingReduction(reduction_op, raking_segment, partial, num_valid, Int2Type()); + } + + template + __device__ __forceinline__ T RakingReduction( + ReductionOp /*reduction_op*/, ///< [in] Binary scan operator + T * /*raking_segment*/, + T partial, ///< [in] [lane0 only] Warp-wide aggregate reduction of input items + int /*num_valid*/, ///< [in] Number of valid elements (may be less than BLOCK_THREADS) + Int2Type /*iteration*/) + { + return partial; + } + + + + /// Computes a thread block-wide reduction using the specified reduction operator. The first num_valid threads each contribute one reduction partial. The return value is only valid for thread0. + template < + bool IS_FULL_TILE, + typename ReductionOp> + __device__ __forceinline__ T Reduce( + T partial, ///< [in] Calling thread's input partial reductions + int num_valid, ///< [in] Number of valid elements (may be less than BLOCK_THREADS) + ReductionOp reduction_op) ///< [in] Binary reduction operator + { + if (WARP_SYNCHRONOUS) + { + // Short-circuit directly to warp synchronous reduction (unguarded if active threads is a power-of-two) + partial = WarpReduce(temp_storage.warp_storage).template Reduce( + partial, + num_valid, + reduction_op); + } + else + { + // Place partial into shared memory grid. + *BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid) = partial; + + CTA_SYNC(); + + // Reduce parallelism to one warp + if (linear_tid < RAKING_THREADS) + { + // Raking reduction in grid + T *raking_segment = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); + partial = raking_segment[0]; + + partial = RakingReduction(reduction_op, raking_segment, partial, num_valid, Int2Type<1>()); + + int valid_raking_threads = (IS_FULL_TILE) ? + RAKING_THREADS : + (num_valid + SEGMENT_LENGTH - 1) / SEGMENT_LENGTH; + + partial = WarpReduce(temp_storage.warp_storage).template Reduce( + partial, + valid_raking_threads, + reduction_op); + + } + } + + return partial; + } + + + /// Computes a thread block-wide reduction using addition (+) as the reduction operator. The first num_valid threads each contribute one reduction partial. The return value is only valid for thread0. + template + __device__ __forceinline__ T Sum( + T partial, ///< [in] Calling thread's input partial reductions + int num_valid) ///< [in] Number of valid elements (may be less than BLOCK_THREADS) + { + cub::Sum reduction_op; + + return Reduce(partial, num_valid, reduction_op); + } + + + +}; + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/block/specializations/block_reduce_raking_commutative_only.cuh b/fastertransformer/cuda/cub/block/specializations/block_reduce_raking_commutative_only.cuh new file mode 100644 index 000000000..454fdafa5 --- /dev/null +++ b/fastertransformer/cuda/cub/block/specializations/block_reduce_raking_commutative_only.cuh @@ -0,0 +1,199 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::BlockReduceRakingCommutativeOnly provides raking-based methods of parallel reduction across a CUDA thread block. Does not support non-commutative reduction operators. + */ + +#pragma once + +#include "block_reduce_raking.cuh" +#include "../../warp/warp_reduce.cuh" +#include "../../thread/thread_reduce.cuh" +#include "../../util_ptx.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \brief BlockReduceRakingCommutativeOnly provides raking-based methods of parallel reduction across a CUDA thread block. Does not support non-commutative reduction operators. Does not support block sizes that are not a multiple of the warp size. + */ +template < + typename T, ///< Data type being reduced + int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension + int BLOCK_DIM_Y, ///< The thread block length in threads along the Y dimension + int BLOCK_DIM_Z, ///< The thread block length in threads along the Z dimension + int PTX_ARCH> ///< The PTX compute capability for which to to specialize this collective +struct BlockReduceRakingCommutativeOnly +{ + /// Constants + enum + { + /// The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + }; + + // The fall-back implementation to use when BLOCK_THREADS is not a multiple of the warp size or not all threads have valid values + typedef BlockReduceRaking FallBack; + + /// Constants + enum + { + /// Number of warp threads + WARP_THREADS = CUB_WARP_THREADS(PTX_ARCH), + + /// Whether or not to use fall-back + USE_FALLBACK = ((BLOCK_THREADS % WARP_THREADS != 0) || (BLOCK_THREADS <= WARP_THREADS)), + + /// Number of raking threads + RAKING_THREADS = WARP_THREADS, + + /// Number of threads actually sharing items with the raking threads + SHARING_THREADS = CUB_MAX(1, BLOCK_THREADS - RAKING_THREADS), + + /// Number of raking elements per warp synchronous raking thread + SEGMENT_LENGTH = SHARING_THREADS / WARP_THREADS, + }; + + /// WarpReduce utility type + typedef WarpReduce WarpReduce; + + /// Layout type for padded thread block raking grid + typedef BlockRakingLayout BlockRakingLayout; + + /// Shared memory storage layout type + union _TempStorage + { + struct + { + typename WarpReduce::TempStorage warp_storage; ///< Storage for warp-synchronous reduction + typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid + }; + typename FallBack::TempStorage fallback_storage; ///< Fall-back storage for non-commutative block scan + }; + + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + // Thread fields + _TempStorage &temp_storage; + unsigned int linear_tid; + + + /// Constructor + __device__ __forceinline__ BlockReduceRakingCommutativeOnly( + TempStorage &temp_storage) + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + /// Computes a thread block-wide reduction using addition (+) as the reduction operator. The first num_valid threads each contribute one reduction partial. The return value is only valid for thread0. + template + __device__ __forceinline__ T Sum( + T partial, ///< [in] Calling thread's input partial reductions + int num_valid) ///< [in] Number of valid elements (may be less than BLOCK_THREADS) + { + if (USE_FALLBACK || !FULL_TILE) + { + return FallBack(temp_storage.fallback_storage).template Sum(partial, num_valid); + } + else + { + // Place partial into shared memory grid + if (linear_tid >= RAKING_THREADS) + *BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid - RAKING_THREADS) = partial; + + CTA_SYNC(); + + // Reduce parallelism to one warp + if (linear_tid < RAKING_THREADS) + { + // Raking reduction in grid + T *raking_segment = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); + partial = internal::ThreadReduce(raking_segment, cub::Sum(), partial); + + // Warpscan + partial = WarpReduce(temp_storage.warp_storage).Sum(partial); + } + } + + return partial; + } + + + /// Computes a thread block-wide reduction using the specified reduction operator. The first num_valid threads each contribute one reduction partial. The return value is only valid for thread0. + template < + bool FULL_TILE, + typename ReductionOp> + __device__ __forceinline__ T Reduce( + T partial, ///< [in] Calling thread's input partial reductions + int num_valid, ///< [in] Number of valid elements (may be less than BLOCK_THREADS) + ReductionOp reduction_op) ///< [in] Binary reduction operator + { + if (USE_FALLBACK || !FULL_TILE) + { + return FallBack(temp_storage.fallback_storage).template Reduce(partial, num_valid, reduction_op); + } + else + { + // Place partial into shared memory grid + if (linear_tid >= RAKING_THREADS) + *BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid - RAKING_THREADS) = partial; + + CTA_SYNC(); + + // Reduce parallelism to one warp + if (linear_tid < RAKING_THREADS) + { + // Raking reduction in grid + T *raking_segment = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); + partial = internal::ThreadReduce(raking_segment, reduction_op, partial); + + // Warpscan + partial = WarpReduce(temp_storage.warp_storage).Reduce(partial, reduction_op); + } + } + + return partial; + } + +}; + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/block/specializations/block_reduce_warp_reductions.cuh b/fastertransformer/cuda/cub/block/specializations/block_reduce_warp_reductions.cuh new file mode 100644 index 000000000..10ba303b4 --- /dev/null +++ b/fastertransformer/cuda/cub/block/specializations/block_reduce_warp_reductions.cuh @@ -0,0 +1,218 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::BlockReduceWarpReductions provides variants of warp-reduction-based parallel reduction across a CUDA thread block. Supports non-commutative reduction operators. + */ + +#pragma once + +#include "../../warp/warp_reduce.cuh" +#include "../../util_ptx.cuh" +#include "../../util_arch.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \brief BlockReduceWarpReductions provides variants of warp-reduction-based parallel reduction across a CUDA thread block. Supports non-commutative reduction operators. + */ +template < + typename T, ///< Data type being reduced + int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension + int BLOCK_DIM_Y, ///< The thread block length in threads along the Y dimension + int BLOCK_DIM_Z, ///< The thread block length in threads along the Z dimension + int PTX_ARCH> ///< The PTX compute capability for which to to specialize this collective +struct BlockReduceWarpReductions +{ + /// Constants + enum + { + /// The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + + /// Number of warp threads + WARP_THREADS = CUB_WARP_THREADS(PTX_ARCH), + + /// Number of active warps + WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, + + /// The logical warp size for warp reductions + LOGICAL_WARP_SIZE = CUB_MIN(BLOCK_THREADS, WARP_THREADS), + + /// Whether or not the logical warp size evenly divides the thread block size + EVEN_WARP_MULTIPLE = (BLOCK_THREADS % LOGICAL_WARP_SIZE == 0) + }; + + + /// WarpReduce utility type + typedef typename WarpReduce::InternalWarpReduce WarpReduce; + + + /// Shared memory storage layout type + struct _TempStorage + { + typename WarpReduce::TempStorage warp_reduce[WARPS]; ///< Buffer for warp-synchronous scan + T warp_aggregates[WARPS]; ///< Shared totals from each warp-synchronous scan + T block_prefix; ///< Shared prefix for the entire thread block + }; + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + // Thread fields + _TempStorage &temp_storage; + int linear_tid; + int warp_id; + int lane_id; + + + /// Constructor + __device__ __forceinline__ BlockReduceWarpReductions( + TempStorage &temp_storage) + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)), + warp_id((WARPS == 1) ? 0 : linear_tid / WARP_THREADS), + lane_id(LaneId()) + {} + + + template + __device__ __forceinline__ T ApplyWarpAggregates( + ReductionOp reduction_op, ///< [in] Binary scan operator + T warp_aggregate, ///< [in] [lane0 only] Warp-wide aggregate reduction of input items + int num_valid, ///< [in] Number of valid elements (may be less than BLOCK_THREADS) + Int2Type /*successor_warp*/) + { + if (FULL_TILE || (SUCCESSOR_WARP * LOGICAL_WARP_SIZE < num_valid)) + { + T addend = temp_storage.warp_aggregates[SUCCESSOR_WARP]; + warp_aggregate = reduction_op(warp_aggregate, addend); + } + return ApplyWarpAggregates(reduction_op, warp_aggregate, num_valid, Int2Type()); + } + + template + __device__ __forceinline__ T ApplyWarpAggregates( + ReductionOp /*reduction_op*/, ///< [in] Binary scan operator + T warp_aggregate, ///< [in] [lane0 only] Warp-wide aggregate reduction of input items + int /*num_valid*/, ///< [in] Number of valid elements (may be less than BLOCK_THREADS) + Int2Type /*successor_warp*/) + { + return warp_aggregate; + } + + + /// Returns block-wide aggregate in thread0. + template < + bool FULL_TILE, + typename ReductionOp> + __device__ __forceinline__ T ApplyWarpAggregates( + ReductionOp reduction_op, ///< [in] Binary scan operator + T warp_aggregate, ///< [in] [lane0 only] Warp-wide aggregate reduction of input items + int num_valid) ///< [in] Number of valid elements (may be less than BLOCK_THREADS) + { + // Share lane aggregates + if (lane_id == 0) + { + temp_storage.warp_aggregates[warp_id] = warp_aggregate; + } + + CTA_SYNC(); + + // Update total aggregate in warp 0, lane 0 + if (linear_tid == 0) + { + warp_aggregate = ApplyWarpAggregates(reduction_op, warp_aggregate, num_valid, Int2Type<1>()); + } + + return warp_aggregate; + } + + + /// Computes a thread block-wide reduction using addition (+) as the reduction operator. The first num_valid threads each contribute one reduction partial. The return value is only valid for thread0. + template + __device__ __forceinline__ T Sum( + T input, ///< [in] Calling thread's input partial reductions + int num_valid) ///< [in] Number of valid elements (may be less than BLOCK_THREADS) + { + cub::Sum reduction_op; + int warp_offset = (warp_id * LOGICAL_WARP_SIZE); + int warp_num_valid = ((FULL_TILE && EVEN_WARP_MULTIPLE) || (warp_offset + LOGICAL_WARP_SIZE <= num_valid)) ? + LOGICAL_WARP_SIZE : + num_valid - warp_offset; + + // Warp reduction in every warp + T warp_aggregate = WarpReduce(temp_storage.warp_reduce[warp_id]).template Reduce<(FULL_TILE && EVEN_WARP_MULTIPLE)>( + input, + warp_num_valid, + cub::Sum()); + + // Update outputs and block_aggregate with warp-wide aggregates from lane-0s + return ApplyWarpAggregates(reduction_op, warp_aggregate, num_valid); + } + + + /// Computes a thread block-wide reduction using the specified reduction operator. The first num_valid threads each contribute one reduction partial. The return value is only valid for thread0. + template < + bool FULL_TILE, + typename ReductionOp> + __device__ __forceinline__ T Reduce( + T input, ///< [in] Calling thread's input partial reductions + int num_valid, ///< [in] Number of valid elements (may be less than BLOCK_THREADS) + ReductionOp reduction_op) ///< [in] Binary reduction operator + { + int warp_offset = warp_id * LOGICAL_WARP_SIZE; + int warp_num_valid = ((FULL_TILE && EVEN_WARP_MULTIPLE) || (warp_offset + LOGICAL_WARP_SIZE <= num_valid)) ? + LOGICAL_WARP_SIZE : + num_valid - warp_offset; + + // Warp reduction in every warp + T warp_aggregate = WarpReduce(temp_storage.warp_reduce[warp_id]).template Reduce<(FULL_TILE && EVEN_WARP_MULTIPLE)>( + input, + warp_num_valid, + reduction_op); + + // Update outputs and block_aggregate with warp-wide aggregates from lane-0s + return ApplyWarpAggregates(reduction_op, warp_aggregate, num_valid); + } + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/block/specializations/block_scan_raking.cuh b/fastertransformer/cuda/cub/block/specializations/block_scan_raking.cuh new file mode 100644 index 000000000..a855cda0b --- /dev/null +++ b/fastertransformer/cuda/cub/block/specializations/block_scan_raking.cuh @@ -0,0 +1,666 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + + +/** + * \file + * cub::BlockScanRaking provides variants of raking-based parallel prefix scan across a CUDA thread block. + */ + +#pragma once + +#include "../../util_ptx.cuh" +#include "../../util_arch.cuh" +#include "../../block/block_raking_layout.cuh" +#include "../../thread/thread_reduce.cuh" +#include "../../thread/thread_scan.cuh" +#include "../../warp/warp_scan.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \brief BlockScanRaking provides variants of raking-based parallel prefix scan across a CUDA thread block. + */ +template < + typename T, ///< Data type being scanned + int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension + int BLOCK_DIM_Y, ///< The thread block length in threads along the Y dimension + int BLOCK_DIM_Z, ///< The thread block length in threads along the Z dimension + bool MEMOIZE, ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure + int PTX_ARCH> ///< The PTX compute capability for which to to specialize this collective +struct BlockScanRaking +{ + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + /// Constants + enum + { + /// The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + }; + + /// Layout type for padded thread block raking grid + typedef BlockRakingLayout BlockRakingLayout; + + /// Constants + enum + { + /// Number of raking threads + RAKING_THREADS = BlockRakingLayout::RAKING_THREADS, + + /// Number of raking elements per warp synchronous raking thread + SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH, + + /// Cooperative work can be entirely warp synchronous + WARP_SYNCHRONOUS = (BLOCK_THREADS == RAKING_THREADS), + }; + + /// WarpScan utility type + typedef WarpScan WarpScan; + + /// Shared memory storage layout type + struct _TempStorage + { + typename WarpScan::TempStorage warp_scan; ///< Buffer for warp-synchronous scan + typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid + T block_aggregate; ///< Block aggregate + }; + + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + // Thread fields + _TempStorage &temp_storage; + unsigned int linear_tid; + T cached_segment[SEGMENT_LENGTH]; + + + //--------------------------------------------------------------------- + // Utility methods + //--------------------------------------------------------------------- + + /// Templated reduction + template + __device__ __forceinline__ T GuardedReduce( + T* raking_ptr, ///< [in] Input array + ScanOp scan_op, ///< [in] Binary reduction operator + T raking_partial, ///< [in] Prefix to seed reduction with + Int2Type /*iteration*/) + { + if ((BlockRakingLayout::UNGUARDED) || (((linear_tid * SEGMENT_LENGTH) + ITERATION) < BLOCK_THREADS)) + { + T addend = raking_ptr[ITERATION]; + raking_partial = scan_op(raking_partial, addend); + } + + return GuardedReduce(raking_ptr, scan_op, raking_partial, Int2Type()); + } + + + /// Templated reduction (base case) + template + __device__ __forceinline__ T GuardedReduce( + T* /*raking_ptr*/, ///< [in] Input array + ScanOp /*scan_op*/, ///< [in] Binary reduction operator + T raking_partial, ///< [in] Prefix to seed reduction with + Int2Type /*iteration*/) + { + return raking_partial; + } + + + /// Templated copy + template + __device__ __forceinline__ void CopySegment( + T* out, ///< [out] Out array + T* in, ///< [in] Input array + Int2Type /*iteration*/) + { + out[ITERATION] = in[ITERATION]; + CopySegment(out, in, Int2Type()); + } + + + /// Templated copy (base case) + __device__ __forceinline__ void CopySegment( + T* /*out*/, ///< [out] Out array + T* /*in*/, ///< [in] Input array + Int2Type /*iteration*/) + {} + + + /// Performs upsweep raking reduction, returning the aggregate + template + __device__ __forceinline__ T Upsweep( + ScanOp scan_op) + { + T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); + + // Read data into registers + CopySegment(cached_segment, smem_raking_ptr, Int2Type<0>()); + + T raking_partial = cached_segment[0]; + + return GuardedReduce(cached_segment, scan_op, raking_partial, Int2Type<1>()); + } + + + /// Performs exclusive downsweep raking scan + template + __device__ __forceinline__ void ExclusiveDownsweep( + ScanOp scan_op, + T raking_partial, + bool apply_prefix = true) + { + T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); + + // Read data back into registers + if (!MEMOIZE) + { + CopySegment(cached_segment, smem_raking_ptr, Int2Type<0>()); + } + + internal::ThreadScanExclusive(cached_segment, cached_segment, scan_op, raking_partial, apply_prefix); + + // Write data back to smem + CopySegment(smem_raking_ptr, cached_segment, Int2Type<0>()); + } + + + /// Performs inclusive downsweep raking scan + template + __device__ __forceinline__ void InclusiveDownsweep( + ScanOp scan_op, + T raking_partial, + bool apply_prefix = true) + { + T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); + + // Read data back into registers + if (!MEMOIZE) + { + CopySegment(cached_segment, smem_raking_ptr, Int2Type<0>()); + } + + internal::ThreadScanInclusive(cached_segment, cached_segment, scan_op, raking_partial, apply_prefix); + + // Write data back to smem + CopySegment(smem_raking_ptr, cached_segment, Int2Type<0>()); + } + + + //--------------------------------------------------------------------- + // Constructors + //--------------------------------------------------------------------- + + /// Constructor + __device__ __forceinline__ BlockScanRaking( + TempStorage &temp_storage) + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + //--------------------------------------------------------------------- + // Exclusive scans + //--------------------------------------------------------------------- + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. With no initial value, the output computed for thread0 is undefined. + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item + T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op) ///< [in] Binary scan operator + { + if (WARP_SYNCHRONOUS) + { + // Short-circuit directly to warp-synchronous scan + WarpScan(temp_storage.warp_scan).ExclusiveScan(input, exclusive_output, scan_op); + } + else + { + // Place thread partial into shared memory raking grid + T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); + *placement_ptr = input; + + CTA_SYNC(); + + // Reduce parallelism down to just raking threads + if (linear_tid < RAKING_THREADS) + { + // Raking upsweep reduction across shared partials + T upsweep_partial = Upsweep(scan_op); + + // Warp-synchronous scan + T exclusive_partial; + WarpScan(temp_storage.warp_scan).ExclusiveScan(upsweep_partial, exclusive_partial, scan_op); + + // Exclusive raking downsweep scan + ExclusiveDownsweep(scan_op, exclusive_partial, (linear_tid != 0)); + } + + CTA_SYNC(); + + // Grab thread prefix from shared memory + exclusive_output = *placement_ptr; + } + } + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input items + T &output, ///< [out] Calling thread's output items (may be aliased to \p input) + const T &initial_value, ///< [in] Initial value to seed the exclusive scan + ScanOp scan_op) ///< [in] Binary scan operator + { + if (WARP_SYNCHRONOUS) + { + // Short-circuit directly to warp-synchronous scan + WarpScan(temp_storage.warp_scan).ExclusiveScan(input, output, initial_value, scan_op); + } + else + { + // Place thread partial into shared memory raking grid + T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); + *placement_ptr = input; + + CTA_SYNC(); + + // Reduce parallelism down to just raking threads + if (linear_tid < RAKING_THREADS) + { + // Raking upsweep reduction across shared partials + T upsweep_partial = Upsweep(scan_op); + + // Exclusive Warp-synchronous scan + T exclusive_partial; + WarpScan(temp_storage.warp_scan).ExclusiveScan(upsweep_partial, exclusive_partial, initial_value, scan_op); + + // Exclusive raking downsweep scan + ExclusiveDownsweep(scan_op, exclusive_partial); + } + + CTA_SYNC(); + + // Grab exclusive partial from shared memory + output = *placement_ptr; + } + } + + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. Also provides every thread with the block-wide \p block_aggregate of all inputs. With no initial value, the output computed for thread0 is undefined. + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item + T &output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T &block_aggregate) ///< [out] Threadblock-wide aggregate reduction of input items + { + if (WARP_SYNCHRONOUS) + { + // Short-circuit directly to warp-synchronous scan + WarpScan(temp_storage.warp_scan).ExclusiveScan(input, output, scan_op, block_aggregate); + } + else + { + // Place thread partial into shared memory raking grid + T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); + *placement_ptr = input; + + CTA_SYNC(); + + // Reduce parallelism down to just raking threads + if (linear_tid < RAKING_THREADS) + { + // Raking upsweep reduction across shared partials + T upsweep_partial= Upsweep(scan_op); + + // Warp-synchronous scan + T inclusive_partial; + T exclusive_partial; + WarpScan(temp_storage.warp_scan).Scan(upsweep_partial, inclusive_partial, exclusive_partial, scan_op); + + // Exclusive raking downsweep scan + ExclusiveDownsweep(scan_op, exclusive_partial, (linear_tid != 0)); + + // Broadcast aggregate to all threads + if (linear_tid == RAKING_THREADS - 1) + temp_storage.block_aggregate = inclusive_partial; + } + + CTA_SYNC(); + + // Grab thread prefix from shared memory + output = *placement_ptr; + + // Retrieve block aggregate + block_aggregate = temp_storage.block_aggregate; + } + } + + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input items + T &output, ///< [out] Calling thread's output items (may be aliased to \p input) + const T &initial_value, ///< [in] Initial value to seed the exclusive scan + ScanOp scan_op, ///< [in] Binary scan operator + T &block_aggregate) ///< [out] Threadblock-wide aggregate reduction of input items + { + if (WARP_SYNCHRONOUS) + { + // Short-circuit directly to warp-synchronous scan + WarpScan(temp_storage.warp_scan).ExclusiveScan(input, output, initial_value, scan_op, block_aggregate); + } + else + { + // Place thread partial into shared memory raking grid + T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); + *placement_ptr = input; + + CTA_SYNC(); + + // Reduce parallelism down to just raking threads + if (linear_tid < RAKING_THREADS) + { + // Raking upsweep reduction across shared partials + T upsweep_partial = Upsweep(scan_op); + + // Warp-synchronous scan + T exclusive_partial; + WarpScan(temp_storage.warp_scan).ExclusiveScan(upsweep_partial, exclusive_partial, initial_value, scan_op, block_aggregate); + + // Exclusive raking downsweep scan + ExclusiveDownsweep(scan_op, exclusive_partial); + + // Broadcast aggregate to other threads + if (linear_tid == 0) + temp_storage.block_aggregate = block_aggregate; + } + + CTA_SYNC(); + + // Grab exclusive partial from shared memory + output = *placement_ptr; + + // Retrieve block aggregate + block_aggregate = temp_storage.block_aggregate; + } + } + + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_prefix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template < + typename ScanOp, + typename BlockPrefixCallbackOp> + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item + T &output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + BlockPrefixCallbackOp &block_prefix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a thread block-wide prefix to be applied to all inputs. + { + if (WARP_SYNCHRONOUS) + { + // Short-circuit directly to warp-synchronous scan + T block_aggregate; + WarpScan warp_scan(temp_storage.warp_scan); + warp_scan.ExclusiveScan(input, output, scan_op, block_aggregate); + + // Obtain warp-wide prefix in lane0, then broadcast to other lanes + T block_prefix = block_prefix_callback_op(block_aggregate); + block_prefix = warp_scan.Broadcast(block_prefix, 0); + + output = scan_op(block_prefix, output); + if (linear_tid == 0) + output = block_prefix; + } + else + { + // Place thread partial into shared memory raking grid + T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); + *placement_ptr = input; + + CTA_SYNC(); + + // Reduce parallelism down to just raking threads + if (linear_tid < RAKING_THREADS) + { + WarpScan warp_scan(temp_storage.warp_scan); + + // Raking upsweep reduction across shared partials + T upsweep_partial = Upsweep(scan_op); + + // Warp-synchronous scan + T exclusive_partial, block_aggregate; + warp_scan.ExclusiveScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate); + + // Obtain block-wide prefix in lane0, then broadcast to other lanes + T block_prefix = block_prefix_callback_op(block_aggregate); + block_prefix = warp_scan.Broadcast(block_prefix, 0); + + // Update prefix with warpscan exclusive partial + T downsweep_prefix = scan_op(block_prefix, exclusive_partial); + if (linear_tid == 0) + downsweep_prefix = block_prefix; + + // Exclusive raking downsweep scan + ExclusiveDownsweep(scan_op, downsweep_prefix); + } + + CTA_SYNC(); + + // Grab thread prefix from shared memory + output = *placement_ptr; + } + } + + + //--------------------------------------------------------------------- + // Inclusive scans + //--------------------------------------------------------------------- + + /// Computes an inclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. + template + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item + T &output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op) ///< [in] Binary scan operator + { + if (WARP_SYNCHRONOUS) + { + // Short-circuit directly to warp-synchronous scan + WarpScan(temp_storage.warp_scan).InclusiveScan(input, output, scan_op); + } + else + { + // Place thread partial into shared memory raking grid + T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); + *placement_ptr = input; + + CTA_SYNC(); + + // Reduce parallelism down to just raking threads + if (linear_tid < RAKING_THREADS) + { + // Raking upsweep reduction across shared partials + T upsweep_partial = Upsweep(scan_op); + + // Exclusive Warp-synchronous scan + T exclusive_partial; + WarpScan(temp_storage.warp_scan).ExclusiveScan(upsweep_partial, exclusive_partial, scan_op); + + // Inclusive raking downsweep scan + InclusiveDownsweep(scan_op, exclusive_partial, (linear_tid != 0)); + } + + CTA_SYNC(); + + // Grab thread prefix from shared memory + output = *placement_ptr; + } + } + + + /// Computes an inclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item + T &output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T &block_aggregate) ///< [out] Threadblock-wide aggregate reduction of input items + { + if (WARP_SYNCHRONOUS) + { + // Short-circuit directly to warp-synchronous scan + WarpScan(temp_storage.warp_scan).InclusiveScan(input, output, scan_op, block_aggregate); + } + else + { + // Place thread partial into shared memory raking grid + T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); + *placement_ptr = input; + + CTA_SYNC(); + + // Reduce parallelism down to just raking threads + if (linear_tid < RAKING_THREADS) + { + // Raking upsweep reduction across shared partials + T upsweep_partial = Upsweep(scan_op); + + // Warp-synchronous scan + T inclusive_partial; + T exclusive_partial; + WarpScan(temp_storage.warp_scan).Scan(upsweep_partial, inclusive_partial, exclusive_partial, scan_op); + + // Inclusive raking downsweep scan + InclusiveDownsweep(scan_op, exclusive_partial, (linear_tid != 0)); + + // Broadcast aggregate to all threads + if (linear_tid == RAKING_THREADS - 1) + temp_storage.block_aggregate = inclusive_partial; + } + + CTA_SYNC(); + + // Grab thread prefix from shared memory + output = *placement_ptr; + + // Retrieve block aggregate + block_aggregate = temp_storage.block_aggregate; + } + } + + + /// Computes an inclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_prefix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template < + typename ScanOp, + typename BlockPrefixCallbackOp> + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item + T &output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + BlockPrefixCallbackOp &block_prefix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a thread block-wide prefix to be applied to all inputs. + { + if (WARP_SYNCHRONOUS) + { + // Short-circuit directly to warp-synchronous scan + T block_aggregate; + WarpScan warp_scan(temp_storage.warp_scan); + warp_scan.InclusiveScan(input, output, scan_op, block_aggregate); + + // Obtain warp-wide prefix in lane0, then broadcast to other lanes + T block_prefix = block_prefix_callback_op(block_aggregate); + block_prefix = warp_scan.Broadcast(block_prefix, 0); + + // Update prefix with exclusive warpscan partial + output = scan_op(block_prefix, output); + } + else + { + // Place thread partial into shared memory raking grid + T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); + *placement_ptr = input; + + CTA_SYNC(); + + // Reduce parallelism down to just raking threads + if (linear_tid < RAKING_THREADS) + { + WarpScan warp_scan(temp_storage.warp_scan); + + // Raking upsweep reduction across shared partials + T upsweep_partial = Upsweep(scan_op); + + // Warp-synchronous scan + T exclusive_partial, block_aggregate; + warp_scan.ExclusiveScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate); + + // Obtain block-wide prefix in lane0, then broadcast to other lanes + T block_prefix = block_prefix_callback_op(block_aggregate); + block_prefix = warp_scan.Broadcast(block_prefix, 0); + + // Update prefix with warpscan exclusive partial + T downsweep_prefix = scan_op(block_prefix, exclusive_partial); + if (linear_tid == 0) + downsweep_prefix = block_prefix; + + // Inclusive raking downsweep scan + InclusiveDownsweep(scan_op, downsweep_prefix); + } + + CTA_SYNC(); + + // Grab thread prefix from shared memory + output = *placement_ptr; + } + } + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/block/specializations/block_scan_warp_scans.cuh b/fastertransformer/cuda/cub/block/specializations/block_scan_warp_scans.cuh new file mode 100644 index 000000000..85e4d6135 --- /dev/null +++ b/fastertransformer/cuda/cub/block/specializations/block_scan_warp_scans.cuh @@ -0,0 +1,392 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::BlockScanWarpscans provides warpscan-based variants of parallel prefix scan across a CUDA thread block. + */ + +#pragma once + +#include "../../util_arch.cuh" +#include "../../util_ptx.cuh" +#include "../../warp/warp_scan.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \brief BlockScanWarpScans provides warpscan-based variants of parallel prefix scan across a CUDA thread block. + */ +template < + typename T, + int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension + int BLOCK_DIM_Y, ///< The thread block length in threads along the Y dimension + int BLOCK_DIM_Z, ///< The thread block length in threads along the Z dimension + int PTX_ARCH> ///< The PTX compute capability for which to to specialize this collective +struct BlockScanWarpScans +{ + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + /// Constants + enum + { + /// Number of warp threads + WARP_THREADS = CUB_WARP_THREADS(PTX_ARCH), + + /// The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + + /// Number of active warps + WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, + }; + + /// WarpScan utility type + typedef WarpScan WarpScanT; + + /// WarpScan utility type + typedef WarpScan WarpAggregateScan; + + /// Shared memory storage layout type + + struct __align__(32) _TempStorage + { + T warp_aggregates[WARPS]; + typename WarpScanT::TempStorage warp_scan[WARPS]; ///< Buffer for warp-synchronous scans + T block_prefix; ///< Shared prefix for the entire thread block + }; + + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + // Thread fields + _TempStorage &temp_storage; + unsigned int linear_tid; + unsigned int warp_id; + unsigned int lane_id; + + + //--------------------------------------------------------------------- + // Constructors + //--------------------------------------------------------------------- + + /// Constructor + __device__ __forceinline__ BlockScanWarpScans( + TempStorage &temp_storage) + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)), + warp_id((WARPS == 1) ? 0 : linear_tid / WARP_THREADS), + lane_id(LaneId()) + {} + + + //--------------------------------------------------------------------- + // Utility methods + //--------------------------------------------------------------------- + + template + __device__ __forceinline__ void ApplyWarpAggregates( + T &warp_prefix, ///< [out] The calling thread's partial reduction + ScanOp scan_op, ///< [in] Binary scan operator + T &block_aggregate, ///< [out] Threadblock-wide aggregate reduction of input items + Int2Type /*addend_warp*/) + { + if (warp_id == WARP) + warp_prefix = block_aggregate; + + T addend = temp_storage.warp_aggregates[WARP]; + block_aggregate = scan_op(block_aggregate, addend); + + ApplyWarpAggregates(warp_prefix, scan_op, block_aggregate, Int2Type()); + } + + template + __device__ __forceinline__ void ApplyWarpAggregates( + T &/*warp_prefix*/, ///< [out] The calling thread's partial reduction + ScanOp /*scan_op*/, ///< [in] Binary scan operator + T &/*block_aggregate*/, ///< [out] Threadblock-wide aggregate reduction of input items + Int2Type /*addend_warp*/) + {} + + + /// Use the warp-wide aggregates to compute the calling warp's prefix. Also returns block-wide aggregate in all threads. + template + __device__ __forceinline__ T ComputeWarpPrefix( + ScanOp scan_op, ///< [in] Binary scan operator + T warp_aggregate, ///< [in] [laneWARP_THREADS - 1 only] Warp-wide aggregate reduction of input items + T &block_aggregate) ///< [out] Threadblock-wide aggregate reduction of input items + { + // Last lane in each warp shares its warp-aggregate + if (lane_id == WARP_THREADS - 1) + temp_storage.warp_aggregates[warp_id] = warp_aggregate; + + CTA_SYNC(); + + // Accumulate block aggregates and save the one that is our warp's prefix + T warp_prefix; + block_aggregate = temp_storage.warp_aggregates[0]; + + // Use template unrolling (since the PTX backend can't handle unrolling it for SM1x) + ApplyWarpAggregates(warp_prefix, scan_op, block_aggregate, Int2Type<1>()); +/* + #pragma unroll + for (int WARP = 1; WARP < WARPS; ++WARP) + { + if (warp_id == WARP) + warp_prefix = block_aggregate; + + T addend = temp_storage.warp_aggregates[WARP]; + block_aggregate = scan_op(block_aggregate, addend); + } +*/ + + return warp_prefix; + } + + + /// Use the warp-wide aggregates and initial-value to compute the calling warp's prefix. Also returns block-wide aggregate in all threads. + template + __device__ __forceinline__ T ComputeWarpPrefix( + ScanOp scan_op, ///< [in] Binary scan operator + T warp_aggregate, ///< [in] [laneWARP_THREADS - 1 only] Warp-wide aggregate reduction of input items + T &block_aggregate, ///< [out] Threadblock-wide aggregate reduction of input items + const T &initial_value) ///< [in] Initial value to seed the exclusive scan + { + T warp_prefix = ComputeWarpPrefix(scan_op, warp_aggregate, block_aggregate); + + warp_prefix = scan_op(initial_value, warp_prefix); + + if (warp_id == 0) + warp_prefix = initial_value; + + return warp_prefix; + } + + //--------------------------------------------------------------------- + // Exclusive scans + //--------------------------------------------------------------------- + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. With no initial value, the output computed for thread0 is undefined. + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item + T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op) ///< [in] Binary scan operator + { + // Compute block-wide exclusive scan. The exclusive output from tid0 is invalid. + T block_aggregate; + ExclusiveScan(input, exclusive_output, scan_op, block_aggregate); + } + + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input items + T &exclusive_output, ///< [out] Calling thread's output items (may be aliased to \p input) + const T &initial_value, ///< [in] Initial value to seed the exclusive scan + ScanOp scan_op) ///< [in] Binary scan operator + { + T block_aggregate; + ExclusiveScan(input, exclusive_output, initial_value, scan_op, block_aggregate); + } + + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. Also provides every thread with the block-wide \p block_aggregate of all inputs. With no initial value, the output computed for thread0 is undefined. + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item + T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T &block_aggregate) ///< [out] Threadblock-wide aggregate reduction of input items + { + // Compute warp scan in each warp. The exclusive output from each lane0 is invalid. + T inclusive_output; + WarpScanT(temp_storage.warp_scan[warp_id]).Scan(input, inclusive_output, exclusive_output, scan_op); + + // Compute the warp-wide prefix and block-wide aggregate for each warp. Warp prefix for warp0 is invalid. + T warp_prefix = ComputeWarpPrefix(scan_op, inclusive_output, block_aggregate); + + // Apply warp prefix to our lane's partial + if (warp_id != 0) + { + exclusive_output = scan_op(warp_prefix, exclusive_output); + if (lane_id == 0) + exclusive_output = warp_prefix; + } + } + + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input items + T &exclusive_output, ///< [out] Calling thread's output items (may be aliased to \p input) + const T &initial_value, ///< [in] Initial value to seed the exclusive scan + ScanOp scan_op, ///< [in] Binary scan operator + T &block_aggregate) ///< [out] Threadblock-wide aggregate reduction of input items + { + // Compute warp scan in each warp. The exclusive output from each lane0 is invalid. + T inclusive_output; + WarpScanT(temp_storage.warp_scan[warp_id]).Scan(input, inclusive_output, exclusive_output, scan_op); + + // Compute the warp-wide prefix and block-wide aggregate for each warp + T warp_prefix = ComputeWarpPrefix(scan_op, inclusive_output, block_aggregate, initial_value); + + // Apply warp prefix to our lane's partial + exclusive_output = scan_op(warp_prefix, exclusive_output); + if (lane_id == 0) + exclusive_output = warp_prefix; + } + + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_prefix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template < + typename ScanOp, + typename BlockPrefixCallbackOp> + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item + T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + BlockPrefixCallbackOp &block_prefix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a thread block-wide prefix to be applied to all inputs. + { + // Compute block-wide exclusive scan. The exclusive output from tid0 is invalid. + T block_aggregate; + ExclusiveScan(input, exclusive_output, scan_op, block_aggregate); + + // Use the first warp to determine the thread block prefix, returning the result in lane0 + if (warp_id == 0) + { + T block_prefix = block_prefix_callback_op(block_aggregate); + if (lane_id == 0) + { + // Share the prefix with all threads + temp_storage.block_prefix = block_prefix; + exclusive_output = block_prefix; // The block prefix is the exclusive output for tid0 + } + } + + CTA_SYNC(); + + // Incorporate thread block prefix into outputs + T block_prefix = temp_storage.block_prefix; + if (linear_tid > 0) + { + exclusive_output = scan_op(block_prefix, exclusive_output); + } + } + + + //--------------------------------------------------------------------- + // Inclusive scans + //--------------------------------------------------------------------- + + /// Computes an inclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. + template + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item + T &inclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op) ///< [in] Binary scan operator + { + T block_aggregate; + InclusiveScan(input, inclusive_output, scan_op, block_aggregate); + } + + + /// Computes an inclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item + T &inclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T &block_aggregate) ///< [out] Threadblock-wide aggregate reduction of input items + { + WarpScanT(temp_storage.warp_scan[warp_id]).InclusiveScan(input, inclusive_output, scan_op); + + // Compute the warp-wide prefix and block-wide aggregate for each warp. Warp prefix for warp0 is invalid. + T warp_prefix = ComputeWarpPrefix(scan_op, inclusive_output, block_aggregate); + + // Apply warp prefix to our lane's partial + if (warp_id != 0) + { + inclusive_output = scan_op(warp_prefix, inclusive_output); + } + } + + + /// Computes an inclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_prefix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template < + typename ScanOp, + typename BlockPrefixCallbackOp> + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item + T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + BlockPrefixCallbackOp &block_prefix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a thread block-wide prefix to be applied to all inputs. + { + T block_aggregate; + InclusiveScan(input, exclusive_output, scan_op, block_aggregate); + + // Use the first warp to determine the thread block prefix, returning the result in lane0 + if (warp_id == 0) + { + T block_prefix = block_prefix_callback_op(block_aggregate); + if (lane_id == 0) + { + // Share the prefix with all threads + temp_storage.block_prefix = block_prefix; + } + } + + CTA_SYNC(); + + // Incorporate thread block prefix into outputs + T block_prefix = temp_storage.block_prefix; + exclusive_output = scan_op(block_prefix, exclusive_output); + } + + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/block/specializations/block_scan_warp_scans2.cuh b/fastertransformer/cuda/cub/block/specializations/block_scan_warp_scans2.cuh new file mode 100644 index 000000000..4de7c69b7 --- /dev/null +++ b/fastertransformer/cuda/cub/block/specializations/block_scan_warp_scans2.cuh @@ -0,0 +1,436 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::BlockScanWarpscans provides warpscan-based variants of parallel prefix scan across a CUDA thread block. + */ + +#pragma once + +#include "../../util_arch.cuh" +#include "../../util_ptx.cuh" +#include "../../warp/warp_scan.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \brief BlockScanWarpScans provides warpscan-based variants of parallel prefix scan across a CUDA thread block. + */ +template < + typename T, + int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension + int BLOCK_DIM_Y, ///< The thread block length in threads along the Y dimension + int BLOCK_DIM_Z, ///< The thread block length in threads along the Z dimension + int PTX_ARCH> ///< The PTX compute capability for which to to specialize this collective +struct BlockScanWarpScans +{ + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + /// Constants + enum + { + /// Number of warp threads + WARP_THREADS = CUB_WARP_THREADS(PTX_ARCH), + + /// The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + + /// Number of active warps + WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, + }; + + /// WarpScan utility type + typedef WarpScan WarpScanT; + + /// WarpScan utility type + typedef WarpScan WarpAggregateScanT; + + /// Shared memory storage layout type + struct _TempStorage + { + typename WarpAggregateScanT::TempStorage inner_scan[WARPS]; ///< Buffer for warp-synchronous scans + typename WarpScanT::TempStorage warp_scan[WARPS]; ///< Buffer for warp-synchronous scans + T warp_aggregates[WARPS]; + T block_prefix; ///< Shared prefix for the entire thread block + }; + + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + // Thread fields + _TempStorage &temp_storage; + unsigned int linear_tid; + unsigned int warp_id; + unsigned int lane_id; + + + //--------------------------------------------------------------------- + // Constructors + //--------------------------------------------------------------------- + + /// Constructor + __device__ __forceinline__ BlockScanWarpScans( + TempStorage &temp_storage) + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)), + warp_id((WARPS == 1) ? 0 : linear_tid / WARP_THREADS), + lane_id(LaneId()) + {} + + + //--------------------------------------------------------------------- + // Utility methods + //--------------------------------------------------------------------- + + template + __device__ __forceinline__ void ApplyWarpAggregates( + T &warp_prefix, ///< [out] The calling thread's partial reduction + ScanOp scan_op, ///< [in] Binary scan operator + T &block_aggregate, ///< [out] Threadblock-wide aggregate reduction of input items + Int2Type addend_warp) + { + if (warp_id == WARP) + warp_prefix = block_aggregate; + + T addend = temp_storage.warp_aggregates[WARP]; + block_aggregate = scan_op(block_aggregate, addend); + + ApplyWarpAggregates(warp_prefix, scan_op, block_aggregate, Int2Type()); + } + + template + __device__ __forceinline__ void ApplyWarpAggregates( + T &warp_prefix, ///< [out] The calling thread's partial reduction + ScanOp scan_op, ///< [in] Binary scan operator + T &block_aggregate, ///< [out] Threadblock-wide aggregate reduction of input items + Int2Type addend_warp) + {} + + + /// Use the warp-wide aggregates to compute the calling warp's prefix. Also returns block-wide aggregate in all threads. + template + __device__ __forceinline__ T ComputeWarpPrefix( + ScanOp scan_op, ///< [in] Binary scan operator + T warp_aggregate, ///< [in] [laneWARP_THREADS - 1 only] Warp-wide aggregate reduction of input items + T &block_aggregate) ///< [out] Threadblock-wide aggregate reduction of input items + { + // Last lane in each warp shares its warp-aggregate + if (lane_id == WARP_THREADS - 1) + temp_storage.warp_aggregates[warp_id] = warp_aggregate; + + CTA_SYNC(); + + // Accumulate block aggregates and save the one that is our warp's prefix + T warp_prefix; + block_aggregate = temp_storage.warp_aggregates[0]; + + // Use template unrolling (since the PTX backend can't handle unrolling it for SM1x) + ApplyWarpAggregates(warp_prefix, scan_op, block_aggregate, Int2Type<1>()); +/* + #pragma unroll + for (int WARP = 1; WARP < WARPS; ++WARP) + { + if (warp_id == WARP) + warp_prefix = block_aggregate; + + T addend = temp_storage.warp_aggregates[WARP]; + block_aggregate = scan_op(block_aggregate, addend); + } +*/ + + return warp_prefix; + } + + + /// Use the warp-wide aggregates and initial-value to compute the calling warp's prefix. Also returns block-wide aggregate in all threads. + template + __device__ __forceinline__ T ComputeWarpPrefix( + ScanOp scan_op, ///< [in] Binary scan operator + T warp_aggregate, ///< [in] [laneWARP_THREADS - 1 only] Warp-wide aggregate reduction of input items + T &block_aggregate, ///< [out] Threadblock-wide aggregate reduction of input items + const T &initial_value) ///< [in] Initial value to seed the exclusive scan + { + T warp_prefix = ComputeWarpPrefix(scan_op, warp_aggregate, block_aggregate); + + warp_prefix = scan_op(initial_value, warp_prefix); + + if (warp_id == 0) + warp_prefix = initial_value; + + return warp_prefix; + } + + //--------------------------------------------------------------------- + // Exclusive scans + //--------------------------------------------------------------------- + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. With no initial value, the output computed for thread0 is undefined. + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item + T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op) ///< [in] Binary scan operator + { + // Compute block-wide exclusive scan. The exclusive output from tid0 is invalid. + T block_aggregate; + ExclusiveScan(input, exclusive_output, scan_op, block_aggregate); + } + + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input items + T &exclusive_output, ///< [out] Calling thread's output items (may be aliased to \p input) + const T &initial_value, ///< [in] Initial value to seed the exclusive scan + ScanOp scan_op) ///< [in] Binary scan operator + { + T block_aggregate; + ExclusiveScan(input, exclusive_output, initial_value, scan_op, block_aggregate); + } + + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. Also provides every thread with the block-wide \p block_aggregate of all inputs. With no initial value, the output computed for thread0 is undefined. + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item + T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T &block_aggregate) ///< [out] Threadblock-wide aggregate reduction of input items + { + WarpScanT my_warp_scan(temp_storage.warp_scan[warp_id]); + + // Compute warp scan in each warp. The exclusive output from each lane0 is invalid. + T inclusive_output; + my_warp_scan.Scan(input, inclusive_output, exclusive_output, scan_op); + + // Compute the warp-wide prefix and block-wide aggregate for each warp. Warp prefix for warp0 is invalid. +// T warp_prefix = ComputeWarpPrefix(scan_op, inclusive_output, block_aggregate); + +//-------------------------------------------------- + // Last lane in each warp shares its warp-aggregate + if (lane_id == WARP_THREADS - 1) + temp_storage.warp_aggregates[warp_id] = inclusive_output; + + CTA_SYNC(); + + // Get the warp scan partial + T warp_inclusive, warp_prefix; + if (lane_id < WARPS) + { + // Scan the warpscan partials + T warp_val = temp_storage.warp_aggregates[lane_id]; + WarpAggregateScanT(temp_storage.inner_scan[warp_id]).Scan(warp_val, warp_inclusive, warp_prefix, scan_op); + } + + warp_prefix = my_warp_scan.Broadcast(warp_prefix, warp_id); + block_aggregate = my_warp_scan.Broadcast(warp_inclusive, WARPS - 1); +//-------------------------------------------------- + + // Apply warp prefix to our lane's partial + if (warp_id != 0) + { + exclusive_output = scan_op(warp_prefix, exclusive_output); + if (lane_id == 0) + exclusive_output = warp_prefix; + } + } + + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input items + T &exclusive_output, ///< [out] Calling thread's output items (may be aliased to \p input) + const T &initial_value, ///< [in] Initial value to seed the exclusive scan + ScanOp scan_op, ///< [in] Binary scan operator + T &block_aggregate) ///< [out] Threadblock-wide aggregate reduction of input items + { + WarpScanT my_warp_scan(temp_storage.warp_scan[warp_id]); + + // Compute warp scan in each warp. The exclusive output from each lane0 is invalid. + T inclusive_output; + my_warp_scan.Scan(input, inclusive_output, exclusive_output, scan_op); + + // Compute the warp-wide prefix and block-wide aggregate for each warp +// T warp_prefix = ComputeWarpPrefix(scan_op, inclusive_output, block_aggregate, initial_value); + +//-------------------------------------------------- + // Last lane in each warp shares its warp-aggregate + if (lane_id == WARP_THREADS - 1) + temp_storage.warp_aggregates[warp_id] = inclusive_output; + + CTA_SYNC(); + + // Get the warp scan partial + T warp_inclusive, warp_prefix; + if (lane_id < WARPS) + { + // Scan the warpscan partials + T warp_val = temp_storage.warp_aggregates[lane_id]; + WarpAggregateScanT(temp_storage.inner_scan[warp_id]).Scan(warp_val, warp_inclusive, warp_prefix, initial_value, scan_op); + } + + warp_prefix = my_warp_scan.Broadcast(warp_prefix, warp_id); + block_aggregate = my_warp_scan.Broadcast(warp_inclusive, WARPS - 1); +//-------------------------------------------------- + + // Apply warp prefix to our lane's partial + exclusive_output = scan_op(warp_prefix, exclusive_output); + if (lane_id == 0) + exclusive_output = warp_prefix; + } + + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_prefix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template < + typename ScanOp, + typename BlockPrefixCallbackOp> + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item + T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + BlockPrefixCallbackOp &block_prefix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a thread block-wide prefix to be applied to all inputs. + { + // Compute block-wide exclusive scan. The exclusive output from tid0 is invalid. + T block_aggregate; + ExclusiveScan(input, exclusive_output, scan_op, block_aggregate); + + // Use the first warp to determine the thread block prefix, returning the result in lane0 + if (warp_id == 0) + { + T block_prefix = block_prefix_callback_op(block_aggregate); + if (lane_id == 0) + { + // Share the prefix with all threads + temp_storage.block_prefix = block_prefix; + exclusive_output = block_prefix; // The block prefix is the exclusive output for tid0 + } + } + + CTA_SYNC(); + + // Incorporate thread block prefix into outputs + T block_prefix = temp_storage.block_prefix; + if (linear_tid > 0) + { + exclusive_output = scan_op(block_prefix, exclusive_output); + } + } + + + //--------------------------------------------------------------------- + // Inclusive scans + //--------------------------------------------------------------------- + + /// Computes an inclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. + template + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item + T &inclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op) ///< [in] Binary scan operator + { + T block_aggregate; + InclusiveScan(input, inclusive_output, scan_op, block_aggregate); + } + + + /// Computes an inclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item + T &inclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T &block_aggregate) ///< [out] Threadblock-wide aggregate reduction of input items + { + WarpScanT(temp_storage.warp_scan[warp_id]).InclusiveScan(input, inclusive_output, scan_op); + + // Compute the warp-wide prefix and block-wide aggregate for each warp. Warp prefix for warp0 is invalid. + T warp_prefix = ComputeWarpPrefix(scan_op, inclusive_output, block_aggregate); + + // Apply warp prefix to our lane's partial + if (warp_id != 0) + { + inclusive_output = scan_op(warp_prefix, inclusive_output); + } + } + + + /// Computes an inclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_prefix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template < + typename ScanOp, + typename BlockPrefixCallbackOp> + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item + T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + BlockPrefixCallbackOp &block_prefix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a thread block-wide prefix to be applied to all inputs. + { + T block_aggregate; + InclusiveScan(input, exclusive_output, scan_op, block_aggregate); + + // Use the first warp to determine the thread block prefix, returning the result in lane0 + if (warp_id == 0) + { + T block_prefix = block_prefix_callback_op(block_aggregate); + if (lane_id == 0) + { + // Share the prefix with all threads + temp_storage.block_prefix = block_prefix; + } + } + + CTA_SYNC(); + + // Incorporate thread block prefix into outputs + T block_prefix = temp_storage.block_prefix; + exclusive_output = scan_op(block_prefix, exclusive_output); + } + + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/block/specializations/block_scan_warp_scans3.cuh b/fastertransformer/cuda/cub/block/specializations/block_scan_warp_scans3.cuh new file mode 100644 index 000000000..147ca4c5a --- /dev/null +++ b/fastertransformer/cuda/cub/block/specializations/block_scan_warp_scans3.cuh @@ -0,0 +1,418 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::BlockScanWarpscans provides warpscan-based variants of parallel prefix scan across a CUDA thread block. + */ + +#pragma once + +#include "../../util_arch.cuh" +#include "../../util_ptx.cuh" +#include "../../warp/warp_scan.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \brief BlockScanWarpScans provides warpscan-based variants of parallel prefix scan across a CUDA thread block. + */ +template < + typename T, + int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension + int BLOCK_DIM_Y, ///< The thread block length in threads along the Y dimension + int BLOCK_DIM_Z, ///< The thread block length in threads along the Z dimension + int PTX_ARCH> ///< The PTX compute capability for which to to specialize this collective +struct BlockScanWarpScans +{ + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + /// Constants + enum + { + /// The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + + /// Number of warp threads + INNER_WARP_THREADS = CUB_WARP_THREADS(PTX_ARCH), + OUTER_WARP_THREADS = BLOCK_THREADS / INNER_WARP_THREADS, + + /// Number of outer scan warps + OUTER_WARPS = INNER_WARP_THREADS + }; + + /// Outer WarpScan utility type + typedef WarpScan OuterWarpScanT; + + /// Inner WarpScan utility type + typedef WarpScan InnerWarpScanT; + + typedef typename OuterWarpScanT::TempStorage OuterScanArray[OUTER_WARPS]; + + + /// Shared memory storage layout type + struct _TempStorage + { + union Aliasable + { + Uninitialized outer_warp_scan; ///< Buffer for warp-synchronous outer scans + typename InnerWarpScanT::TempStorage inner_warp_scan; ///< Buffer for warp-synchronous inner scan + + } aliasable; + + T warp_aggregates[OUTER_WARPS]; + + T block_aggregate; ///< Shared prefix for the entire thread block + }; + + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + // Thread fields + _TempStorage &temp_storage; + unsigned int linear_tid; + unsigned int warp_id; + unsigned int lane_id; + + + //--------------------------------------------------------------------- + // Constructors + //--------------------------------------------------------------------- + + /// Constructor + __device__ __forceinline__ BlockScanWarpScans( + TempStorage &temp_storage) + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)), + warp_id((OUTER_WARPS == 1) ? 0 : linear_tid / OUTER_WARP_THREADS), + lane_id((OUTER_WARPS == 1) ? linear_tid : linear_tid % OUTER_WARP_THREADS) + {} + + + //--------------------------------------------------------------------- + // Exclusive scans + //--------------------------------------------------------------------- + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. With no initial value, the output computed for thread0 is undefined. + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item + T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op) ///< [in] Binary scan operator + { + // Compute block-wide exclusive scan. The exclusive output from tid0 is invalid. + T block_aggregate; + ExclusiveScan(input, exclusive_output, scan_op, block_aggregate); + } + + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input items + T &exclusive_output, ///< [out] Calling thread's output items (may be aliased to \p input) + const T &initial_value, ///< [in] Initial value to seed the exclusive scan + ScanOp scan_op) ///< [in] Binary scan operator + { + T block_aggregate; + ExclusiveScan(input, exclusive_output, initial_value, scan_op, block_aggregate); + } + + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. Also provides every thread with the block-wide \p block_aggregate of all inputs. With no initial value, the output computed for thread0 is undefined. + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item + T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T &block_aggregate) ///< [out] Threadblock-wide aggregate reduction of input items + { + // Compute warp scan in each warp. The exclusive output from each lane0 is invalid. + T inclusive_output; + OuterWarpScanT(temp_storage.aliasable.outer_warp_scan.Alias()[warp_id]).Scan( + input, inclusive_output, exclusive_output, scan_op); + + // Share outer warp total + if (lane_id == OUTER_WARP_THREADS - 1) + temp_storage.warp_aggregates[warp_id] = inclusive_output; + + CTA_SYNC(); + + if (linear_tid < INNER_WARP_THREADS) + { + T outer_warp_input = temp_storage.warp_aggregates[linear_tid]; + T outer_warp_exclusive; + + InnerWarpScanT(temp_storage.aliasable.inner_warp_scan).ExclusiveScan( + outer_warp_input, outer_warp_exclusive, scan_op, block_aggregate); + + temp_storage.block_aggregate = block_aggregate; + temp_storage.warp_aggregates[linear_tid] = outer_warp_exclusive; + } + + CTA_SYNC(); + + if (warp_id != 0) + { + // Retrieve block aggregate + block_aggregate = temp_storage.block_aggregate; + + // Apply warp prefix to our lane's partial + T outer_warp_exclusive = temp_storage.warp_aggregates[warp_id]; + exclusive_output = scan_op(outer_warp_exclusive, exclusive_output); + if (lane_id == 0) + exclusive_output = outer_warp_exclusive; + } + } + + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input items + T &exclusive_output, ///< [out] Calling thread's output items (may be aliased to \p input) + const T &initial_value, ///< [in] Initial value to seed the exclusive scan + ScanOp scan_op, ///< [in] Binary scan operator + T &block_aggregate) ///< [out] Threadblock-wide aggregate reduction of input items + { + // Compute warp scan in each warp. The exclusive output from each lane0 is invalid. + T inclusive_output; + OuterWarpScanT(temp_storage.aliasable.outer_warp_scan.Alias()[warp_id]).Scan( + input, inclusive_output, exclusive_output, scan_op); + + // Share outer warp total + if (lane_id == OUTER_WARP_THREADS - 1) + { + temp_storage.warp_aggregates[warp_id] = inclusive_output; + } + + CTA_SYNC(); + + if (linear_tid < INNER_WARP_THREADS) + { + T outer_warp_input = temp_storage.warp_aggregates[linear_tid]; + T outer_warp_exclusive; + + InnerWarpScanT(temp_storage.aliasable.inner_warp_scan).ExclusiveScan( + outer_warp_input, outer_warp_exclusive, initial_value, scan_op, block_aggregate); + + temp_storage.block_aggregate = block_aggregate; + temp_storage.warp_aggregates[linear_tid] = outer_warp_exclusive; + } + + CTA_SYNC(); + + // Retrieve block aggregate + block_aggregate = temp_storage.block_aggregate; + + // Apply warp prefix to our lane's partial + T outer_warp_exclusive = temp_storage.warp_aggregates[warp_id]; + exclusive_output = scan_op(outer_warp_exclusive, exclusive_output); + if (lane_id == 0) + exclusive_output = outer_warp_exclusive; + } + + + /// Computes an exclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. The call-back functor \p block_prefix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + template < + typename ScanOp, + typename BlockPrefixCallbackOp> + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item + T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + BlockPrefixCallbackOp &block_prefix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a thread block-wide prefix to be applied to all inputs. + { + // Compute warp scan in each warp. The exclusive output from each lane0 is invalid. + T inclusive_output; + OuterWarpScanT(temp_storage.aliasable.outer_warp_scan.Alias()[warp_id]).Scan( + input, inclusive_output, exclusive_output, scan_op); + + // Share outer warp total + if (lane_id == OUTER_WARP_THREADS - 1) + temp_storage.warp_aggregates[warp_id] = inclusive_output; + + CTA_SYNC(); + + if (linear_tid < INNER_WARP_THREADS) + { + InnerWarpScanT inner_scan(temp_storage.aliasable.inner_warp_scan); + + T upsweep = temp_storage.warp_aggregates[linear_tid]; + T downsweep_prefix, block_aggregate; + + inner_scan.ExclusiveScan(upsweep, downsweep_prefix, scan_op, block_aggregate); + + // Use callback functor to get block prefix in lane0 and then broadcast to other lanes + T block_prefix = block_prefix_callback_op(block_aggregate); + block_prefix = inner_scan.Broadcast(block_prefix, 0); + + downsweep_prefix = scan_op(block_prefix, downsweep_prefix); + if (linear_tid == 0) + downsweep_prefix = block_prefix; + + temp_storage.warp_aggregates[linear_tid] = downsweep_prefix; + } + + CTA_SYNC(); + + // Apply warp prefix to our lane's partial (or assign it if partial is invalid) + T outer_warp_exclusive = temp_storage.warp_aggregates[warp_id]; + exclusive_output = scan_op(outer_warp_exclusive, exclusive_output); + if (lane_id == 0) + exclusive_output = outer_warp_exclusive; + } + + + //--------------------------------------------------------------------- + // Inclusive scans + //--------------------------------------------------------------------- + + /// Computes an inclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. + template + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item + T &inclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op) ///< [in] Binary scan operator + { + T block_aggregate; + InclusiveScan(input, inclusive_output, scan_op, block_aggregate); + } + + + /// Computes an inclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item + T &inclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T &block_aggregate) ///< [out] Threadblock-wide aggregate reduction of input items + { + // Compute warp scan in each warp. The exclusive output from each lane0 is invalid. + OuterWarpScanT(temp_storage.aliasable.outer_warp_scan.Alias()[warp_id]).InclusiveScan( + input, inclusive_output, scan_op); + + // Share outer warp total + if (lane_id == OUTER_WARP_THREADS - 1) + temp_storage.warp_aggregates[warp_id] = inclusive_output; + + CTA_SYNC(); + + if (linear_tid < INNER_WARP_THREADS) + { + T outer_warp_input = temp_storage.warp_aggregates[linear_tid]; + T outer_warp_exclusive; + + InnerWarpScanT(temp_storage.aliasable.inner_warp_scan).ExclusiveScan( + outer_warp_input, outer_warp_exclusive, scan_op, block_aggregate); + + temp_storage.block_aggregate = block_aggregate; + temp_storage.warp_aggregates[linear_tid] = outer_warp_exclusive; + } + + CTA_SYNC(); + + if (warp_id != 0) + { + // Retrieve block aggregate + block_aggregate = temp_storage.block_aggregate; + + // Apply warp prefix to our lane's partial + T outer_warp_exclusive = temp_storage.warp_aggregates[warp_id]; + inclusive_output = scan_op(outer_warp_exclusive, inclusive_output); + } + } + + + /// Computes an inclusive thread block-wide prefix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_prefix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + template < + typename ScanOp, + typename BlockPrefixCallbackOp> + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item + T &inclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + BlockPrefixCallbackOp &block_prefix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a thread block-wide prefix to be applied to all inputs. + { + // Compute warp scan in each warp. The exclusive output from each lane0 is invalid. + OuterWarpScanT(temp_storage.aliasable.outer_warp_scan.Alias()[warp_id]).InclusiveScan( + input, inclusive_output, scan_op); + + // Share outer warp total + if (lane_id == OUTER_WARP_THREADS - 1) + temp_storage.warp_aggregates[warp_id] = inclusive_output; + + CTA_SYNC(); + + if (linear_tid < INNER_WARP_THREADS) + { + InnerWarpScanT inner_scan(temp_storage.aliasable.inner_warp_scan); + + T upsweep = temp_storage.warp_aggregates[linear_tid]; + T downsweep_prefix, block_aggregate; + inner_scan.ExclusiveScan(upsweep, downsweep_prefix, scan_op, block_aggregate); + + // Use callback functor to get block prefix in lane0 and then broadcast to other lanes + T block_prefix = block_prefix_callback_op(block_aggregate); + block_prefix = inner_scan.Broadcast(block_prefix, 0); + + downsweep_prefix = scan_op(block_prefix, downsweep_prefix); + if (linear_tid == 0) + downsweep_prefix = block_prefix; + + temp_storage.warp_aggregates[linear_tid] = downsweep_prefix; + } + + CTA_SYNC(); + + // Apply warp prefix to our lane's partial + T outer_warp_exclusive = temp_storage.warp_aggregates[warp_id]; + inclusive_output = scan_op(outer_warp_exclusive, inclusive_output); + } + + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/cub.cuh b/fastertransformer/cuda/cub/cub.cuh new file mode 100644 index 000000000..3ece0f658 --- /dev/null +++ b/fastertransformer/cuda/cub/cub.cuh @@ -0,0 +1,95 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * CUB umbrella include file + */ + +#pragma once + + +// Block +#include "block/block_histogram.cuh" +#include "block/block_discontinuity.cuh" +#include "block/block_exchange.cuh" +#include "block/block_load.cuh" +#include "block/block_radix_rank.cuh" +#include "block/block_radix_sort.cuh" +#include "block/block_reduce.cuh" +#include "block/block_scan.cuh" +#include "block/block_store.cuh" +//#include "block/block_shift.cuh" + +// Device +#include "device/device_histogram.cuh" +#include "device/device_partition.cuh" +#include "device/device_radix_sort.cuh" +#include "device/device_reduce.cuh" +#include "device/device_run_length_encode.cuh" +#include "device/device_scan.cuh" +#include "device/device_segmented_radix_sort.cuh" +#include "device/device_segmented_reduce.cuh" +#include "device/device_select.cuh" +#include "device/device_spmv.cuh" + +// Grid +//#include "grid/grid_barrier.cuh" +#include "grid/grid_even_share.cuh" +#include "grid/grid_mapping.cuh" +#include "grid/grid_queue.cuh" + +// Thread +#include "thread/thread_load.cuh" +#include "thread/thread_operators.cuh" +#include "thread/thread_reduce.cuh" +#include "thread/thread_scan.cuh" +#include "thread/thread_store.cuh" + +// Warp +#include "warp/warp_reduce.cuh" +#include "warp/warp_scan.cuh" + +// Iterator +#include "iterator/arg_index_input_iterator.cuh" +#include "iterator/cache_modified_input_iterator.cuh" +#include "iterator/cache_modified_output_iterator.cuh" +#include "iterator/constant_input_iterator.cuh" +#include "iterator/counting_input_iterator.cuh" +#include "iterator/tex_obj_input_iterator.cuh" +#include "iterator/tex_ref_input_iterator.cuh" +#include "iterator/transform_input_iterator.cuh" + +// Util +#include "util_arch.cuh" +#include "util_debug.cuh" +#include "util_device.cuh" +#include "util_macro.cuh" +#include "util_ptx.cuh" +#include "util_type.cuh" + diff --git a/fastertransformer/cuda/cub/device/device_histogram.cuh b/fastertransformer/cuda/cub/device/device_histogram.cuh new file mode 100644 index 000000000..a2556a6b8 --- /dev/null +++ b/fastertransformer/cuda/cub/device/device_histogram.cuh @@ -0,0 +1,866 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DeviceHistogram provides device-wide parallel operations for constructing histogram(s) from a sequence of samples data residing within device-accessible memory. + */ + +#pragma once + +#include +#include +#include + +#include "dispatch/dispatch_histogram.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \brief DeviceHistogram provides device-wide parallel operations for constructing histogram(s) from a sequence of samples data residing within device-accessible memory. ![](histogram_logo.png) + * \ingroup SingleModule + * + * \par Overview + * A histogram + * counts the number of observations that fall into each of the disjoint categories (known as bins). + * + * \par Usage Considerations + * \cdp_class{DeviceHistogram} + * + */ +struct DeviceHistogram +{ + /******************************************************************//** + * \name Evenly-segmented bin ranges + *********************************************************************/ + //@{ + + /** + * \brief Computes an intensity histogram from a sequence of data samples using equal-width bins. + * + * \par + * - The number of histogram bins is (\p num_levels - 1) + * - All bins comprise the same width of sample values: (\p upper_level - \p lower_level) / (\p num_levels - 1) + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the computation of a six-bin histogram + * from a sequence of float samples + * + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input samples and + * // output histogram + * int num_samples; // e.g., 10 + * float* d_samples; // e.g., [2.2, 6.0, 7.1, 2.9, 3.5, 0.3, 2.9, 2.0, 6.1, 999.5] + * int* d_histogram; // e.g., [ -, -, -, -, -, -, -, -] + * int num_levels; // e.g., 7 (seven level boundaries for six bins) + * float lower_level; // e.g., 0.0 (lower sample value boundary of lowest bin) + * float upper_level; // e.g., 12.0 (upper sample value boundary of upper bin) + * ... + * + * // Determine temporary device storage requirements + * void* d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceHistogram::HistogramEven(d_temp_storage, temp_storage_bytes, + * d_samples, d_histogram, num_levels, lower_level, upper_level, num_samples); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Compute histograms + * cub::DeviceHistogram::HistogramEven(d_temp_storage, temp_storage_bytes, + * d_samples, d_histogram, num_levels, lower_level, upper_level, num_samples); + * + * // d_histogram <-- [1, 0, 5, 0, 3, 0, 0, 0]; + * + * \endcode + * + * \tparam SampleIteratorT [inferred] Random-access input iterator type for reading input samples. \iterator + * \tparam CounterT [inferred] Integer type for histogram bin counters + * \tparam LevelT [inferred] Type for specifying boundaries (levels) + * \tparam OffsetT [inferred] Signed integer type for sequence offsets, list lengths, pointer differences, etc. \offset_size1 + */ + template < + typename SampleIteratorT, + typename CounterT, + typename LevelT, + typename OffsetT> + CUB_RUNTIME_FUNCTION + static cudaError_t HistogramEven( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + SampleIteratorT d_samples, ///< [in] The pointer to the input sequence of data samples. + CounterT* d_histogram, ///< [out] The pointer to the histogram counter output array of length num_levels - 1. + int num_levels, ///< [in] The number of boundaries (levels) for delineating histogram samples. Implies that the number of bins is num_levels - 1. + LevelT lower_level, ///< [in] The lower sample value bound (inclusive) for the lowest histogram bin. + LevelT upper_level, ///< [in] The upper sample value bound (exclusive) for the highest histogram bin. + OffsetT num_samples, ///< [in] The number of input samples (i.e., the length of \p d_samples) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + /// The sample value type of the input iterator + typedef typename std::iterator_traits::value_type SampleT; + + CounterT* d_histogram1[1] = {d_histogram}; + int num_levels1[1] = {num_levels}; + LevelT lower_level1[1] = {lower_level}; + LevelT upper_level1[1] = {upper_level}; + + return MultiHistogramEven<1, 1>( + d_temp_storage, + temp_storage_bytes, + d_samples, + d_histogram1, + num_levels1, + lower_level1, + upper_level1, + num_samples, + 1, + sizeof(SampleT) * num_samples, + stream, + debug_synchronous); + } + + + /** + * \brief Computes an intensity histogram from a sequence of data samples using equal-width bins. + * + * \par + * - A two-dimensional region of interest within \p d_samples can be specified + * using the \p num_row_samples, num_rows, and \p row_stride_bytes parameters. + * - The row stride must be a whole multiple of the sample data type + * size, i.e., (row_stride_bytes % sizeof(SampleT)) == 0. + * - The number of histogram bins is (\p num_levels - 1) + * - All bins comprise the same width of sample values: (\p upper_level - \p lower_level) / (\p num_levels - 1) + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the computation of a six-bin histogram + * from a 2x5 region of interest within a flattened 2x7 array of float samples. + * + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input samples and + * // output histogram + * int num_row_samples; // e.g., 5 + * int num_rows; // e.g., 2; + * size_t row_stride_bytes; // e.g., 7 * sizeof(float) + * float* d_samples; // e.g., [2.2, 6.0, 7.1, 2.9, 3.5, -, -, + * // 0.3, 2.9, 2.0, 6.1, 999.5, -, -] + * int* d_histogram; // e.g., [ -, -, -, -, -, -, -, -] + * int num_levels; // e.g., 7 (seven level boundaries for six bins) + * float lower_level; // e.g., 0.0 (lower sample value boundary of lowest bin) + * float upper_level; // e.g., 12.0 (upper sample value boundary of upper bin) + * ... + * + * // Determine temporary device storage requirements + * void* d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceHistogram::HistogramEven(d_temp_storage, temp_storage_bytes, + * d_samples, d_histogram, num_levels, lower_level, upper_level, + * num_row_samples, num_rows, row_stride_bytes); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Compute histograms + * cub::DeviceHistogram::HistogramEven(d_temp_storage, temp_storage_bytes, d_samples, d_histogram, + * d_samples, d_histogram, num_levels, lower_level, upper_level, + * num_row_samples, num_rows, row_stride_bytes); + * + * // d_histogram <-- [1, 0, 5, 0, 3, 0, 0, 0]; + * + * \endcode + * + * \tparam SampleIteratorT [inferred] Random-access input iterator type for reading input samples. \iterator + * \tparam CounterT [inferred] Integer type for histogram bin counters + * \tparam LevelT [inferred] Type for specifying boundaries (levels) + * \tparam OffsetT [inferred] Signed integer type for sequence offsets, list lengths, pointer differences, etc. \offset_size1 + */ + template < + typename SampleIteratorT, + typename CounterT, + typename LevelT, + typename OffsetT> + CUB_RUNTIME_FUNCTION + static cudaError_t HistogramEven( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + SampleIteratorT d_samples, ///< [in] The pointer to the input sequence of data samples. + CounterT* d_histogram, ///< [out] The pointer to the histogram counter output array of length num_levels - 1. + int num_levels, ///< [in] The number of boundaries (levels) for delineating histogram samples. Implies that the number of bins is num_levels - 1. + LevelT lower_level, ///< [in] The lower sample value bound (inclusive) for the lowest histogram bin. + LevelT upper_level, ///< [in] The upper sample value bound (exclusive) for the highest histogram bin. + OffsetT num_row_samples, ///< [in] The number of data samples per row in the region of interest + OffsetT num_rows, ///< [in] The number of rows in the region of interest + size_t row_stride_bytes, ///< [in] The number of bytes between starts of consecutive rows in the region of interest + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + CounterT* d_histogram1[1] = {d_histogram}; + int num_levels1[1] = {num_levels}; + LevelT lower_level1[1] = {lower_level}; + LevelT upper_level1[1] = {upper_level}; + + return MultiHistogramEven<1, 1>( + d_temp_storage, + temp_storage_bytes, + d_samples, + d_histogram1, + num_levels1, + lower_level1, + upper_level1, + num_row_samples, + num_rows, + row_stride_bytes, + stream, + debug_synchronous); + } + + /** + * \brief Computes per-channel intensity histograms from a sequence of multi-channel "pixel" data samples using equal-width bins. + * + * \par + * - The input is a sequence of pixel structures, where each pixel comprises + * a record of \p NUM_CHANNELS consecutive data samples (e.g., an RGBA pixel). + * - Of the \p NUM_CHANNELS specified, the function will only compute histograms + * for the first \p NUM_ACTIVE_CHANNELS (e.g., only RGB histograms from RGBA + * pixel samples). + * - The number of histogram bins for channeli is num_levels[i] - 1. + * - For channeli, the range of values for all histogram bins + * have the same width: (upper_level[i] - lower_level[i]) / ( num_levels[i] - 1) + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the computation of three 256-bin RGB histograms + * from a quad-channel sequence of RGBA pixels (8 bits per channel per pixel) + * + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input samples + * // and output histograms + * int num_pixels; // e.g., 5 + * unsigned char* d_samples; // e.g., [(2, 6, 7, 5), (3, 0, 2, 1), (7, 0, 6, 2), + * // (0, 6, 7, 5), (3, 0, 2, 6)] + * int* d_histogram[3]; // e.g., three device pointers to three device buffers, + * // each allocated with 256 integer counters + * int num_levels[3]; // e.g., {257, 257, 257}; + * unsigned int lower_level[3]; // e.g., {0, 0, 0}; + * unsigned int upper_level[3]; // e.g., {256, 256, 256}; + * ... + * + * // Determine temporary device storage requirements + * void* d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceHistogram::MultiHistogramEven<4, 3>(d_temp_storage, temp_storage_bytes, + * d_samples, d_histogram, num_levels, lower_level, upper_level, num_pixels); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Compute histograms + * cub::DeviceHistogram::MultiHistogramEven<4, 3>(d_temp_storage, temp_storage_bytes, + * d_samples, d_histogram, num_levels, lower_level, upper_level, num_pixels); + * + * // d_histogram <-- [ [1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0, ..., 0], + * // [0, 3, 0, 0, 0, 0, 2, 0, 0, 0, 0, ..., 0], + * // [0, 0, 2, 0, 0, 0, 1, 2, 0, 0, 0, ..., 0] ] + * + * \endcode + * + * \tparam NUM_CHANNELS Number of channels interleaved in the input data (may be greater than the number of channels being actively histogrammed) + * \tparam NUM_ACTIVE_CHANNELS [inferred] Number of channels actively being histogrammed + * \tparam SampleIteratorT [inferred] Random-access input iterator type for reading input samples. \iterator + * \tparam CounterT [inferred] Integer type for histogram bin counters + * \tparam LevelT [inferred] Type for specifying boundaries (levels) + * \tparam OffsetT [inferred] Signed integer type for sequence offsets, list lengths, pointer differences, etc. \offset_size1 + */ + template < + int NUM_CHANNELS, + int NUM_ACTIVE_CHANNELS, + typename SampleIteratorT, + typename CounterT, + typename LevelT, + typename OffsetT> + CUB_RUNTIME_FUNCTION + static cudaError_t MultiHistogramEven( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + SampleIteratorT d_samples, ///< [in] The pointer to the multi-channel input sequence of data samples. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four RGBA 8-bit samples). + CounterT* d_histogram[NUM_ACTIVE_CHANNELS], ///< [out] The pointers to the histogram counter output arrays, one for each active channel. For channeli, the allocation length of d_histogram[i] should be num_levels[i] - 1. + int num_levels[NUM_ACTIVE_CHANNELS], ///< [in] The number of boundaries (levels) for delineating histogram samples in each active channel. Implies that the number of bins for channeli is num_levels[i] - 1. + LevelT lower_level[NUM_ACTIVE_CHANNELS], ///< [in] The lower sample value bound (inclusive) for the lowest histogram bin in each active channel. + LevelT upper_level[NUM_ACTIVE_CHANNELS], ///< [in] The upper sample value bound (exclusive) for the highest histogram bin in each active channel. + OffsetT num_pixels, ///< [in] The number of multi-channel pixels (i.e., the length of \p d_samples / NUM_CHANNELS) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + /// The sample value type of the input iterator + typedef typename std::iterator_traits::value_type SampleT; + + return MultiHistogramEven( + d_temp_storage, + temp_storage_bytes, + d_samples, + d_histogram, + num_levels, + lower_level, + upper_level, + num_pixels, + 1, + sizeof(SampleT) * NUM_CHANNELS * num_pixels, + stream, + debug_synchronous); + } + + + /** + * \brief Computes per-channel intensity histograms from a sequence of multi-channel "pixel" data samples using equal-width bins. + * + * \par + * - The input is a sequence of pixel structures, where each pixel comprises + * a record of \p NUM_CHANNELS consecutive data samples (e.g., an RGBA pixel). + * - Of the \p NUM_CHANNELS specified, the function will only compute histograms + * for the first \p NUM_ACTIVE_CHANNELS (e.g., only RGB histograms from RGBA + * pixel samples). + * - A two-dimensional region of interest within \p d_samples can be specified + * using the \p num_row_samples, num_rows, and \p row_stride_bytes parameters. + * - The row stride must be a whole multiple of the sample data type + * size, i.e., (row_stride_bytes % sizeof(SampleT)) == 0. + * - The number of histogram bins for channeli is num_levels[i] - 1. + * - For channeli, the range of values for all histogram bins + * have the same width: (upper_level[i] - lower_level[i]) / ( num_levels[i] - 1) + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the computation of three 256-bin RGB histograms from a 2x3 region of + * interest of within a flattened 2x4 array of quad-channel RGBA pixels (8 bits per channel per pixel). + * + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input samples + * // and output histograms + * int num_row_pixels; // e.g., 3 + * int num_rows; // e.g., 2 + * size_t row_stride_bytes; // e.g., 4 * sizeof(unsigned char) * NUM_CHANNELS + * unsigned char* d_samples; // e.g., [(2, 6, 7, 5), (3, 0, 2, 1), (7, 0, 6, 2), (-, -, -, -), + * // (0, 6, 7, 5), (3, 0, 2, 6), (1, 1, 1, 1), (-, -, -, -)] + * int* d_histogram[3]; // e.g., three device pointers to three device buffers, + * // each allocated with 256 integer counters + * int num_levels[3]; // e.g., {257, 257, 257}; + * unsigned int lower_level[3]; // e.g., {0, 0, 0}; + * unsigned int upper_level[3]; // e.g., {256, 256, 256}; + * ... + * + * // Determine temporary device storage requirements + * void* d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceHistogram::MultiHistogramEven<4, 3>(d_temp_storage, temp_storage_bytes, + * d_samples, d_histogram, num_levels, lower_level, upper_level, + * num_row_pixels, num_rows, row_stride_bytes); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Compute histograms + * cub::DeviceHistogram::MultiHistogramEven<4, 3>(d_temp_storage, temp_storage_bytes, + * d_samples, d_histogram, num_levels, lower_level, upper_level, + * num_row_pixels, num_rows, row_stride_bytes); + * + * // d_histogram <-- [ [1, 1, 1, 2, 0, 0, 0, 1, 0, 0, 0, ..., 0], + * // [0, 4, 0, 0, 0, 0, 2, 0, 0, 0, 0, ..., 0], + * // [0, 1, 2, 0, 0, 0, 1, 2, 0, 0, 0, ..., 0] ] + * + * \endcode + * + * \tparam NUM_CHANNELS Number of channels interleaved in the input data (may be greater than the number of channels being actively histogrammed) + * \tparam NUM_ACTIVE_CHANNELS [inferred] Number of channels actively being histogrammed + * \tparam SampleIteratorT [inferred] Random-access input iterator type for reading input samples. \iterator + * \tparam CounterT [inferred] Integer type for histogram bin counters + * \tparam LevelT [inferred] Type for specifying boundaries (levels) + * \tparam OffsetT [inferred] Signed integer type for sequence offsets, list lengths, pointer differences, etc. \offset_size1 + */ + template < + int NUM_CHANNELS, + int NUM_ACTIVE_CHANNELS, + typename SampleIteratorT, + typename CounterT, + typename LevelT, + typename OffsetT> + CUB_RUNTIME_FUNCTION + static cudaError_t MultiHistogramEven( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + SampleIteratorT d_samples, ///< [in] The pointer to the multi-channel input sequence of data samples. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four RGBA 8-bit samples). + CounterT* d_histogram[NUM_ACTIVE_CHANNELS], ///< [out] The pointers to the histogram counter output arrays, one for each active channel. For channeli, the allocation length of d_histogram[i] should be num_levels[i] - 1. + int num_levels[NUM_ACTIVE_CHANNELS], ///< [in] The number of boundaries (levels) for delineating histogram samples in each active channel. Implies that the number of bins for channeli is num_levels[i] - 1. + LevelT lower_level[NUM_ACTIVE_CHANNELS], ///< [in] The lower sample value bound (inclusive) for the lowest histogram bin in each active channel. + LevelT upper_level[NUM_ACTIVE_CHANNELS], ///< [in] The upper sample value bound (exclusive) for the highest histogram bin in each active channel. + OffsetT num_row_pixels, ///< [in] The number of multi-channel pixels per row in the region of interest + OffsetT num_rows, ///< [in] The number of rows in the region of interest + size_t row_stride_bytes, ///< [in] The number of bytes between starts of consecutive rows in the region of interest + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + /// The sample value type of the input iterator + typedef typename std::iterator_traits::value_type SampleT; + Int2Type is_byte_sample; + + if ((sizeof(OffsetT) > sizeof(int)) && + ((unsigned long long) (num_rows * row_stride_bytes) < (unsigned long long) std::numeric_limits::max())) + { + // Down-convert OffsetT data type + + + return DipatchHistogram::DispatchEven( + d_temp_storage, temp_storage_bytes, d_samples, d_histogram, num_levels, lower_level, upper_level, + (int) num_row_pixels, (int) num_rows, (int) (row_stride_bytes / sizeof(SampleT)), + stream, debug_synchronous, is_byte_sample); + } + + return DipatchHistogram::DispatchEven( + d_temp_storage, temp_storage_bytes, d_samples, d_histogram, num_levels, lower_level, upper_level, + num_row_pixels, num_rows, (OffsetT) (row_stride_bytes / sizeof(SampleT)), + stream, debug_synchronous, is_byte_sample); + } + + + //@} end member group + /******************************************************************//** + * \name Custom bin ranges + *********************************************************************/ + //@{ + + /** + * \brief Computes an intensity histogram from a sequence of data samples using the specified bin boundary levels. + * + * \par + * - The number of histogram bins is (\p num_levels - 1) + * - The value range for bini is [level[i], level[i+1]) + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the computation of an six-bin histogram + * from a sequence of float samples + * + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input samples and + * // output histogram + * int num_samples; // e.g., 10 + * float* d_samples; // e.g., [2.2, 6.0, 7.1, 2.9, 3.5, 0.3, 2.9, 2.0, 6.1, 999.5] + * int* d_histogram; // e.g., [ -, -, -, -, -, -, -, -] + * int num_levels // e.g., 7 (seven level boundaries for six bins) + * float* d_levels; // e.g., [0.0, 2.0, 4.0, 6.0, 8.0, 12.0, 16.0] + * ... + * + * // Determine temporary device storage requirements + * void* d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceHistogram::HistogramRange(d_temp_storage, temp_storage_bytes, + * d_samples, d_histogram, num_levels, d_levels, num_samples); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Compute histograms + * cub::DeviceHistogram::HistogramRange(d_temp_storage, temp_storage_bytes, + * d_samples, d_histogram, num_levels, d_levels, num_samples); + * + * // d_histogram <-- [1, 0, 5, 0, 3, 0, 0, 0]; + * + * \endcode + * + * \tparam SampleIteratorT [inferred] Random-access input iterator type for reading input samples. \iterator + * \tparam CounterT [inferred] Integer type for histogram bin counters + * \tparam LevelT [inferred] Type for specifying boundaries (levels) + * \tparam OffsetT [inferred] Signed integer type for sequence offsets, list lengths, pointer differences, etc. \offset_size1 + */ + template < + typename SampleIteratorT, + typename CounterT, + typename LevelT, + typename OffsetT> + CUB_RUNTIME_FUNCTION + static cudaError_t HistogramRange( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + SampleIteratorT d_samples, ///< [in] The pointer to the input sequence of data samples. + CounterT* d_histogram, ///< [out] The pointer to the histogram counter output array of length num_levels - 1. + int num_levels, ///< [in] The number of boundaries (levels) for delineating histogram samples. Implies that the number of bins is num_levels - 1. + LevelT* d_levels, ///< [in] The pointer to the array of boundaries (levels). Bin ranges are defined by consecutive boundary pairings: lower sample value boundaries are inclusive and upper sample value boundaries are exclusive. + OffsetT num_samples, ///< [in] The number of data samples per row in the region of interest + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + /// The sample value type of the input iterator + typedef typename std::iterator_traits::value_type SampleT; + + CounterT* d_histogram1[1] = {d_histogram}; + int num_levels1[1] = {num_levels}; + LevelT* d_levels1[1] = {d_levels}; + + return MultiHistogramRange<1, 1>( + d_temp_storage, + temp_storage_bytes, + d_samples, + d_histogram1, + num_levels1, + d_levels1, + num_samples, + 1, + sizeof(SampleT) * num_samples, + stream, + debug_synchronous); + } + + + /** + * \brief Computes an intensity histogram from a sequence of data samples using the specified bin boundary levels. + * + * \par + * - A two-dimensional region of interest within \p d_samples can be specified + * using the \p num_row_samples, num_rows, and \p row_stride_bytes parameters. + * - The row stride must be a whole multiple of the sample data type + * size, i.e., (row_stride_bytes % sizeof(SampleT)) == 0. + * - The number of histogram bins is (\p num_levels - 1) + * - The value range for bini is [level[i], level[i+1]) + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the computation of a six-bin histogram + * from a 2x5 region of interest within a flattened 2x7 array of float samples. + * + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input samples and + * // output histogram + * int num_row_samples; // e.g., 5 + * int num_rows; // e.g., 2; + * int row_stride_bytes; // e.g., 7 * sizeof(float) + * float* d_samples; // e.g., [2.2, 6.0, 7.1, 2.9, 3.5, -, -, + * // 0.3, 2.9, 2.0, 6.1, 999.5, -, -] + * int* d_histogram; // e.g., [ , , , , , , , ] + * int num_levels // e.g., 7 (seven level boundaries for six bins) + * float *d_levels; // e.g., [0.0, 2.0, 4.0, 6.0, 8.0, 12.0, 16.0] + * ... + * + * // Determine temporary device storage requirements + * void* d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceHistogram::HistogramRange(d_temp_storage, temp_storage_bytes, + * d_samples, d_histogram, num_levels, d_levels, + * num_row_samples, num_rows, row_stride_bytes); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Compute histograms + * cub::DeviceHistogram::HistogramRange(d_temp_storage, temp_storage_bytes, + * d_samples, d_histogram, num_levels, d_levels, + * num_row_samples, num_rows, row_stride_bytes); + * + * // d_histogram <-- [1, 0, 5, 0, 3, 0, 0, 0]; + * + * \endcode + * + * \tparam SampleIteratorT [inferred] Random-access input iterator type for reading input samples. \iterator + * \tparam CounterT [inferred] Integer type for histogram bin counters + * \tparam LevelT [inferred] Type for specifying boundaries (levels) + * \tparam OffsetT [inferred] Signed integer type for sequence offsets, list lengths, pointer differences, etc. \offset_size1 + */ + template < + typename SampleIteratorT, + typename CounterT, + typename LevelT, + typename OffsetT> + CUB_RUNTIME_FUNCTION + static cudaError_t HistogramRange( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + SampleIteratorT d_samples, ///< [in] The pointer to the input sequence of data samples. + CounterT* d_histogram, ///< [out] The pointer to the histogram counter output array of length num_levels - 1. + int num_levels, ///< [in] The number of boundaries (levels) for delineating histogram samples. Implies that the number of bins is num_levels - 1. + LevelT* d_levels, ///< [in] The pointer to the array of boundaries (levels). Bin ranges are defined by consecutive boundary pairings: lower sample value boundaries are inclusive and upper sample value boundaries are exclusive. + OffsetT num_row_samples, ///< [in] The number of data samples per row in the region of interest + OffsetT num_rows, ///< [in] The number of rows in the region of interest + size_t row_stride_bytes, ///< [in] The number of bytes between starts of consecutive rows in the region of interest + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + CounterT* d_histogram1[1] = {d_histogram}; + int num_levels1[1] = {num_levels}; + LevelT* d_levels1[1] = {d_levels}; + + return MultiHistogramRange<1, 1>( + d_temp_storage, + temp_storage_bytes, + d_samples, + d_histogram1, + num_levels1, + d_levels1, + num_row_samples, + num_rows, + row_stride_bytes, + stream, + debug_synchronous); + } + + /** + * \brief Computes per-channel intensity histograms from a sequence of multi-channel "pixel" data samples using the specified bin boundary levels. + * + * \par + * - The input is a sequence of pixel structures, where each pixel comprises + * a record of \p NUM_CHANNELS consecutive data samples (e.g., an RGBA pixel). + * - Of the \p NUM_CHANNELS specified, the function will only compute histograms + * for the first \p NUM_ACTIVE_CHANNELS (e.g., RGB histograms from RGBA + * pixel samples). + * - The number of histogram bins for channeli is num_levels[i] - 1. + * - For channeli, the range of values for all histogram bins + * have the same width: (upper_level[i] - lower_level[i]) / ( num_levels[i] - 1) + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the computation of three 4-bin RGB histograms + * from a quad-channel sequence of RGBA pixels (8 bits per channel per pixel) + * + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input samples + * // and output histograms + * int num_pixels; // e.g., 5 + * unsigned char *d_samples; // e.g., [(2, 6, 7, 5),(3, 0, 2, 1),(7, 0, 6, 2), + * // (0, 6, 7, 5),(3, 0, 2, 6)] + * unsigned int *d_histogram[3]; // e.g., [[ -, -, -, -],[ -, -, -, -],[ -, -, -, -]]; + * int num_levels[3]; // e.g., {5, 5, 5}; + * unsigned int *d_levels[3]; // e.g., [ [0, 2, 4, 6, 8], + * // [0, 2, 4, 6, 8], + * // [0, 2, 4, 6, 8] ]; + * ... + * + * // Determine temporary device storage requirements + * void* d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceHistogram::MultiHistogramRange<4, 3>(d_temp_storage, temp_storage_bytes, + * d_samples, d_histogram, num_levels, d_levels, num_pixels); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Compute histograms + * cub::DeviceHistogram::MultiHistogramRange<4, 3>(d_temp_storage, temp_storage_bytes, + * d_samples, d_histogram, num_levels, d_levels, num_pixels); + * + * // d_histogram <-- [ [1, 3, 0, 1], + * // [3, 0, 0, 2], + * // [0, 2, 0, 3] ] + * + * \endcode + * + * \tparam NUM_CHANNELS Number of channels interleaved in the input data (may be greater than the number of channels being actively histogrammed) + * \tparam NUM_ACTIVE_CHANNELS [inferred] Number of channels actively being histogrammed + * \tparam SampleIteratorT [inferred] Random-access input iterator type for reading input samples. \iterator + * \tparam CounterT [inferred] Integer type for histogram bin counters + * \tparam LevelT [inferred] Type for specifying boundaries (levels) + * \tparam OffsetT [inferred] Signed integer type for sequence offsets, list lengths, pointer differences, etc. \offset_size1 + */ + template < + int NUM_CHANNELS, + int NUM_ACTIVE_CHANNELS, + typename SampleIteratorT, + typename CounterT, + typename LevelT, + typename OffsetT> + CUB_RUNTIME_FUNCTION + static cudaError_t MultiHistogramRange( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + SampleIteratorT d_samples, ///< [in] The pointer to the multi-channel input sequence of data samples. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four RGBA 8-bit samples). + CounterT* d_histogram[NUM_ACTIVE_CHANNELS], ///< [out] The pointers to the histogram counter output arrays, one for each active channel. For channeli, the allocation length of d_histogram[i] should be num_levels[i] - 1. + int num_levels[NUM_ACTIVE_CHANNELS], ///< [in] The number of boundaries (levels) for delineating histogram samples in each active channel. Implies that the number of bins for channeli is num_levels[i] - 1. + LevelT* d_levels[NUM_ACTIVE_CHANNELS], ///< [in] The pointers to the arrays of boundaries (levels), one for each active channel. Bin ranges are defined by consecutive boundary pairings: lower sample value boundaries are inclusive and upper sample value boundaries are exclusive. + OffsetT num_pixels, ///< [in] The number of multi-channel pixels (i.e., the length of \p d_samples / NUM_CHANNELS) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + /// The sample value type of the input iterator + typedef typename std::iterator_traits::value_type SampleT; + + return MultiHistogramRange( + d_temp_storage, + temp_storage_bytes, + d_samples, + d_histogram, + num_levels, + d_levels, + num_pixels, + 1, + sizeof(SampleT) * NUM_CHANNELS * num_pixels, + stream, + debug_synchronous); + } + + + /** + * \brief Computes per-channel intensity histograms from a sequence of multi-channel "pixel" data samples using the specified bin boundary levels. + * + * \par + * - The input is a sequence of pixel structures, where each pixel comprises + * a record of \p NUM_CHANNELS consecutive data samples (e.g., an RGBA pixel). + * - Of the \p NUM_CHANNELS specified, the function will only compute histograms + * for the first \p NUM_ACTIVE_CHANNELS (e.g., RGB histograms from RGBA + * pixel samples). + * - A two-dimensional region of interest within \p d_samples can be specified + * using the \p num_row_samples, num_rows, and \p row_stride_bytes parameters. + * - The row stride must be a whole multiple of the sample data type + * size, i.e., (row_stride_bytes % sizeof(SampleT)) == 0. + * - The number of histogram bins for channeli is num_levels[i] - 1. + * - For channeli, the range of values for all histogram bins + * have the same width: (upper_level[i] - lower_level[i]) / ( num_levels[i] - 1) + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the computation of three 4-bin RGB histograms from a 2x3 region of + * interest of within a flattened 2x4 array of quad-channel RGBA pixels (8 bits per channel per pixel). + * + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input samples + * // and output histograms + * int num_row_pixels; // e.g., 3 + * int num_rows; // e.g., 2 + * size_t row_stride_bytes; // e.g., 4 * sizeof(unsigned char) * NUM_CHANNELS + * unsigned char* d_samples; // e.g., [(2, 6, 7, 5),(3, 0, 2, 1),(1, 1, 1, 1),(-, -, -, -), + * // (7, 0, 6, 2),(0, 6, 7, 5),(3, 0, 2, 6),(-, -, -, -)] + * int* d_histogram[3]; // e.g., [[ -, -, -, -],[ -, -, -, -],[ -, -, -, -]]; + * int num_levels[3]; // e.g., {5, 5, 5}; + * unsigned int* d_levels[3]; // e.g., [ [0, 2, 4, 6, 8], + * // [0, 2, 4, 6, 8], + * // [0, 2, 4, 6, 8] ]; + * ... + * + * // Determine temporary device storage requirements + * void* d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceHistogram::MultiHistogramRange<4, 3>(d_temp_storage, temp_storage_bytes, + * d_samples, d_histogram, num_levels, d_levels, num_row_pixels, num_rows, row_stride_bytes); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Compute histograms + * cub::DeviceHistogram::MultiHistogramRange<4, 3>(d_temp_storage, temp_storage_bytes, + * d_samples, d_histogram, num_levels, d_levels, num_row_pixels, num_rows, row_stride_bytes); + * + * // d_histogram <-- [ [2, 3, 0, 1], + * // [3, 0, 0, 2], + * // [1, 2, 0, 3] ] + * + * \endcode + * + * \tparam NUM_CHANNELS Number of channels interleaved in the input data (may be greater than the number of channels being actively histogrammed) + * \tparam NUM_ACTIVE_CHANNELS [inferred] Number of channels actively being histogrammed + * \tparam SampleIteratorT [inferred] Random-access input iterator type for reading input samples. \iterator + * \tparam CounterT [inferred] Integer type for histogram bin counters + * \tparam LevelT [inferred] Type for specifying boundaries (levels) + * \tparam OffsetT [inferred] Signed integer type for sequence offsets, list lengths, pointer differences, etc. \offset_size1 + */ + template < + int NUM_CHANNELS, + int NUM_ACTIVE_CHANNELS, + typename SampleIteratorT, + typename CounterT, + typename LevelT, + typename OffsetT> + CUB_RUNTIME_FUNCTION + static cudaError_t MultiHistogramRange( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + SampleIteratorT d_samples, ///< [in] The pointer to the multi-channel input sequence of data samples. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four RGBA 8-bit samples). + CounterT* d_histogram[NUM_ACTIVE_CHANNELS], ///< [out] The pointers to the histogram counter output arrays, one for each active channel. For channeli, the allocation length of d_histogram[i] should be num_levels[i] - 1. + int num_levels[NUM_ACTIVE_CHANNELS], ///< [in] The number of boundaries (levels) for delineating histogram samples in each active channel. Implies that the number of bins for channeli is num_levels[i] - 1. + LevelT* d_levels[NUM_ACTIVE_CHANNELS], ///< [in] The pointers to the arrays of boundaries (levels), one for each active channel. Bin ranges are defined by consecutive boundary pairings: lower sample value boundaries are inclusive and upper sample value boundaries are exclusive. + OffsetT num_row_pixels, ///< [in] The number of multi-channel pixels per row in the region of interest + OffsetT num_rows, ///< [in] The number of rows in the region of interest + size_t row_stride_bytes, ///< [in] The number of bytes between starts of consecutive rows in the region of interest + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + /// The sample value type of the input iterator + typedef typename std::iterator_traits::value_type SampleT; + Int2Type is_byte_sample; + + if ((sizeof(OffsetT) > sizeof(int)) && + ((unsigned long long) (num_rows * row_stride_bytes) < (unsigned long long) std::numeric_limits::max())) + { + // Down-convert OffsetT data type + return DipatchHistogram::DispatchRange( + d_temp_storage, temp_storage_bytes, d_samples, d_histogram, num_levels, d_levels, + (int) num_row_pixels, (int) num_rows, (int) (row_stride_bytes / sizeof(SampleT)), + stream, debug_synchronous, is_byte_sample); + } + + return DipatchHistogram::DispatchRange( + d_temp_storage, temp_storage_bytes, d_samples, d_histogram, num_levels, d_levels, + num_row_pixels, num_rows, (OffsetT) (row_stride_bytes / sizeof(SampleT)), + stream, debug_synchronous, is_byte_sample); + } + + + + //@} end member group +}; + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/device/device_partition.cuh b/fastertransformer/cuda/cub/device/device_partition.cuh new file mode 100644 index 000000000..505354007 --- /dev/null +++ b/fastertransformer/cuda/cub/device/device_partition.cuh @@ -0,0 +1,273 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DevicePartition provides device-wide, parallel operations for partitioning sequences of data items residing within device-accessible memory. + */ + +#pragma once + +#include +#include + +#include "dispatch/dispatch_select_if.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \brief DevicePartition provides device-wide, parallel operations for partitioning sequences of data items residing within device-accessible memory. ![](partition_logo.png) + * \ingroup SingleModule + * + * \par Overview + * These operations apply a selection criterion to construct a partitioned output sequence from items selected/unselected from + * a specified input sequence. + * + * \par Usage Considerations + * \cdp_class{DevicePartition} + * + * \par Performance + * \linear_performance{partition} + * + * \par + * The following chart illustrates DevicePartition::If + * performance across different CUDA architectures for \p int32 items, + * where 50% of the items are randomly selected for the first partition. + * \plots_below + * + * \image html partition_if_int32_50_percent.png + * + */ +struct DevicePartition +{ + /** + * \brief Uses the \p d_flags sequence to split the corresponding items from \p d_in into a partitioned sequence \p d_out. The total number of items copied into the first partition is written to \p d_num_selected_out. ![](partition_flags_logo.png) + * + * \par + * - The value type of \p d_flags must be castable to \p bool (e.g., \p bool, \p char, \p int, etc.). + * - Copies of the selected items are compacted into \p d_out and maintain their original + * relative ordering, however copies of the unselected items are compacted into the + * rear of \p d_out in reverse order. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the compaction of items selected from an \p int device vector. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input, flags, and output + * int num_items; // e.g., 8 + * int *d_in; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] + * char *d_flags; // e.g., [1, 0, 0, 1, 0, 1, 1, 0] + * int *d_out; // e.g., [ , , , , , , , ] + * int *d_num_selected_out; // e.g., [ ] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DevicePartition::Flagged(d_temp_storage, temp_storage_bytes, d_in, d_flags, d_out, d_num_selected_out, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run selection + * cub::DevicePartition::Flagged(d_temp_storage, temp_storage_bytes, d_in, d_flags, d_out, d_num_selected_out, num_items); + * + * // d_out <-- [1, 4, 6, 7, 8, 5, 3, 2] + * // d_num_selected_out <-- [4] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items \iterator + * \tparam FlagIterator [inferred] Random-access input iterator type for reading selection flags \iterator + * \tparam OutputIteratorT [inferred] Random-access output iterator type for writing output items \iterator + * \tparam NumSelectedIteratorT [inferred] Output iterator type for recording the number of items selected \iterator + */ + template < + typename InputIteratorT, + typename FlagIterator, + typename OutputIteratorT, + typename NumSelectedIteratorT> + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Flagged( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + FlagIterator d_flags, ///< [in] Pointer to the input sequence of selection flags + OutputIteratorT d_out, ///< [out] Pointer to the output sequence of partitioned data items + NumSelectedIteratorT d_num_selected_out, ///< [out] Pointer to the output total number of items selected (i.e., the offset of the unselected partition) + int num_items, ///< [in] Total number of items to select from + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + typedef int OffsetT; // Signed integer type for global offsets + typedef NullType SelectOp; // Selection op (not used) + typedef NullType EqualityOp; // Equality operator (not used) + + return DispatchSelectIf::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_flags, + d_out, + d_num_selected_out, + SelectOp(), + EqualityOp(), + num_items, + stream, + debug_synchronous); + } + + + /** + * \brief Uses the \p select_op functor to split the corresponding items from \p d_in into a partitioned sequence \p d_out. The total number of items copied into the first partition is written to \p d_num_selected_out. ![](partition_logo.png) + * + * \par + * - Copies of the selected items are compacted into \p d_out and maintain their original + * relative ordering, however copies of the unselected items are compacted into the + * rear of \p d_out in reverse order. + * - \devicestorage + * + * \par Performance + * The following charts illustrate saturated partition-if performance across different + * CUDA architectures for \p int32 and \p int64 items, respectively. Items are + * selected for the first partition with 50% probability. + * + * \image html partition_if_int32_50_percent.png + * \image html partition_if_int64_50_percent.png + * + * \par + * The following charts are similar, but 5% selection probability for the first partition: + * + * \image html partition_if_int32_5_percent.png + * \image html partition_if_int64_5_percent.png + * + * \par Snippet + * The code snippet below illustrates the compaction of items selected from an \p int device vector. + * \par + * \code + * #include // or equivalently + * + * // Functor type for selecting values less than some criteria + * struct LessThan + * { + * int compare; + * + * CUB_RUNTIME_FUNCTION __forceinline__ + * LessThan(int compare) : compare(compare) {} + * + * CUB_RUNTIME_FUNCTION __forceinline__ + * bool operator()(const int &a) const { + * return (a < compare); + * } + * }; + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_items; // e.g., 8 + * int *d_in; // e.g., [0, 2, 3, 9, 5, 2, 81, 8] + * int *d_out; // e.g., [ , , , , , , , ] + * int *d_num_selected_out; // e.g., [ ] + * LessThan select_op(7); + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSelect::If(d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run selection + * cub::DeviceSelect::If(d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op); + * + * // d_out <-- [0, 2, 3, 5, 2, 8, 81, 9] + * // d_num_selected_out <-- [5] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items \iterator + * \tparam OutputIteratorT [inferred] Random-access output iterator type for writing output items \iterator + * \tparam NumSelectedIteratorT [inferred] Output iterator type for recording the number of items selected \iterator + * \tparam SelectOp [inferred] Selection functor type having member bool operator()(const T &a) + */ + template < + typename InputIteratorT, + typename OutputIteratorT, + typename NumSelectedIteratorT, + typename SelectOp> + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t If( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output sequence of partitioned data items + NumSelectedIteratorT d_num_selected_out, ///< [out] Pointer to the output total number of items selected (i.e., the offset of the unselected partition) + int num_items, ///< [in] Total number of items to select from + SelectOp select_op, ///< [in] Unary selection operator + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + typedef int OffsetT; // Signed integer type for global offsets + typedef NullType* FlagIterator; // FlagT iterator type (not used) + typedef NullType EqualityOp; // Equality operator (not used) + + return DispatchSelectIf::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + NULL, + d_out, + d_num_selected_out, + select_op, + EqualityOp(), + num_items, + stream, + debug_synchronous); + } + +}; + +/** + * \example example_device_partition_flagged.cu + * \example example_device_partition_if.cu + */ + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/device/device_radix_sort.cuh b/fastertransformer/cuda/cub/device/device_radix_sort.cuh new file mode 100644 index 000000000..1c0bdbea1 --- /dev/null +++ b/fastertransformer/cuda/cub/device/device_radix_sort.cuh @@ -0,0 +1,797 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DeviceRadixSort provides device-wide, parallel operations for computing a radix sort across a sequence of data items residing within device-accessible memory. + */ + +#pragma once + +#include +#include + +#include "dispatch/dispatch_radix_sort.cuh" +#include "../util_arch.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \brief DeviceRadixSort provides device-wide, parallel operations for computing a radix sort across a sequence of data items residing within device-accessible memory. ![](sorting_logo.png) + * \ingroup SingleModule + * + * \par Overview + * The [radix sorting method](http://en.wikipedia.org/wiki/Radix_sort) arranges + * items into ascending (or descending) order. The algorithm relies upon a positional representation for + * keys, i.e., each key is comprised of an ordered sequence of symbols (e.g., digits, + * characters, etc.) specified from least-significant to most-significant. For a + * given input sequence of keys and a set of rules specifying a total ordering + * of the symbolic alphabet, the radix sorting method produces a lexicographic + * ordering of those keys. + * + * \par + * DeviceRadixSort can sort all of the built-in C++ numeric primitive types + * (unsigned char, \p int, \p double, etc.) as well as CUDA's \p __half + * half-precision floating-point type. Although the direct radix sorting + * method can only be applied to unsigned integral types, DeviceRadixSort + * is able to sort signed and floating-point types via simple bit-wise transformations + * that ensure lexicographic key ordering. + * + * \par Usage Considerations + * \cdp_class{DeviceRadixSort} + * + * \par Performance + * \linear_performance{radix sort} The following chart illustrates DeviceRadixSort::SortKeys + * performance across different CUDA architectures for uniform-random \p uint32 keys. + * \plots_below + * + * \image html lsb_radix_sort_int32_keys.png + * + */ +struct DeviceRadixSort +{ + + /******************************************************************//** + * \name KeyT-value pairs + *********************************************************************/ + //@{ + + /** + * \brief Sorts key-value pairs into ascending order. (~2N auxiliary storage required) + * + * \par + * - The contents of the input data are not altered by the sorting operation + * - An optional bit subrange [begin_bit, end_bit) of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + * - \devicestorageNP For sorting using only O(P) temporary storage, see the sorting interface using DoubleBuffer wrappers below. + * - \devicestorage + * + * \par Performance + * The following charts illustrate saturated sorting performance across different + * CUDA architectures for uniform-random uint32,uint32 and + * uint64,uint64 pairs, respectively. + * + * \image html lsb_radix_sort_int32_pairs.png + * \image html lsb_radix_sort_int64_pairs.png + * + * \par Snippet + * The code snippet below illustrates the sorting of a device vector of \p int keys + * with associated vector of \p int values. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_keys_out; // e.g., [ ... ] + * int *d_values_in; // e.g., [0, 1, 2, 3, 4, 5, 6] + * int *d_values_out; // e.g., [ ... ] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, + * d_keys_in, d_keys_out, d_values_in, d_values_out, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, + * d_keys_in, d_keys_out, d_values_in, d_values_out, num_items); + * + * // d_keys_out <-- [0, 3, 5, 6, 7, 8, 9] + * // d_values_out <-- [5, 4, 3, 1, 2, 0, 6] + * + * \endcode + * + * \tparam KeyT [inferred] KeyT type + * \tparam ValueT [inferred] ValueT type + */ + template < + typename KeyT, + typename ValueT> + CUB_RUNTIME_FUNCTION + static cudaError_t SortPairs( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + const KeyT *d_keys_in, ///< [in] Pointer to the input data of key data to sort + KeyT *d_keys_out, ///< [out] Pointer to the sorted output sequence of key data + const ValueT *d_values_in, ///< [in] Pointer to the corresponding input sequence of associated value items + ValueT *d_values_out, ///< [out] Pointer to the correspondingly-reordered output sequence of associated value items + int num_items, ///< [in] Number of items to sort + int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison + int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); + DoubleBuffer d_values(const_cast(d_values_in), d_values_out); + + return DispatchRadixSort::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + num_items, + begin_bit, + end_bit, + false, + stream, + debug_synchronous); + } + + + /** + * \brief Sorts key-value pairs into ascending order. (~N auxiliary storage required) + * + * \par + * - The sorting operation is given a pair of key buffers and a corresponding + * pair of associated value buffers. Each pair is managed by a DoubleBuffer + * structure that indicates which of the two buffers is "current" (and thus + * contains the input data to be sorted). + * - The contents of both buffers within each pair may be altered by the sorting + * operation. + * - Upon completion, the sorting operation will update the "current" indicator + * within each DoubleBuffer wrapper to reference which of the two buffers + * now contains the sorted output sequence (a function of the number of key bits + * specified and the targeted device architecture). + * - An optional bit subrange [begin_bit, end_bit) of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + * - \devicestorageP + * - \devicestorage + * + * \par Performance + * The following charts illustrate saturated sorting performance across different + * CUDA architectures for uniform-random uint32,uint32 and + * uint64,uint64 pairs, respectively. + * + * \image html lsb_radix_sort_int32_pairs.png + * \image html lsb_radix_sort_int64_pairs.png + * + * \par Snippet + * The code snippet below illustrates the sorting of a device vector of \p int keys + * with associated vector of \p int values. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_key_alt_buf; // e.g., [ ... ] + * int *d_value_buf; // e.g., [0, 1, 2, 3, 4, 5, 6] + * int *d_value_alt_buf; // e.g., [ ... ] + * ... + * + * // Create a set of DoubleBuffers to wrap pairs of device pointers + * cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + * cub::DoubleBuffer d_values(d_value_buf, d_value_alt_buf); + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items); + * + * // d_keys.Current() <-- [0, 3, 5, 6, 7, 8, 9] + * // d_values.Current() <-- [5, 4, 3, 1, 2, 0, 6] + * + * \endcode + * + * \tparam KeyT [inferred] KeyT type + * \tparam ValueT [inferred] ValueT type + */ + template < + typename KeyT, + typename ValueT> + CUB_RUNTIME_FUNCTION + static cudaError_t SortPairs( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + DoubleBuffer &d_keys, ///< [in,out] Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + DoubleBuffer &d_values, ///< [in,out] Double-buffer of values whose "current" device-accessible buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + int num_items, ///< [in] Number of items to sort + int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison + int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + return DispatchRadixSort::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + num_items, + begin_bit, + end_bit, + true, + stream, + debug_synchronous); + } + + + /** + * \brief Sorts key-value pairs into descending order. (~2N auxiliary storage required). + * + * \par + * - The contents of the input data are not altered by the sorting operation + * - An optional bit subrange [begin_bit, end_bit) of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + * - \devicestorageNP For sorting using only O(P) temporary storage, see the sorting interface using DoubleBuffer wrappers below. + * - \devicestorage + * + * \par Performance + * Performance is similar to DeviceRadixSort::SortPairs. + * + * \par Snippet + * The code snippet below illustrates the sorting of a device vector of \p int keys + * with associated vector of \p int values. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_keys_out; // e.g., [ ... ] + * int *d_values_in; // e.g., [0, 1, 2, 3, 4, 5, 6] + * int *d_values_out; // e.g., [ ... ] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, + * d_keys_in, d_keys_out, d_values_in, d_values_out, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, + * d_keys_in, d_keys_out, d_values_in, d_values_out, num_items); + * + * // d_keys_out <-- [9, 8, 7, 6, 5, 3, 0] + * // d_values_out <-- [6, 0, 2, 1, 3, 4, 5] + * + * \endcode + * + * \tparam KeyT [inferred] KeyT type + * \tparam ValueT [inferred] ValueT type + */ + template < + typename KeyT, + typename ValueT> + CUB_RUNTIME_FUNCTION + static cudaError_t SortPairsDescending( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + const KeyT *d_keys_in, ///< [in] Pointer to the input data of key data to sort + KeyT *d_keys_out, ///< [out] Pointer to the sorted output sequence of key data + const ValueT *d_values_in, ///< [in] Pointer to the corresponding input sequence of associated value items + ValueT *d_values_out, ///< [out] Pointer to the correspondingly-reordered output sequence of associated value items + int num_items, ///< [in] Number of items to sort + int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison + int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); + DoubleBuffer d_values(const_cast(d_values_in), d_values_out); + + return DispatchRadixSort::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + num_items, + begin_bit, + end_bit, + false, + stream, + debug_synchronous); + } + + + /** + * \brief Sorts key-value pairs into descending order. (~N auxiliary storage required). + * + * \par + * - The sorting operation is given a pair of key buffers and a corresponding + * pair of associated value buffers. Each pair is managed by a DoubleBuffer + * structure that indicates which of the two buffers is "current" (and thus + * contains the input data to be sorted). + * - The contents of both buffers within each pair may be altered by the sorting + * operation. + * - Upon completion, the sorting operation will update the "current" indicator + * within each DoubleBuffer wrapper to reference which of the two buffers + * now contains the sorted output sequence (a function of the number of key bits + * specified and the targeted device architecture). + * - An optional bit subrange [begin_bit, end_bit) of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + * - \devicestorageP + * - \devicestorage + * + * \par Performance + * Performance is similar to DeviceRadixSort::SortPairs. + * + * \par Snippet + * The code snippet below illustrates the sorting of a device vector of \p int keys + * with associated vector of \p int values. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_key_alt_buf; // e.g., [ ... ] + * int *d_value_buf; // e.g., [0, 1, 2, 3, 4, 5, 6] + * int *d_value_alt_buf; // e.g., [ ... ] + * ... + * + * // Create a set of DoubleBuffers to wrap pairs of device pointers + * cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + * cub::DoubleBuffer d_values(d_value_buf, d_value_alt_buf); + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items); + * + * // d_keys.Current() <-- [9, 8, 7, 6, 5, 3, 0] + * // d_values.Current() <-- [6, 0, 2, 1, 3, 4, 5] + * + * \endcode + * + * \tparam KeyT [inferred] KeyT type + * \tparam ValueT [inferred] ValueT type + */ + template < + typename KeyT, + typename ValueT> + CUB_RUNTIME_FUNCTION + static cudaError_t SortPairsDescending( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + DoubleBuffer &d_keys, ///< [in,out] Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + DoubleBuffer &d_values, ///< [in,out] Double-buffer of values whose "current" device-accessible buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + int num_items, ///< [in] Number of items to sort + int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison + int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + return DispatchRadixSort::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + num_items, + begin_bit, + end_bit, + true, + stream, + debug_synchronous); + } + + + //@} end member group + /******************************************************************//** + * \name Keys-only + *********************************************************************/ + //@{ + + + /** + * \brief Sorts keys into ascending order. (~2N auxiliary storage required) + * + * \par + * - The contents of the input data are not altered by the sorting operation + * - An optional bit subrange [begin_bit, end_bit) of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + * - \devicestorageNP For sorting using only O(P) temporary storage, see the sorting interface using DoubleBuffer wrappers below. + * - \devicestorage + * + * \par Performance + * The following charts illustrate saturated sorting performance across different + * CUDA architectures for uniform-random \p uint32 and \p uint64 keys, respectively. + * + * \image html lsb_radix_sort_int32_keys.png + * \image html lsb_radix_sort_int64_keys.png + * + * \par Snippet + * The code snippet below illustrates the sorting of a device vector of \p int keys. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_keys_out; // e.g., [ ... ] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceRadixSort::SortKeys(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceRadixSort::SortKeys(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, num_items); + * + * // d_keys_out <-- [0, 3, 5, 6, 7, 8, 9] + * + * \endcode + * + * \tparam KeyT [inferred] KeyT type + */ + template + CUB_RUNTIME_FUNCTION + static cudaError_t SortKeys( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + const KeyT *d_keys_in, ///< [in] Pointer to the input data of key data to sort + KeyT *d_keys_out, ///< [out] Pointer to the sorted output sequence of key data + int num_items, ///< [in] Number of items to sort + int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison + int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + // Null value type + DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); + DoubleBuffer d_values; + + return DispatchRadixSort::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + num_items, + begin_bit, + end_bit, + false, + stream, + debug_synchronous); + } + + + /** + * \brief Sorts keys into ascending order. (~N auxiliary storage required). + * + * \par + * - The sorting operation is given a pair of key buffers managed by a + * DoubleBuffer structure that indicates which of the two buffers is + * "current" (and thus contains the input data to be sorted). + * - The contents of both buffers may be altered by the sorting operation. + * - Upon completion, the sorting operation will update the "current" indicator + * within the DoubleBuffer wrapper to reference which of the two buffers + * now contains the sorted output sequence (a function of the number of key bits + * specified and the targeted device architecture). + * - An optional bit subrange [begin_bit, end_bit) of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + * - \devicestorageP + * - \devicestorage + * + * \par Performance + * The following charts illustrate saturated sorting performance across different + * CUDA architectures for uniform-random \p uint32 and \p uint64 keys, respectively. + * + * \image html lsb_radix_sort_int32_keys.png + * \image html lsb_radix_sort_int64_keys.png + * + * \par Snippet + * The code snippet below illustrates the sorting of a device vector of \p int keys. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_key_alt_buf; // e.g., [ ... ] + * ... + * + * // Create a DoubleBuffer to wrap the pair of device pointers + * cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceRadixSort::SortKeys(d_temp_storage, temp_storage_bytes, d_keys, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceRadixSort::SortKeys(d_temp_storage, temp_storage_bytes, d_keys, num_items); + * + * // d_keys.Current() <-- [0, 3, 5, 6, 7, 8, 9] + * + * \endcode + * + * \tparam KeyT [inferred] KeyT type + */ + template + CUB_RUNTIME_FUNCTION + static cudaError_t SortKeys( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + DoubleBuffer &d_keys, ///< [in,out] Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + int num_items, ///< [in] Number of items to sort + int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison + int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + // Null value type + DoubleBuffer d_values; + + return DispatchRadixSort::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + num_items, + begin_bit, + end_bit, + true, + stream, + debug_synchronous); + } + + /** + * \brief Sorts keys into descending order. (~2N auxiliary storage required). + * + * \par + * - The contents of the input data are not altered by the sorting operation + * - An optional bit subrange [begin_bit, end_bit) of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + * - \devicestorageNP For sorting using only O(P) temporary storage, see the sorting interface using DoubleBuffer wrappers below. + * - \devicestorage + * + * \par Performance + * Performance is similar to DeviceRadixSort::SortKeys. + * + * \par Snippet + * The code snippet below illustrates the sorting of a device vector of \p int keys. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_keys_out; // e.g., [ ... ] + * ... + * + * // Create a DoubleBuffer to wrap the pair of device pointers + * cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceRadixSort::SortKeysDescending(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceRadixSort::SortKeysDescending(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, num_items); + * + * // d_keys_out <-- [9, 8, 7, 6, 5, 3, 0]s + * + * \endcode + * + * \tparam KeyT [inferred] KeyT type + */ + template + CUB_RUNTIME_FUNCTION + static cudaError_t SortKeysDescending( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + const KeyT *d_keys_in, ///< [in] Pointer to the input data of key data to sort + KeyT *d_keys_out, ///< [out] Pointer to the sorted output sequence of key data + int num_items, ///< [in] Number of items to sort + int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison + int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); + DoubleBuffer d_values; + + return DispatchRadixSort::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + num_items, + begin_bit, + end_bit, + false, + stream, + debug_synchronous); + } + + + /** + * \brief Sorts keys into descending order. (~N auxiliary storage required). + * + * \par + * - The sorting operation is given a pair of key buffers managed by a + * DoubleBuffer structure that indicates which of the two buffers is + * "current" (and thus contains the input data to be sorted). + * - The contents of both buffers may be altered by the sorting operation. + * - Upon completion, the sorting operation will update the "current" indicator + * within the DoubleBuffer wrapper to reference which of the two buffers + * now contains the sorted output sequence (a function of the number of key bits + * specified and the targeted device architecture). + * - An optional bit subrange [begin_bit, end_bit) of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + * - \devicestorageP + * - \devicestorage + * + * \par Performance + * Performance is similar to DeviceRadixSort::SortKeys. + * + * \par Snippet + * The code snippet below illustrates the sorting of a device vector of \p int keys. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_key_alt_buf; // e.g., [ ... ] + * ... + * + * // Create a DoubleBuffer to wrap the pair of device pointers + * cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceRadixSort::SortKeysDescending(d_temp_storage, temp_storage_bytes, d_keys, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceRadixSort::SortKeysDescending(d_temp_storage, temp_storage_bytes, d_keys, num_items); + * + * // d_keys.Current() <-- [9, 8, 7, 6, 5, 3, 0] + * + * \endcode + * + * \tparam KeyT [inferred] KeyT type + */ + template + CUB_RUNTIME_FUNCTION + static cudaError_t SortKeysDescending( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + DoubleBuffer &d_keys, ///< [in,out] Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + int num_items, ///< [in] Number of items to sort + int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison + int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + // Null value type + DoubleBuffer d_values; + + return DispatchRadixSort::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + num_items, + begin_bit, + end_bit, + true, + stream, + debug_synchronous); + } + + + //@} end member group + + +}; + +/** + * \example example_device_radix_sort.cu + */ + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/device/device_reduce.cuh b/fastertransformer/cuda/cub/device/device_reduce.cuh new file mode 100644 index 000000000..13c7a72d1 --- /dev/null +++ b/fastertransformer/cuda/cub/device/device_reduce.cuh @@ -0,0 +1,734 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DeviceReduce provides device-wide, parallel operations for computing a reduction across a sequence of data items residing within device-accessible memory. + */ + +#pragma once + +#include +#include +#include + +#include "../iterator/arg_index_input_iterator.cuh" +#include "dispatch/dispatch_reduce.cuh" +#include "dispatch/dispatch_reduce_by_key.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \brief DeviceReduce provides device-wide, parallel operations for computing a reduction across a sequence of data items residing within device-accessible memory. ![](reduce_logo.png) + * \ingroup SingleModule + * + * \par Overview + * A reduction (or fold) + * uses a binary combining operator to compute a single aggregate from a sequence of input elements. + * + * \par Usage Considerations + * \cdp_class{DeviceReduce} + * + * \par Performance + * \linear_performance{reduction, reduce-by-key, and run-length encode} + * + * \par + * The following chart illustrates DeviceReduce::Sum + * performance across different CUDA architectures for \p int32 keys. + * + * \image html reduce_int32.png + * + * \par + * The following chart illustrates DeviceReduce::ReduceByKey (summation) + * performance across different CUDA architectures for \p fp32 + * values. Segments are identified by \p int32 keys, and have lengths uniformly sampled from [1,1000]. + * + * \image html reduce_by_key_fp32_len_500.png + * + * \par + * \plots_below + * + */ +struct DeviceReduce +{ + /** + * \brief Computes a device-wide reduction using the specified binary \p reduction_op functor and initial value \p init. + * + * \par + * - Does not support binary reduction operators that are non-commutative. + * - Provides "run-to-run" determinism for pseudo-associative reduction + * (e.g., addition of floating point types) on the same GPU device. + * However, results for pseudo-associative reduction may be inconsistent + * from one device to a another device of a different compute-capability + * because CUB can employ different tile-sizing for different architectures. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates a user-defined min-reduction of a device vector of \p int data elements. + * \par + * \code + * #include // or equivalently + * + * // CustomMin functor + * struct CustomMin + * { + * template + * __device__ __forceinline__ + * T operator()(const T &a, const T &b) const { + * return (b < a) ? b : a; + * } + * }; + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_items; // e.g., 7 + * int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_out; // e.g., [-] + * CustomMin min_op; + * int init; // e.g., INT_MAX + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceReduce::Reduce(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, min_op, init); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run reduction + * cub::DeviceReduce::Reduce(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, min_op, init); + * + * // d_out <-- [0] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items \iterator + * \tparam OutputIteratorT [inferred] Output iterator type for recording the reduced aggregate \iterator + * \tparam ReductionOpT [inferred] Binary reduction functor type having member T operator()(const T &a, const T &b) + * \tparam T [inferred] Data element type that is convertible to the \p value type of \p InputIteratorT + */ + template < + typename InputIteratorT, + typename OutputIteratorT, + typename ReductionOpT, + typename T> + CUB_RUNTIME_FUNCTION + static cudaError_t Reduce( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output aggregate + int num_items, ///< [in] Total number of input items (i.e., length of \p d_in) + ReductionOpT reduction_op, ///< [in] Binary reduction functor + T init, ///< [in] Initial value of the reduction + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + return DispatchReduce::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + num_items, + reduction_op, + init, + stream, + debug_synchronous); + } + + + /** + * \brief Computes a device-wide sum using the addition (\p +) operator. + * + * \par + * - Uses \p 0 as the initial value of the reduction. + * - Does not support \p + operators that are non-commutative.. + * - Provides "run-to-run" determinism for pseudo-associative reduction + * (e.g., addition of floating point types) on the same GPU device. + * However, results for pseudo-associative reduction may be inconsistent + * from one device to a another device of a different compute-capability + * because CUB can employ different tile-sizing for different architectures. + * - \devicestorage + * + * \par Performance + * The following charts illustrate saturated sum-reduction performance across different + * CUDA architectures for \p int32 and \p int64 items, respectively. + * + * \image html reduce_int32.png + * \image html reduce_int64.png + * + * \par Snippet + * The code snippet below illustrates the sum-reduction of a device vector of \p int data elements. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_items; // e.g., 7 + * int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_out; // e.g., [-] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sum-reduction + * cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items); + * + * // d_out <-- [38] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items \iterator + * \tparam OutputIteratorT [inferred] Output iterator type for recording the reduced aggregate \iterator + */ + template < + typename InputIteratorT, + typename OutputIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t Sum( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output aggregate + int num_items, ///< [in] Total number of input items (i.e., length of \p d_in) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + // The output value type + typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? + typename std::iterator_traits::value_type, // ... then the input iterator's value type, + typename std::iterator_traits::value_type>::Type OutputT; // ... else the output iterator's value type + + return DispatchReduce::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + num_items, + cub::Sum(), + OutputT(), // zero-initialize + stream, + debug_synchronous); + } + + + /** + * \brief Computes a device-wide minimum using the less-than ('<') operator. + * + * \par + * - Uses std::numeric_limits::max() as the initial value of the reduction. + * - Does not support \p < operators that are non-commutative. + * - Provides "run-to-run" determinism for pseudo-associative reduction + * (e.g., addition of floating point types) on the same GPU device. + * However, results for pseudo-associative reduction may be inconsistent + * from one device to a another device of a different compute-capability + * because CUB can employ different tile-sizing for different architectures. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the min-reduction of a device vector of \p int data elements. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_items; // e.g., 7 + * int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_out; // e.g., [-] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceReduce::Min(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run min-reduction + * cub::DeviceReduce::Min(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items); + * + * // d_out <-- [0] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items \iterator + * \tparam OutputIteratorT [inferred] Output iterator type for recording the reduced aggregate \iterator + */ + template < + typename InputIteratorT, + typename OutputIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t Min( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output aggregate + int num_items, ///< [in] Total number of input items (i.e., length of \p d_in) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + // The input value type + typedef typename std::iterator_traits::value_type InputT; + + return DispatchReduce::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + num_items, + cub::Min(), + Traits::Max(), // replace with std::numeric_limits::max() when C++11 support is more prevalent + stream, + debug_synchronous); + } + + + /** + * \brief Finds the first device-wide minimum using the less-than ('<') operator, also returning the index of that item. + * + * \par + * - The output value type of \p d_out is cub::KeyValuePair (assuming the value type of \p d_in is \p T) + * - The minimum is written to d_out.value and its offset in the input array is written to d_out.key. + * - The {1, std::numeric_limits::max()} tuple is produced for zero-length inputs + * - Does not support \p < operators that are non-commutative. + * - Provides "run-to-run" determinism for pseudo-associative reduction + * (e.g., addition of floating point types) on the same GPU device. + * However, results for pseudo-associative reduction may be inconsistent + * from one device to a another device of a different compute-capability + * because CUB can employ different tile-sizing for different architectures. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the argmin-reduction of a device vector of \p int data elements. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_items; // e.g., 7 + * int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * KeyValuePair *d_out; // e.g., [{-,-}] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceReduce::ArgMin(d_temp_storage, temp_storage_bytes, d_in, d_argmin, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run argmin-reduction + * cub::DeviceReduce::ArgMin(d_temp_storage, temp_storage_bytes, d_in, d_argmin, num_items); + * + * // d_out <-- [{5, 0}] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items (of some type \p T) \iterator + * \tparam OutputIteratorT [inferred] Output iterator type for recording the reduced aggregate (having value type cub::KeyValuePair) \iterator + */ + template < + typename InputIteratorT, + typename OutputIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t ArgMin( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output aggregate + int num_items, ///< [in] Total number of input items (i.e., length of \p d_in) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + // The input type + typedef typename std::iterator_traits::value_type InputValueT; + + // The output tuple type + typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? + KeyValuePair, // ... then the key value pair OffsetT + InputValueT + typename std::iterator_traits::value_type>::Type OutputTupleT; // ... else the output iterator's value type + + // The output value type + typedef typename OutputTupleT::Value OutputValueT; + + // Wrapped input iterator to produce index-value tuples + typedef ArgIndexInputIterator ArgIndexInputIteratorT; + ArgIndexInputIteratorT d_indexed_in(d_in); + + // Initial value + OutputTupleT initial_value(1, Traits::Max()); // replace with std::numeric_limits::max() when C++11 support is more prevalent + + return DispatchReduce::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_indexed_in, + d_out, + num_items, + cub::ArgMin(), + initial_value, + stream, + debug_synchronous); + } + + + /** + * \brief Computes a device-wide maximum using the greater-than ('>') operator. + * + * \par + * - Uses std::numeric_limits::lowest() as the initial value of the reduction. + * - Does not support \p > operators that are non-commutative. + * - Provides "run-to-run" determinism for pseudo-associative reduction + * (e.g., addition of floating point types) on the same GPU device. + * However, results for pseudo-associative reduction may be inconsistent + * from one device to a another device of a different compute-capability + * because CUB can employ different tile-sizing for different architectures. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the max-reduction of a device vector of \p int data elements. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_items; // e.g., 7 + * int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_out; // e.g., [-] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_max, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run max-reduction + * cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_max, num_items); + * + * // d_out <-- [9] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items \iterator + * \tparam OutputIteratorT [inferred] Output iterator type for recording the reduced aggregate \iterator + */ + template < + typename InputIteratorT, + typename OutputIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t Max( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output aggregate + int num_items, ///< [in] Total number of input items (i.e., length of \p d_in) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + // The input value type + typedef typename std::iterator_traits::value_type InputT; + + return DispatchReduce::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + num_items, + cub::Max(), + Traits::Lowest(), // replace with std::numeric_limits::lowest() when C++11 support is more prevalent + stream, + debug_synchronous); + } + + + /** + * \brief Finds the first device-wide maximum using the greater-than ('>') operator, also returning the index of that item + * + * \par + * - The output value type of \p d_out is cub::KeyValuePair (assuming the value type of \p d_in is \p T) + * - The maximum is written to d_out.value and its offset in the input array is written to d_out.key. + * - The {1, std::numeric_limits::lowest()} tuple is produced for zero-length inputs + * - Does not support \p > operators that are non-commutative. + * - Provides "run-to-run" determinism for pseudo-associative reduction + * (e.g., addition of floating point types) on the same GPU device. + * However, results for pseudo-associative reduction may be inconsistent + * from one device to a another device of a different compute-capability + * because CUB can employ different tile-sizing for different architectures. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the argmax-reduction of a device vector of \p int data elements. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_items; // e.g., 7 + * int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * KeyValuePair *d_out; // e.g., [{-,-}] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceReduce::ArgMax(d_temp_storage, temp_storage_bytes, d_in, d_argmax, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run argmax-reduction + * cub::DeviceReduce::ArgMax(d_temp_storage, temp_storage_bytes, d_in, d_argmax, num_items); + * + * // d_out <-- [{6, 9}] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items (of some type \p T) \iterator + * \tparam OutputIteratorT [inferred] Output iterator type for recording the reduced aggregate (having value type cub::KeyValuePair) \iterator + */ + template < + typename InputIteratorT, + typename OutputIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t ArgMax( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output aggregate + int num_items, ///< [in] Total number of input items (i.e., length of \p d_in) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + // The input type + typedef typename std::iterator_traits::value_type InputValueT; + + // The output tuple type + typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? + KeyValuePair, // ... then the key value pair OffsetT + InputValueT + typename std::iterator_traits::value_type>::Type OutputTupleT; // ... else the output iterator's value type + + // The output value type + typedef typename OutputTupleT::Value OutputValueT; + + // Wrapped input iterator to produce index-value tuples + typedef ArgIndexInputIterator ArgIndexInputIteratorT; + ArgIndexInputIteratorT d_indexed_in(d_in); + + // Initial value + OutputTupleT initial_value(1, Traits::Lowest()); // replace with std::numeric_limits::lowest() when C++11 support is more prevalent + + return DispatchReduce::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_indexed_in, + d_out, + num_items, + cub::ArgMax(), + initial_value, + stream, + debug_synchronous); + } + + + /** + * \brief Reduces segments of values, where segments are demarcated by corresponding runs of identical keys. + * + * \par + * This operation computes segmented reductions within \p d_values_in using + * the specified binary \p reduction_op functor. The segments are identified by + * "runs" of corresponding keys in \p d_keys_in, where runs are maximal ranges of + * consecutive, identical keys. For the ith run encountered, + * the first key of the run and the corresponding value aggregate of that run are + * written to d_unique_out[i] and d_aggregates_out[i], + * respectively. The total number of runs encountered is written to \p d_num_runs_out. + * + * \par + * - The == equality operator is used to determine whether keys are equivalent + * - Provides "run-to-run" determinism for pseudo-associative reduction + * (e.g., addition of floating point types) on the same GPU device. + * However, results for pseudo-associative reduction may be inconsistent + * from one device to a another device of a different compute-capability + * because CUB can employ different tile-sizing for different architectures. + * - \devicestorage + * + * \par Performance + * The following chart illustrates reduction-by-key (sum) performance across + * different CUDA architectures for \p fp32 and \p fp64 values, respectively. Segments + * are identified by \p int32 keys, and have lengths uniformly sampled from [1,1000]. + * + * \image html reduce_by_key_fp32_len_500.png + * \image html reduce_by_key_fp64_len_500.png + * + * \par + * The following charts are similar, but with segment lengths uniformly sampled from [1,10]: + * + * \image html reduce_by_key_fp32_len_5.png + * \image html reduce_by_key_fp64_len_5.png + * + * \par Snippet + * The code snippet below illustrates the segmented reduction of \p int values grouped + * by runs of associated \p int keys. + * \par + * \code + * #include // or equivalently + * + * // CustomMin functor + * struct CustomMin + * { + * template + * CUB_RUNTIME_FUNCTION __forceinline__ + * T operator()(const T &a, const T &b) const { + * return (b < a) ? b : a; + * } + * }; + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_items; // e.g., 8 + * int *d_keys_in; // e.g., [0, 2, 2, 9, 5, 5, 5, 8] + * int *d_values_in; // e.g., [0, 7, 1, 6, 2, 5, 3, 4] + * int *d_unique_out; // e.g., [-, -, -, -, -, -, -, -] + * int *d_aggregates_out; // e.g., [-, -, -, -, -, -, -, -] + * int *d_num_runs_out; // e.g., [-] + * CustomMin reduction_op; + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceReduce::ReduceByKey(d_temp_storage, temp_storage_bytes, d_keys_in, d_unique_out, d_values_in, d_aggregates_out, d_num_runs_out, reduction_op, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run reduce-by-key + * cub::DeviceReduce::ReduceByKey(d_temp_storage, temp_storage_bytes, d_keys_in, d_unique_out, d_values_in, d_aggregates_out, d_num_runs_out, reduction_op, num_items); + * + * // d_unique_out <-- [0, 2, 9, 5, 8] + * // d_aggregates_out <-- [0, 1, 6, 2, 4] + * // d_num_runs_out <-- [5] + * + * \endcode + * + * \tparam KeysInputIteratorT [inferred] Random-access input iterator type for reading input keys \iterator + * \tparam UniqueOutputIteratorT [inferred] Random-access output iterator type for writing unique output keys \iterator + * \tparam ValuesInputIteratorT [inferred] Random-access input iterator type for reading input values \iterator + * \tparam AggregatesOutputIterator [inferred] Random-access output iterator type for writing output value aggregates \iterator + * \tparam NumRunsOutputIteratorT [inferred] Output iterator type for recording the number of runs encountered \iterator + * \tparam ReductionOpT [inferred] Binary reduction functor type having member T operator()(const T &a, const T &b) + */ + template < + typename KeysInputIteratorT, + typename UniqueOutputIteratorT, + typename ValuesInputIteratorT, + typename AggregatesOutputIteratorT, + typename NumRunsOutputIteratorT, + typename ReductionOpT> + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t ReduceByKey( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + KeysInputIteratorT d_keys_in, ///< [in] Pointer to the input sequence of keys + UniqueOutputIteratorT d_unique_out, ///< [out] Pointer to the output sequence of unique keys (one key per run) + ValuesInputIteratorT d_values_in, ///< [in] Pointer to the input sequence of corresponding values + AggregatesOutputIteratorT d_aggregates_out, ///< [out] Pointer to the output sequence of value aggregates (one aggregate per run) + NumRunsOutputIteratorT d_num_runs_out, ///< [out] Pointer to total number of runs encountered (i.e., the length of d_unique_out) + ReductionOpT reduction_op, ///< [in] Binary reduction functor + int num_items, ///< [in] Total number of associated key+value pairs (i.e., the length of \p d_in_keys and \p d_in_values) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + // FlagT iterator type (not used) + + // Selection op (not used) + + // Default == operator + typedef Equality EqualityOp; + + return DispatchReduceByKey::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys_in, + d_unique_out, + d_values_in, + d_aggregates_out, + d_num_runs_out, + EqualityOp(), + reduction_op, + num_items, + stream, + debug_synchronous); + } + +}; + +/** + * \example example_device_reduce.cu + */ + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/device/device_run_length_encode.cuh b/fastertransformer/cuda/cub/device/device_run_length_encode.cuh new file mode 100644 index 000000000..7a2e82d9d --- /dev/null +++ b/fastertransformer/cuda/cub/device/device_run_length_encode.cuh @@ -0,0 +1,278 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DeviceRunLengthEncode provides device-wide, parallel operations for computing a run-length encoding across a sequence of data items residing within device-accessible memory. + */ + +#pragma once + +#include +#include + +#include "dispatch/dispatch_rle.cuh" +#include "dispatch/dispatch_reduce_by_key.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \brief DeviceRunLengthEncode provides device-wide, parallel operations for demarcating "runs" of same-valued items within a sequence residing within device-accessible memory. ![](run_length_encode_logo.png) + * \ingroup SingleModule + * + * \par Overview + * A run-length encoding + * computes a simple compressed representation of a sequence of input elements such that each + * maximal "run" of consecutive same-valued data items is encoded as a single data value along with a + * count of the elements in that run. + * + * \par Usage Considerations + * \cdp_class{DeviceRunLengthEncode} + * + * \par Performance + * \linear_performance{run-length encode} + * + * \par + * The following chart illustrates DeviceRunLengthEncode::RunLengthEncode performance across + * different CUDA architectures for \p int32 items. + * Segments have lengths uniformly sampled from [1,1000]. + * + * \image html rle_int32_len_500.png + * + * \par + * \plots_below + * + */ +struct DeviceRunLengthEncode +{ + + /** + * \brief Computes a run-length encoding of the sequence \p d_in. + * + * \par + * - For the ith run encountered, the first key of the run and its length are written to + * d_unique_out[i] and d_counts_out[i], + * respectively. + * - The total number of runs encountered is written to \p d_num_runs_out. + * - The == equality operator is used to determine whether values are equivalent + * - \devicestorage + * + * \par Performance + * The following charts illustrate saturated encode performance across different + * CUDA architectures for \p int32 and \p int64 items, respectively. Segments have + * lengths uniformly sampled from [1,1000]. + * + * \image html rle_int32_len_500.png + * \image html rle_int64_len_500.png + * + * \par + * The following charts are similar, but with segment lengths uniformly sampled from [1,10]: + * + * \image html rle_int32_len_5.png + * \image html rle_int64_len_5.png + * + * \par Snippet + * The code snippet below illustrates the run-length encoding of a sequence of \p int values. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_items; // e.g., 8 + * int *d_in; // e.g., [0, 2, 2, 9, 5, 5, 5, 8] + * int *d_unique_out; // e.g., [ , , , , , , , ] + * int *d_counts_out; // e.g., [ , , , , , , , ] + * int *d_num_runs_out; // e.g., [ ] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceRunLengthEncode::Encode(d_temp_storage, temp_storage_bytes, d_in, d_unique_out, d_counts_out, d_num_runs_out, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run encoding + * cub::DeviceRunLengthEncode::Encode(d_temp_storage, temp_storage_bytes, d_in, d_unique_out, d_counts_out, d_num_runs_out, num_items); + * + * // d_unique_out <-- [0, 2, 9, 5, 8] + * // d_counts_out <-- [1, 2, 1, 3, 1] + * // d_num_runs_out <-- [5] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items \iterator + * \tparam UniqueOutputIteratorT [inferred] Random-access output iterator type for writing unique output items \iterator + * \tparam LengthsOutputIteratorT [inferred] Random-access output iterator type for writing output counts \iterator + * \tparam NumRunsOutputIteratorT [inferred] Output iterator type for recording the number of runs encountered \iterator + */ + template < + typename InputIteratorT, + typename UniqueOutputIteratorT, + typename LengthsOutputIteratorT, + typename NumRunsOutputIteratorT> + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Encode( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of keys + UniqueOutputIteratorT d_unique_out, ///< [out] Pointer to the output sequence of unique keys (one key per run) + LengthsOutputIteratorT d_counts_out, ///< [out] Pointer to the output sequence of run-lengths (one count per run) + NumRunsOutputIteratorT d_num_runs_out, ///< [out] Pointer to total number of runs + int num_items, ///< [in] Total number of associated key+value pairs (i.e., the length of \p d_in_keys and \p d_in_values) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + typedef int OffsetT; // Signed integer type for global offsets + typedef NullType* FlagIterator; // FlagT iterator type (not used) + typedef NullType SelectOp; // Selection op (not used) + typedef Equality EqualityOp; // Default == operator + typedef cub::Sum ReductionOp; // Value reduction operator + + // The lengths output value type + typedef typename If<(Equals::value_type, void>::VALUE), // LengthT = (if output iterator's value type is void) ? + OffsetT, // ... then the OffsetT type, + typename std::iterator_traits::value_type>::Type LengthT; // ... else the output iterator's value type + + // Generator type for providing 1s values for run-length reduction + typedef ConstantInputIterator LengthsInputIteratorT; + + return DispatchReduceByKey::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_unique_out, + LengthsInputIteratorT((LengthT) 1), + d_counts_out, + d_num_runs_out, + EqualityOp(), + ReductionOp(), + num_items, + stream, + debug_synchronous); + } + + + /** + * \brief Enumerates the starting offsets and lengths of all non-trivial runs (of length > 1) of same-valued keys in the sequence \p d_in. + * + * \par + * - For the ith non-trivial run, the run's starting offset + * and its length are written to d_offsets_out[i] and + * d_lengths_out[i], respectively. + * - The total number of runs encountered is written to \p d_num_runs_out. + * - The == equality operator is used to determine whether values are equivalent + * - \devicestorage + * + * \par Performance + * + * \par Snippet + * The code snippet below illustrates the identification of non-trivial runs within a sequence of \p int values. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_items; // e.g., 8 + * int *d_in; // e.g., [0, 2, 2, 9, 5, 5, 5, 8] + * int *d_offsets_out; // e.g., [ , , , , , , , ] + * int *d_lengths_out; // e.g., [ , , , , , , , ] + * int *d_num_runs_out; // e.g., [ ] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceRunLengthEncode::NonTrivialRuns(d_temp_storage, temp_storage_bytes, d_in, d_offsets_out, d_lengths_out, d_num_runs_out, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run encoding + * cub::DeviceRunLengthEncode::NonTrivialRuns(d_temp_storage, temp_storage_bytes, d_in, d_offsets_out, d_lengths_out, d_num_runs_out, num_items); + * + * // d_offsets_out <-- [1, 4] + * // d_lengths_out <-- [2, 3] + * // d_num_runs_out <-- [2] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items \iterator + * \tparam OffsetsOutputIteratorT [inferred] Random-access output iterator type for writing run-offset values \iterator + * \tparam LengthsOutputIteratorT [inferred] Random-access output iterator type for writing run-length values \iterator + * \tparam NumRunsOutputIteratorT [inferred] Output iterator type for recording the number of runs encountered \iterator + */ + template < + typename InputIteratorT, + typename OffsetsOutputIteratorT, + typename LengthsOutputIteratorT, + typename NumRunsOutputIteratorT> + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t NonTrivialRuns( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to input sequence of data items + OffsetsOutputIteratorT d_offsets_out, ///< [out] Pointer to output sequence of run-offsets (one offset per non-trivial run) + LengthsOutputIteratorT d_lengths_out, ///< [out] Pointer to output sequence of run-lengths (one count per non-trivial run) + NumRunsOutputIteratorT d_num_runs_out, ///< [out] Pointer to total number of runs (i.e., length of \p d_offsets_out) + int num_items, ///< [in] Total number of associated key+value pairs (i.e., the length of \p d_in_keys and \p d_in_values) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + typedef int OffsetT; // Signed integer type for global offsets + typedef Equality EqualityOp; // Default == operator + + return DeviceRleDispatch::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_offsets_out, + d_lengths_out, + d_num_runs_out, + EqualityOp(), + num_items, + stream, + debug_synchronous); + } + + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/device/device_scan.cuh b/fastertransformer/cuda/cub/device/device_scan.cuh new file mode 100644 index 000000000..e86fefe3c --- /dev/null +++ b/fastertransformer/cuda/cub/device/device_scan.cuh @@ -0,0 +1,443 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DeviceScan provides device-wide, parallel operations for computing a prefix scan across a sequence of data items residing within device-accessible memory. + */ + +#pragma once + +#include +#include + +#include "dispatch/dispatch_scan.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \brief DeviceScan provides device-wide, parallel operations for computing a prefix scan across a sequence of data items residing within device-accessible memory. ![](device_scan.png) + * \ingroup SingleModule + * + * \par Overview + * Given a sequence of input elements and a binary reduction operator, a [prefix scan](http://en.wikipedia.org/wiki/Prefix_sum) + * produces an output sequence where each element is computed to be the reduction + * of the elements occurring earlier in the input sequence. Prefix sum + * connotes a prefix scan with the addition operator. The term \em inclusive indicates + * that the ith output reduction incorporates the ith input. + * The term \em exclusive indicates the ith input is not incorporated into + * the ith output reduction. + * + * \par + * As of CUB 1.0.1 (2013), CUB's device-wide scan APIs have implemented our "decoupled look-back" algorithm + * for performing global prefix scan with only a single pass through the + * input data, as described in our 2016 technical report [1]. The central + * idea is to leverage a small, constant factor of redundant work in order to overlap the latencies + * of global prefix propagation with local computation. As such, our algorithm requires only + * ~2n data movement (n inputs are read, n outputs are written), and typically + * proceeds at "memcpy" speeds. + * + * \par + * [1] [Duane Merrill and Michael Garland. "Single-pass Parallel Prefix Scan with Decoupled Look-back", NVIDIA Technical Report NVR-2016-002, 2016.](https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back) + * + * \par Usage Considerations + * \cdp_class{DeviceScan} + * + * \par Performance + * \linear_performance{prefix scan} + * + * \par + * The following chart illustrates DeviceScan::ExclusiveSum + * performance across different CUDA architectures for \p int32 keys. + * \plots_below + * + * \image html scan_int32.png + * + */ +struct DeviceScan +{ + /******************************************************************//** + * \name Exclusive scans + *********************************************************************/ + //@{ + + /** + * \brief Computes a device-wide exclusive prefix sum. The value of 0 is applied as the initial value, and is assigned to *d_out. + * + * \par + * - Supports non-commutative sum operators. + * - Provides "run-to-run" determinism for pseudo-associative reduction + * (e.g., addition of floating point types) on the same GPU device. + * However, results for pseudo-associative reduction may be inconsistent + * from one device to a another device of a different compute-capability + * because CUB can employ different tile-sizing for different architectures. + * - \devicestorage + * + * \par Performance + * The following charts illustrate saturated exclusive sum performance across different + * CUDA architectures for \p int32 and \p int64 items, respectively. + * + * \image html scan_int32.png + * \image html scan_int64.png + * + * \par Snippet + * The code snippet below illustrates the exclusive prefix sum of an \p int device vector. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_items; // e.g., 7 + * int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_out; // e.g., [ , , , , , , ] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run exclusive prefix sum + * cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items); + * + * // d_out s<-- [0, 8, 14, 21, 26, 29, 29] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading scan inputs \iterator + * \tparam OutputIteratorT [inferred] Random-access output iterator type for writing scan outputs \iterator + */ + template < + typename InputIteratorT, + typename OutputIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t ExclusiveSum( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output sequence of data items + int num_items, ///< [in] Total number of input items (i.e., the length of \p d_in) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + // The output value type + typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? + typename std::iterator_traits::value_type, // ... then the input iterator's value type, + typename std::iterator_traits::value_type>::Type OutputT; // ... else the output iterator's value type + + // Initial value + OutputT init_value = 0; + + return DispatchScan::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + Sum(), + init_value, + num_items, + stream, + debug_synchronous); + } + + + /** + * \brief Computes a device-wide exclusive prefix scan using the specified binary \p scan_op functor. The \p init_value value is applied as the initial value, and is assigned to *d_out. + * + * \par + * - Supports non-commutative scan operators. + * - Provides "run-to-run" determinism for pseudo-associative reduction + * (e.g., addition of floating point types) on the same GPU device. + * However, results for pseudo-associative reduction may be inconsistent + * from one device to a another device of a different compute-capability + * because CUB can employ different tile-sizing for different architectures. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the exclusive prefix min-scan of an \p int device vector + * \par + * \code + * #include // or equivalently + * + * // CustomMin functor + * struct CustomMin + * { + * template + * CUB_RUNTIME_FUNCTION __forceinline__ + * T operator()(const T &a, const T &b) const { + * return (b < a) ? b : a; + * } + * }; + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_items; // e.g., 7 + * int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_out; // e.g., [ , , , , , , ] + * CustomMin min_op + * ... + * + * // Determine temporary device storage requirements for exclusive prefix scan + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceScan::ExclusiveScan(d_temp_storage, temp_storage_bytes, d_in, d_out, min_op, (int) MAX_INT, num_items); + * + * // Allocate temporary storage for exclusive prefix scan + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run exclusive prefix min-scan + * cub::DeviceScan::ExclusiveScan(d_temp_storage, temp_storage_bytes, d_in, d_out, min_op, (int) MAX_INT, num_items); + * + * // d_out <-- [2147483647, 8, 6, 6, 5, 3, 0] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading scan inputs \iterator + * \tparam OutputIteratorT [inferred] Random-access output iterator type for writing scan outputs \iterator + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + * \tparam Identity [inferred] Type of the \p identity value used Binary scan functor type having member T operator()(const T &a, const T &b) + */ + template < + typename InputIteratorT, + typename OutputIteratorT, + typename ScanOpT, + typename InitValueT> + CUB_RUNTIME_FUNCTION + static cudaError_t ExclusiveScan( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output sequence of data items + ScanOpT scan_op, ///< [in] Binary scan functor + InitValueT init_value, ///< [in] Initial value to seed the exclusive scan (and is assigned to *d_out) + int num_items, ///< [in] Total number of input items (i.e., the length of \p d_in) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + return DispatchScan::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + scan_op, + init_value, + num_items, + stream, + debug_synchronous); + } + + + //@} end member group + /******************************************************************//** + * \name Inclusive scans + *********************************************************************/ + //@{ + + + /** + * \brief Computes a device-wide inclusive prefix sum. + * + * \par + * - Supports non-commutative sum operators. + * - Provides "run-to-run" determinism for pseudo-associative reduction + * (e.g., addition of floating point types) on the same GPU device. + * However, results for pseudo-associative reduction may be inconsistent + * from one device to a another device of a different compute-capability + * because CUB can employ different tile-sizing for different architectures. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the inclusive prefix sum of an \p int device vector. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_items; // e.g., 7 + * int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_out; // e.g., [ , , , , , , ] + * ... + * + * // Determine temporary device storage requirements for inclusive prefix sum + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceScan::InclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items); + * + * // Allocate temporary storage for inclusive prefix sum + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run inclusive prefix sum + * cub::DeviceScan::InclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items); + * + * // d_out <-- [8, 14, 21, 26, 29, 29, 38] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading scan inputs \iterator + * \tparam OutputIteratorT [inferred] Random-access output iterator type for writing scan outputs \iterator + */ + template < + typename InputIteratorT, + typename OutputIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t InclusiveSum( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output sequence of data items + int num_items, ///< [in] Total number of input items (i.e., the length of \p d_in) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + return DispatchScan::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + Sum(), + NullType(), + num_items, + stream, + debug_synchronous); + } + + + /** + * \brief Computes a device-wide inclusive prefix scan using the specified binary \p scan_op functor. + * + * \par + * - Supports non-commutative scan operators. + * - Provides "run-to-run" determinism for pseudo-associative reduction + * (e.g., addition of floating point types) on the same GPU device. + * However, results for pseudo-associative reduction may be inconsistent + * from one device to a another device of a different compute-capability + * because CUB can employ different tile-sizing for different architectures. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the inclusive prefix min-scan of an \p int device vector. + * \par + * \code + * #include // or equivalently + * + * // CustomMin functor + * struct CustomMin + * { + * template + * CUB_RUNTIME_FUNCTION __forceinline__ + * T operator()(const T &a, const T &b) const { + * return (b < a) ? b : a; + * } + * }; + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_items; // e.g., 7 + * int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_out; // e.g., [ , , , , , , ] + * CustomMin min_op; + * ... + * + * // Determine temporary device storage requirements for inclusive prefix scan + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceScan::InclusiveScan(d_temp_storage, temp_storage_bytes, d_in, d_out, min_op, num_items); + * + * // Allocate temporary storage for inclusive prefix scan + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run inclusive prefix min-scan + * cub::DeviceScan::InclusiveScan(d_temp_storage, temp_storage_bytes, d_in, d_out, min_op, num_items); + * + * // d_out <-- [8, 6, 6, 5, 3, 0, 0] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading scan inputs \iterator + * \tparam OutputIteratorT [inferred] Random-access output iterator type for writing scan outputs \iterator + * \tparam ScanOp [inferred] Binary scan functor type having member T operator()(const T &a, const T &b) + */ + template < + typename InputIteratorT, + typename OutputIteratorT, + typename ScanOpT> + CUB_RUNTIME_FUNCTION + static cudaError_t InclusiveScan( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output sequence of data items + ScanOpT scan_op, ///< [in] Binary scan functor + int num_items, ///< [in] Total number of input items (i.e., the length of \p d_in) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + return DispatchScan::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + scan_op, + NullType(), + num_items, + stream, + debug_synchronous); + } + + //@} end member group + +}; + +/** + * \example example_device_scan.cu + */ + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/device/device_segmented_radix_sort.cuh b/fastertransformer/cuda/cub/device/device_segmented_radix_sort.cuh new file mode 100644 index 000000000..0d3607627 --- /dev/null +++ b/fastertransformer/cuda/cub/device/device_segmented_radix_sort.cuh @@ -0,0 +1,876 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DeviceSegmentedRadixSort provides device-wide, parallel operations for computing a batched radix sort across multiple, non-overlapping sequences of data items residing within device-accessible memory. + */ + +#pragma once + +#include +#include + +#include "dispatch/dispatch_radix_sort.cuh" +#include "../util_arch.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \brief DeviceSegmentedRadixSort provides device-wide, parallel operations for computing a batched radix sort across multiple, non-overlapping sequences of data items residing within device-accessible memory. ![](segmented_sorting_logo.png) + * \ingroup SegmentedModule + * + * \par Overview + * The [radix sorting method](http://en.wikipedia.org/wiki/Radix_sort) arranges + * items into ascending (or descending) order. The algorithm relies upon a positional representation for + * keys, i.e., each key is comprised of an ordered sequence of symbols (e.g., digits, + * characters, etc.) specified from least-significant to most-significant. For a + * given input sequence of keys and a set of rules specifying a total ordering + * of the symbolic alphabet, the radix sorting method produces a lexicographic + * ordering of those keys. + * + * \par + * DeviceSegmentedRadixSort can sort all of the built-in C++ numeric primitive types + * (unsigned char, \p int, \p double, etc.) as well as CUDA's \p __half + * half-precision floating-point type. Although the direct radix sorting + * method can only be applied to unsigned integral types, DeviceSegmentedRadixSort + * is able to sort signed and floating-point types via simple bit-wise transformations + * that ensure lexicographic key ordering. + * + * \par Usage Considerations + * \cdp_class{DeviceSegmentedRadixSort} + * + */ +struct DeviceSegmentedRadixSort +{ + + /******************************************************************//** + * \name Key-value pairs + *********************************************************************/ + //@{ + + /** + * \brief Sorts segments of key-value pairs into ascending order. (~2N auxiliary storage required) + * + * \par + * - The contents of the input data are not altered by the sorting operation + * - When input a contiguous sequence of segments, a single sequence + * \p segment_offsets (of length num_segments+1) can be aliased + * for both the \p d_begin_offsets and \p d_end_offsets parameters (where + * the latter is specified as segment_offsets+1). + * - An optional bit subrange [begin_bit, end_bit) of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + * - \devicestorageNP For sorting using only O(P) temporary storage, see the sorting interface using DoubleBuffer wrappers below. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of \p int keys + * with associated vector of \p int values. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int num_segments; // e.g., 3 + * int *d_offsets; // e.g., [0, 3, 3, 7] + * int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_keys_out; // e.g., [-, -, -, -, -, -, -] + * int *d_values_in; // e.g., [0, 1, 2, 3, 4, 5, 6] + * int *d_values_out; // e.g., [-, -, -, -, -, -, -] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, + * d_keys_in, d_keys_out, d_values_in, d_values_out, + * num_items, num_segments, d_offsets, d_offsets + 1); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, + * d_keys_in, d_keys_out, d_values_in, d_values_out, + * num_items, num_segments, d_offsets, d_offsets + 1); + * + * // d_keys_out <-- [6, 7, 8, 0, 3, 5, 9] + * // d_values_out <-- [1, 2, 0, 5, 4, 3, 6] + * + * \endcode + * + * \tparam KeyT [inferred] Key type + * \tparam ValueT [inferred] Value type + * \tparam OffsetIteratorT [inferred] Random-access input iterator type for reading segment offsets \iterator + */ + template < + typename KeyT, + typename ValueT, + typename OffsetIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t SortPairs( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + const KeyT *d_keys_in, ///< [in] %Device-accessible pointer to the input data of key data to sort + KeyT *d_keys_out, ///< [out] %Device-accessible pointer to the sorted output sequence of key data + const ValueT *d_values_in, ///< [in] %Device-accessible pointer to the corresponding input sequence of associated value items + ValueT *d_values_out, ///< [out] %Device-accessible pointer to the correspondingly-reordered output sequence of associated value items + int num_items, ///< [in] The total number of items to sort (across all segments) + int num_segments, ///< [in] The number of segments that comprise the sorting data + OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison + int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); + DoubleBuffer d_values(const_cast(d_values_in), d_values_out); + + return DispatchSegmentedRadixSort::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + num_items, + num_segments, + d_begin_offsets, + d_end_offsets, + begin_bit, + end_bit, + false, + stream, + debug_synchronous); + } + + + /** + * \brief Sorts segments of key-value pairs into ascending order. (~N auxiliary storage required) + * + * \par + * - The sorting operation is given a pair of key buffers and a corresponding + * pair of associated value buffers. Each pair is managed by a DoubleBuffer + * structure that indicates which of the two buffers is "current" (and thus + * contains the input data to be sorted). + * - The contents of both buffers within each pair may be altered by the sorting + * operation. + * - Upon completion, the sorting operation will update the "current" indicator + * within each DoubleBuffer wrapper to reference which of the two buffers + * now contains the sorted output sequence (a function of the number of key bits + * specified and the targeted device architecture). + * - When input a contiguous sequence of segments, a single sequence + * \p segment_offsets (of length num_segments+1) can be aliased + * for both the \p d_begin_offsets and \p d_end_offsets parameters (where + * the latter is specified as segment_offsets+1). + * - An optional bit subrange [begin_bit, end_bit) of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + * - \devicestorageP + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of \p int keys + * with associated vector of \p int values. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int num_segments; // e.g., 3 + * int *d_offsets; // e.g., [0, 3, 3, 7] + * int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_key_alt_buf; // e.g., [-, -, -, -, -, -, -] + * int *d_value_buf; // e.g., [0, 1, 2, 3, 4, 5, 6] + * int *d_value_alt_buf; // e.g., [-, -, -, -, -, -, -] + * ... + * + * // Create a set of DoubleBuffers to wrap pairs of device pointers + * cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + * cub::DoubleBuffer d_values(d_value_buf, d_value_alt_buf); + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, d_keys, d_values, + * num_items, num_segments, d_offsets, d_offsets + 1); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, d_keys, d_values, + * num_items, num_segments, d_offsets, d_offsets + 1); + * + * // d_keys.Current() <-- [6, 7, 8, 0, 3, 5, 9] + * // d_values.Current() <-- [5, 4, 3, 1, 2, 0, 6] + * + * \endcode + * + * \tparam KeyT [inferred] Key type + * \tparam ValueT [inferred] Value type + * \tparam OffsetIteratorT [inferred] Random-access input iterator type for reading segment offsets \iterator + */ + template < + typename KeyT, + typename ValueT, + typename OffsetIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t SortPairs( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + DoubleBuffer &d_keys, ///< [in,out] Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + DoubleBuffer &d_values, ///< [in,out] Double-buffer of values whose "current" device-accessible buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + int num_items, ///< [in] The total number of items to sort (across all segments) + int num_segments, ///< [in] The number of segments that comprise the sorting data + OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison + int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + return DispatchSegmentedRadixSort::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + num_items, + num_segments, + d_begin_offsets, + d_end_offsets, + begin_bit, + end_bit, + true, + stream, + debug_synchronous); + } + + + /** + * \brief Sorts segments of key-value pairs into descending order. (~2N auxiliary storage required). + * + * \par + * - The contents of the input data are not altered by the sorting operation + * - When input a contiguous sequence of segments, a single sequence + * \p segment_offsets (of length num_segments+1) can be aliased + * for both the \p d_begin_offsets and \p d_end_offsets parameters (where + * the latter is specified as segment_offsets+1). + * - An optional bit subrange [begin_bit, end_bit) of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + * - \devicestorageNP For sorting using only O(P) temporary storage, see the sorting interface using DoubleBuffer wrappers below. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of \p int keys + * with associated vector of \p int values. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int num_segments; // e.g., 3 + * int *d_offsets; // e.g., [0, 3, 3, 7] + * int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_keys_out; // e.g., [-, -, -, -, -, -, -] + * int *d_values_in; // e.g., [0, 1, 2, 3, 4, 5, 6] + * int *d_values_out; // e.g., [-, -, -, -, -, -, -] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, + * d_keys_in, d_keys_out, d_values_in, d_values_out, + * num_items, num_segments, d_offsets, d_offsets + 1); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, + * d_keys_in, d_keys_out, d_values_in, d_values_out, + * num_items, num_segments, d_offsets, d_offsets + 1); + * + * // d_keys_out <-- [8, 7, 6, 9, 5, 3, 0] + * // d_values_out <-- [0, 2, 1, 6, 3, 4, 5] + * + * \endcode + * + * \tparam KeyT [inferred] Key type + * \tparam ValueT [inferred] Value type + * \tparam OffsetIteratorT [inferred] Random-access input iterator type for reading segment offsets \iterator + */ + template < + typename KeyT, + typename ValueT, + typename OffsetIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t SortPairsDescending( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + const KeyT *d_keys_in, ///< [in] %Device-accessible pointer to the input data of key data to sort + KeyT *d_keys_out, ///< [out] %Device-accessible pointer to the sorted output sequence of key data + const ValueT *d_values_in, ///< [in] %Device-accessible pointer to the corresponding input sequence of associated value items + ValueT *d_values_out, ///< [out] %Device-accessible pointer to the correspondingly-reordered output sequence of associated value items + int num_items, ///< [in] The total number of items to sort (across all segments) + int num_segments, ///< [in] The number of segments that comprise the sorting data + OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison + int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); + DoubleBuffer d_values(const_cast(d_values_in), d_values_out); + + return DispatchSegmentedRadixSort::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + num_items, + num_segments, + d_begin_offsets, + d_end_offsets, + begin_bit, + end_bit, + false, + stream, + debug_synchronous); + } + + + /** + * \brief Sorts segments of key-value pairs into descending order. (~N auxiliary storage required). + * + * \par + * - The sorting operation is given a pair of key buffers and a corresponding + * pair of associated value buffers. Each pair is managed by a DoubleBuffer + * structure that indicates which of the two buffers is "current" (and thus + * contains the input data to be sorted). + * - The contents of both buffers within each pair may be altered by the sorting + * operation. + * - Upon completion, the sorting operation will update the "current" indicator + * within each DoubleBuffer wrapper to reference which of the two buffers + * now contains the sorted output sequence (a function of the number of key bits + * specified and the targeted device architecture). + * - When input a contiguous sequence of segments, a single sequence + * \p segment_offsets (of length num_segments+1) can be aliased + * for both the \p d_begin_offsets and \p d_end_offsets parameters (where + * the latter is specified as segment_offsets+1). + * - An optional bit subrange [begin_bit, end_bit) of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + * - \devicestorageP + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of \p int keys + * with associated vector of \p int values. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int num_segments; // e.g., 3 + * int *d_offsets; // e.g., [0, 3, 3, 7] + * int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_key_alt_buf; // e.g., [-, -, -, -, -, -, -] + * int *d_value_buf; // e.g., [0, 1, 2, 3, 4, 5, 6] + * int *d_value_alt_buf; // e.g., [-, -, -, -, -, -, -] + * ... + * + * // Create a set of DoubleBuffers to wrap pairs of device pointers + * cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + * cub::DoubleBuffer d_values(d_value_buf, d_value_alt_buf); + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, d_keys, d_values, + * num_items, num_segments, d_offsets, d_offsets + 1); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, d_keys, d_values, + * num_items, num_segments, d_offsets, d_offsets + 1); + * + * // d_keys.Current() <-- [8, 7, 6, 9, 5, 3, 0] + * // d_values.Current() <-- [0, 2, 1, 6, 3, 4, 5] + * + * \endcode + * + * \tparam KeyT [inferred] Key type + * \tparam ValueT [inferred] Value type + * \tparam OffsetIteratorT [inferred] Random-access input iterator type for reading segment offsets \iterator + */ + template < + typename KeyT, + typename ValueT, + typename OffsetIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t SortPairsDescending( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + DoubleBuffer &d_keys, ///< [in,out] Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + DoubleBuffer &d_values, ///< [in,out] Double-buffer of values whose "current" device-accessible buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + int num_items, ///< [in] The total number of items to sort (across all segments) + int num_segments, ///< [in] The number of segments that comprise the sorting data + OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison + int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + return DispatchSegmentedRadixSort::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + num_items, + num_segments, + d_begin_offsets, + d_end_offsets, + begin_bit, + end_bit, + true, + stream, + debug_synchronous); + } + + + //@} end member group + /******************************************************************//** + * \name Keys-only + *********************************************************************/ + //@{ + + + /** + * \brief Sorts segments of keys into ascending order. (~2N auxiliary storage required) + * + * \par + * - The contents of the input data are not altered by the sorting operation + * - An optional bit subrange [begin_bit, end_bit) of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + * - When input a contiguous sequence of segments, a single sequence + * \p segment_offsets (of length num_segments+1) can be aliased + * for both the \p d_begin_offsets and \p d_end_offsets parameters (where + * the latter is specified as segment_offsets+1). + * - \devicestorageNP For sorting using only O(P) temporary storage, see the sorting interface using DoubleBuffer wrappers below. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of \p int keys. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int num_segments; // e.g., 3 + * int *d_offsets; // e.g., [0, 3, 3, 7] + * int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_keys_out; // e.g., [-, -, -, -, -, -, -] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSegmentedRadixSort::SortKeys(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, + * num_items, num_segments, d_offsets, d_offsets + 1); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceSegmentedRadixSort::SortKeys(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, + * num_items, num_segments, d_offsets, d_offsets + 1); + * + * // d_keys_out <-- [6, 7, 8, 0, 3, 5, 9] + * + * \endcode + * + * \tparam KeyT [inferred] Key type + * \tparam OffsetIteratorT [inferred] Random-access input iterator type for reading segment offsets \iterator + */ + template < + typename KeyT, + typename OffsetIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t SortKeys( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + const KeyT *d_keys_in, ///< [in] %Device-accessible pointer to the input data of key data to sort + KeyT *d_keys_out, ///< [out] %Device-accessible pointer to the sorted output sequence of key data + int num_items, ///< [in] The total number of items to sort (across all segments) + int num_segments, ///< [in] The number of segments that comprise the sorting data + OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison + int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + // Null value type + DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); + DoubleBuffer d_values; + + return DispatchSegmentedRadixSort::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + num_items, + num_segments, + d_begin_offsets, + d_end_offsets, + begin_bit, + end_bit, + false, + stream, + debug_synchronous); + } + + + /** + * \brief Sorts segments of keys into ascending order. (~N auxiliary storage required). + * + * \par + * - The sorting operation is given a pair of key buffers managed by a + * DoubleBuffer structure that indicates which of the two buffers is + * "current" (and thus contains the input data to be sorted). + * - The contents of both buffers may be altered by the sorting operation. + * - Upon completion, the sorting operation will update the "current" indicator + * within the DoubleBuffer wrapper to reference which of the two buffers + * now contains the sorted output sequence (a function of the number of key bits + * specified and the targeted device architecture). + * - When input a contiguous sequence of segments, a single sequence + * \p segment_offsets (of length num_segments+1) can be aliased + * for both the \p d_begin_offsets and \p d_end_offsets parameters (where + * the latter is specified as segment_offsets+1). + * - An optional bit subrange [begin_bit, end_bit) of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + * - \devicestorageP + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of \p int keys. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int num_segments; // e.g., 3 + * int *d_offsets; // e.g., [0, 3, 3, 7] + * int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_key_alt_buf; // e.g., [-, -, -, -, -, -, -] + * ... + * + * // Create a DoubleBuffer to wrap the pair of device pointers + * cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSegmentedRadixSort::SortKeys(d_temp_storage, temp_storage_bytes, d_keys, + * num_items, num_segments, d_offsets, d_offsets + 1); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceSegmentedRadixSort::SortKeys(d_temp_storage, temp_storage_bytes, d_keys, + * num_items, num_segments, d_offsets, d_offsets + 1); + * + * // d_keys.Current() <-- [6, 7, 8, 0, 3, 5, 9] + * + * \endcode + * + * \tparam KeyT [inferred] Key type + * \tparam OffsetIteratorT [inferred] Random-access input iterator type for reading segment offsets \iterator + */ + template < + typename KeyT, + typename OffsetIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t SortKeys( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + DoubleBuffer &d_keys, ///< [in,out] Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + int num_items, ///< [in] The total number of items to sort (across all segments) + int num_segments, ///< [in] The number of segments that comprise the sorting data + OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison + int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + // Null value type + DoubleBuffer d_values; + + return DispatchSegmentedRadixSort::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + num_items, + num_segments, + d_begin_offsets, + d_end_offsets, + begin_bit, + end_bit, + true, + stream, + debug_synchronous); + } + + /** + * \brief Sorts segments of keys into descending order. (~2N auxiliary storage required). + * + * \par + * - The contents of the input data are not altered by the sorting operation + * - When input a contiguous sequence of segments, a single sequence + * \p segment_offsets (of length num_segments+1) can be aliased + * for both the \p d_begin_offsets and \p d_end_offsets parameters (where + * the latter is specified as segment_offsets+1). + * - An optional bit subrange [begin_bit, end_bit) of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + * - \devicestorageNP For sorting using only O(P) temporary storage, see the sorting interface using DoubleBuffer wrappers below. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of \p int keys. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int num_segments; // e.g., 3 + * int *d_offsets; // e.g., [0, 3, 3, 7] + * int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_keys_out; // e.g., [-, -, -, -, -, -, -] + * ... + * + * // Create a DoubleBuffer to wrap the pair of device pointers + * cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSegmentedRadixSort::SortKeysDescending(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, + * num_items, num_segments, d_offsets, d_offsets + 1); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceSegmentedRadixSort::SortKeysDescending(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, + * num_items, num_segments, d_offsets, d_offsets + 1); + * + * // d_keys_out <-- [8, 7, 6, 9, 5, 3, 0] + * + * \endcode + * + * \tparam KeyT [inferred] Key type + * \tparam OffsetIteratorT [inferred] Random-access input iterator type for reading segment offsets \iterator + */ + template < + typename KeyT, + typename OffsetIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t SortKeysDescending( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + const KeyT *d_keys_in, ///< [in] %Device-accessible pointer to the input data of key data to sort + KeyT *d_keys_out, ///< [out] %Device-accessible pointer to the sorted output sequence of key data + int num_items, ///< [in] The total number of items to sort (across all segments) + int num_segments, ///< [in] The number of segments that comprise the sorting data + OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison + int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); + DoubleBuffer d_values; + + return DispatchSegmentedRadixSort::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + num_items, + num_segments, + d_begin_offsets, + d_end_offsets, + begin_bit, + end_bit, + false, + stream, + debug_synchronous); + } + + + /** + * \brief Sorts segments of keys into descending order. (~N auxiliary storage required). + * + * \par + * - The sorting operation is given a pair of key buffers managed by a + * DoubleBuffer structure that indicates which of the two buffers is + * "current" (and thus contains the input data to be sorted). + * - The contents of both buffers may be altered by the sorting operation. + * - Upon completion, the sorting operation will update the "current" indicator + * within the DoubleBuffer wrapper to reference which of the two buffers + * now contains the sorted output sequence (a function of the number of key bits + * specified and the targeted device architecture). + * - When input a contiguous sequence of segments, a single sequence + * \p segment_offsets (of length num_segments+1) can be aliased + * for both the \p d_begin_offsets and \p d_end_offsets parameters (where + * the latter is specified as segment_offsets+1). + * - An optional bit subrange [begin_bit, end_bit) of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + * - \devicestorageP + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of \p int keys. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int num_segments; // e.g., 3 + * int *d_offsets; // e.g., [0, 3, 3, 7] + * int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_key_alt_buf; // e.g., [-, -, -, -, -, -, -] + * ... + * + * // Create a DoubleBuffer to wrap the pair of device pointers + * cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSegmentedRadixSort::SortKeysDescending(d_temp_storage, temp_storage_bytes, d_keys, + * num_items, num_segments, d_offsets, d_offsets + 1); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceSegmentedRadixSort::SortKeysDescending(d_temp_storage, temp_storage_bytes, d_keys, + * num_items, num_segments, d_offsets, d_offsets + 1); + * + * // d_keys.Current() <-- [8, 7, 6, 9, 5, 3, 0] + * + * \endcode + * + * \tparam KeyT [inferred] Key type + * \tparam OffsetIteratorT [inferred] Random-access input iterator type for reading segment offsets \iterator + */ + template < + typename KeyT, + typename OffsetIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t SortKeysDescending( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + DoubleBuffer &d_keys, ///< [in,out] Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + int num_items, ///< [in] The total number of items to sort (across all segments) + int num_segments, ///< [in] The number of segments that comprise the sorting data + OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison + int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + // Null value type + DoubleBuffer d_values; + + return DispatchSegmentedRadixSort::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + num_items, + num_segments, + d_begin_offsets, + d_end_offsets, + begin_bit, + end_bit, + true, + stream, + debug_synchronous); + } + + + //@} end member group + + +}; + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/device/device_segmented_reduce.cuh b/fastertransformer/cuda/cub/device/device_segmented_reduce.cuh new file mode 100644 index 000000000..6c3b54a03 --- /dev/null +++ b/fastertransformer/cuda/cub/device/device_segmented_reduce.cuh @@ -0,0 +1,619 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DeviceSegmentedReduce provides device-wide, parallel operations for computing a batched reduction across multiple sequences of data items residing within device-accessible memory. + */ + +#pragma once + +#include +#include + +#include "../iterator/arg_index_input_iterator.cuh" +#include "dispatch/dispatch_reduce.cuh" +#include "dispatch/dispatch_reduce_by_key.cuh" +#include "../util_type.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \brief DeviceSegmentedReduce provides device-wide, parallel operations for computing a reduction across multiple sequences of data items residing within device-accessible memory. ![](reduce_logo.png) + * \ingroup SegmentedModule + * + * \par Overview + * A reduction (or fold) + * uses a binary combining operator to compute a single aggregate from a sequence of input elements. + * + * \par Usage Considerations + * \cdp_class{DeviceSegmentedReduce} + * + */ +struct DeviceSegmentedReduce +{ + /** + * \brief Computes a device-wide segmented reduction using the specified binary \p reduction_op functor. + * + * \par + * - Does not support binary reduction operators that are non-commutative. + * - When input a contiguous sequence of segments, a single sequence + * \p segment_offsets (of length num_segments+1) can be aliased + * for both the \p d_begin_offsets and \p d_end_offsets parameters (where + * the latter is specified as segment_offsets+1). + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates a custom min-reduction of a device vector of \p int data elements. + * \par + * \code + * #include // or equivalently + * + * // CustomMin functor + * struct CustomMin + * { + * template + * CUB_RUNTIME_FUNCTION __forceinline__ + * T operator()(const T &a, const T &b) const { + * return (b < a) ? b : a; + * } + * }; + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_segments; // e.g., 3 + * int *d_offsets; // e.g., [0, 3, 3, 7] + * int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_out; // e.g., [-, -, -] + * CustomMin min_op; + * int initial_value; // e.g., INT_MAX + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSegmentedReduce::Reduce(d_temp_storage, temp_storage_bytes, d_in, d_out, + * num_segments, d_offsets, d_offsets + 1, min_op, initial_value); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run reduction + * cub::DeviceSegmentedReduce::Reduce(d_temp_storage, temp_storage_bytes, d_in, d_out, + * num_segments, d_offsets, d_offsets + 1, min_op, initial_value); + * + * // d_out <-- [6, INT_MAX, 0] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items \iterator + * \tparam OutputIteratorT [inferred] Output iterator type for recording the reduced aggregate \iterator + * \tparam OffsetIteratorT [inferred] Random-access input iterator type for reading segment offsets \iterator + * \tparam ReductionOp [inferred] Binary reduction functor type having member T operator()(const T &a, const T &b) + * \tparam T [inferred] Data element type that is convertible to the \p value type of \p InputIteratorT + */ + template < + typename InputIteratorT, + typename OutputIteratorT, + typename OffsetIteratorT, + typename ReductionOp, + typename T> + CUB_RUNTIME_FUNCTION + static cudaError_t Reduce( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output aggregate + int num_segments, ///< [in] The number of segments that comprise the sorting data + OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + ReductionOp reduction_op, ///< [in] Binary reduction functor + T initial_value, ///< [in] Initial value of the reduction for each segment + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + return DispatchSegmentedReduce::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + num_segments, + d_begin_offsets, + d_end_offsets, + reduction_op, + initial_value, + stream, + debug_synchronous); + } + + + /** + * \brief Computes a device-wide segmented sum using the addition ('+') operator. + * + * \par + * - Uses \p 0 as the initial value of the reduction for each segment. + * - When input a contiguous sequence of segments, a single sequence + * \p segment_offsets (of length num_segments+1) can be aliased + * for both the \p d_begin_offsets and \p d_end_offsets parameters (where + * the latter is specified as segment_offsets+1). + * - Does not support \p + operators that are non-commutative.. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the sum reduction of a device vector of \p int data elements. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_segments; // e.g., 3 + * int *d_offsets; // e.g., [0, 3, 3, 7] + * int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_out; // e.g., [-, -, -] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSegmentedReduce::Sum(d_temp_storage, temp_storage_bytes, d_in, d_out, + * num_segments, d_offsets, d_offsets + 1); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sum-reduction + * cub::DeviceSegmentedReduce::Sum(d_temp_storage, temp_storage_bytes, d_in, d_out, + * num_segments, d_offsets, d_offsets + 1); + * + * // d_out <-- [21, 0, 17] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items \iterator + * \tparam OutputIteratorT [inferred] Output iterator type for recording the reduced aggregate \iterator + * \tparam OffsetIteratorT [inferred] Random-access input iterator type for reading segment offsets \iterator + */ + template < + typename InputIteratorT, + typename OutputIteratorT, + typename OffsetIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t Sum( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output aggregate + int num_segments, ///< [in] The number of segments that comprise the sorting data + OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + // The output value type + typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? + typename std::iterator_traits::value_type, // ... then the input iterator's value type, + typename std::iterator_traits::value_type>::Type OutputT; // ... else the output iterator's value type + + return DispatchSegmentedReduce::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + num_segments, + d_begin_offsets, + d_end_offsets, + cub::Sum(), + OutputT(), // zero-initialize + stream, + debug_synchronous); + } + + + /** + * \brief Computes a device-wide segmented minimum using the less-than ('<') operator. + * + * \par + * - Uses std::numeric_limits::max() as the initial value of the reduction for each segment. + * - When input a contiguous sequence of segments, a single sequence + * \p segment_offsets (of length num_segments+1) can be aliased + * for both the \p d_begin_offsets and \p d_end_offsets parameters (where + * the latter is specified as segment_offsets+1). + * - Does not support \p < operators that are non-commutative. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the min-reduction of a device vector of \p int data elements. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_segments; // e.g., 3 + * int *d_offsets; // e.g., [0, 3, 3, 7] + * int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_out; // e.g., [-, -, -] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSegmentedReduce::Min(d_temp_storage, temp_storage_bytes, d_in, d_out, + * num_segments, d_offsets, d_offsets + 1); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run min-reduction + * cub::DeviceSegmentedReduce::Min(d_temp_storage, temp_storage_bytes, d_in, d_out, + * num_segments, d_offsets, d_offsets + 1); + * + * // d_out <-- [6, INT_MAX, 0] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items \iterator + * \tparam OutputIteratorT [inferred] Output iterator type for recording the reduced aggregate \iterator + * \tparam OffsetIteratorT [inferred] Random-access input iterator type for reading segment offsets \iterator + */ + template < + typename InputIteratorT, + typename OutputIteratorT, + typename OffsetIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t Min( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output aggregate + int num_segments, ///< [in] The number of segments that comprise the sorting data + OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + // The input value type + typedef typename std::iterator_traits::value_type InputT; + + return DispatchSegmentedReduce::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + num_segments, + d_begin_offsets, + d_end_offsets, + cub::Min(), + Traits::Max(), // replace with std::numeric_limits::max() when C++11 support is more prevalent + stream, + debug_synchronous); + } + + + /** + * \brief Finds the first device-wide minimum in each segment using the less-than ('<') operator, also returning the in-segment index of that item. + * + * \par + * - The output value type of \p d_out is cub::KeyValuePair (assuming the value type of \p d_in is \p T) + * - The minimum of the ith segment is written to d_out[i].value and its offset in that segment is written to d_out[i].key. + * - The {1, std::numeric_limits::max()} tuple is produced for zero-length inputs + * - When input a contiguous sequence of segments, a single sequence + * \p segment_offsets (of length num_segments+1) can be aliased + * for both the \p d_begin_offsets and \p d_end_offsets parameters (where + * the latter is specified as segment_offsets+1). + * - Does not support \p < operators that are non-commutative. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the argmin-reduction of a device vector of \p int data elements. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_segments; // e.g., 3 + * int *d_offsets; // e.g., [0, 3, 3, 7] + * int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * KeyValuePair *d_out; // e.g., [{-,-}, {-,-}, {-,-}] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSegmentedReduce::ArgMin(d_temp_storage, temp_storage_bytes, d_in, d_out, + * num_segments, d_offsets, d_offsets + 1); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run argmin-reduction + * cub::DeviceSegmentedReduce::ArgMin(d_temp_storage, temp_storage_bytes, d_in, d_out, + * num_segments, d_offsets, d_offsets + 1); + * + * // d_out <-- [{1,6}, {1,INT_MAX}, {2,0}] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items (of some type \p T) \iterator + * \tparam OutputIteratorT [inferred] Output iterator type for recording the reduced aggregate (having value type KeyValuePair) \iterator + * \tparam OffsetIteratorT [inferred] Random-access input iterator type for reading segment offsets \iterator + */ + template < + typename InputIteratorT, + typename OutputIteratorT, + typename OffsetIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t ArgMin( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output aggregate + int num_segments, ///< [in] The number of segments that comprise the sorting data + OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + // The input type + typedef typename std::iterator_traits::value_type InputValueT; + + // The output tuple type + typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? + KeyValuePair, // ... then the key value pair OffsetT + InputValueT + typename std::iterator_traits::value_type>::Type OutputTupleT; // ... else the output iterator's value type + + // The output value type + typedef typename OutputTupleT::Value OutputValueT; + + // Wrapped input iterator to produce index-value tuples + typedef ArgIndexInputIterator ArgIndexInputIteratorT; + ArgIndexInputIteratorT d_indexed_in(d_in); + + // Initial value + OutputTupleT initial_value(1, Traits::Max()); // replace with std::numeric_limits::max() when C++11 support is more prevalent + + return DispatchSegmentedReduce::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_indexed_in, + d_out, + num_segments, + d_begin_offsets, + d_end_offsets, + cub::ArgMin(), + initial_value, + stream, + debug_synchronous); + } + + + /** + * \brief Computes a device-wide segmented maximum using the greater-than ('>') operator. + * + * \par + * - Uses std::numeric_limits::lowest() as the initial value of the reduction. + * - When input a contiguous sequence of segments, a single sequence + * \p segment_offsets (of length num_segments+1) can be aliased + * for both the \p d_begin_offsets and \p d_end_offsets parameters (where + * the latter is specified as segment_offsets+1). + * - Does not support \p > operators that are non-commutative. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the max-reduction of a device vector of \p int data elements. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_segments; // e.g., 3 + * int *d_offsets; // e.g., [0, 3, 3, 7] + * int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * int *d_out; // e.g., [-, -, -] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSegmentedReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, + * num_segments, d_offsets, d_offsets + 1); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run max-reduction + * cub::DeviceSegmentedReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, + * num_segments, d_offsets, d_offsets + 1); + * + * // d_out <-- [8, INT_MIN, 9] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items \iterator + * \tparam OutputIteratorT [inferred] Output iterator type for recording the reduced aggregate \iterator + * \tparam OffsetIteratorT [inferred] Random-access input iterator type for reading segment offsets \iterator + */ + template < + typename InputIteratorT, + typename OutputIteratorT, + typename OffsetIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t Max( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output aggregate + int num_segments, ///< [in] The number of segments that comprise the sorting data + OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + // The input value type + typedef typename std::iterator_traits::value_type InputT; + + return DispatchSegmentedReduce::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + num_segments, + d_begin_offsets, + d_end_offsets, + cub::Max(), + Traits::Lowest(), // replace with std::numeric_limits::lowest() when C++11 support is more prevalent + stream, + debug_synchronous); + } + + + /** + * \brief Finds the first device-wide maximum in each segment using the greater-than ('>') operator, also returning the in-segment index of that item + * + * \par + * - The output value type of \p d_out is cub::KeyValuePair (assuming the value type of \p d_in is \p T) + * - The maximum of the ith segment is written to d_out[i].value and its offset in that segment is written to d_out[i].key. + * - The {1, std::numeric_limits::lowest()} tuple is produced for zero-length inputs + * - When input a contiguous sequence of segments, a single sequence + * \p segment_offsets (of length num_segments+1) can be aliased + * for both the \p d_begin_offsets and \p d_end_offsets parameters (where + * the latter is specified as segment_offsets+1). + * - Does not support \p > operators that are non-commutative. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the argmax-reduction of a device vector of \p int data elements. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_segments; // e.g., 3 + * int *d_offsets; // e.g., [0, 3, 3, 7] + * int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * KeyValuePair *d_out; // e.g., [{-,-}, {-,-}, {-,-}] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSegmentedReduce::ArgMax(d_temp_storage, temp_storage_bytes, d_in, d_out, + * num_segments, d_offsets, d_offsets + 1); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run argmax-reduction + * cub::DeviceSegmentedReduce::ArgMax(d_temp_storage, temp_storage_bytes, d_in, d_out, + * num_segments, d_offsets, d_offsets + 1); + * + * // d_out <-- [{0,8}, {1,INT_MIN}, {3,9}] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items (of some type \p T) \iterator + * \tparam OutputIteratorT [inferred] Output iterator type for recording the reduced aggregate (having value type KeyValuePair) \iterator + * \tparam OffsetIteratorT [inferred] Random-access input iterator type for reading segment offsets \iterator + */ + template < + typename InputIteratorT, + typename OutputIteratorT, + typename OffsetIteratorT> + CUB_RUNTIME_FUNCTION + static cudaError_t ArgMax( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output aggregate + int num_segments, ///< [in] The number of segments that comprise the sorting data + OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + // Signed integer type for global offsets + typedef int OffsetT; + + // The input type + typedef typename std::iterator_traits::value_type InputValueT; + + // The output tuple type + typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? + KeyValuePair, // ... then the key value pair OffsetT + InputValueT + typename std::iterator_traits::value_type>::Type OutputTupleT; // ... else the output iterator's value type + + // The output value type + typedef typename OutputTupleT::Value OutputValueT; + + // Wrapped input iterator to produce index-value tuples + typedef ArgIndexInputIterator ArgIndexInputIteratorT; + ArgIndexInputIteratorT d_indexed_in(d_in); + + // Initial value + OutputTupleT initial_value(1, Traits::Lowest()); // replace with std::numeric_limits::lowest() when C++11 support is more prevalent + + return DispatchSegmentedReduce::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_indexed_in, + d_out, + num_segments, + d_begin_offsets, + d_end_offsets, + cub::ArgMax(), + initial_value, + stream, + debug_synchronous); + } + +}; + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/device/device_select.cuh b/fastertransformer/cuda/cub/device/device_select.cuh new file mode 100644 index 000000000..52a3e126d --- /dev/null +++ b/fastertransformer/cuda/cub/device/device_select.cuh @@ -0,0 +1,369 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DeviceSelect provides device-wide, parallel operations for compacting selected items from sequences of data items residing within device-accessible memory. + */ + +#pragma once + +#include +#include + +#include "dispatch/dispatch_select_if.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \brief DeviceSelect provides device-wide, parallel operations for compacting selected items from sequences of data items residing within device-accessible memory. ![](select_logo.png) + * \ingroup SingleModule + * + * \par Overview + * These operations apply a selection criterion to selectively copy + * items from a specified input sequence to a compact output sequence. + * + * \par Usage Considerations + * \cdp_class{DeviceSelect} + * + * \par Performance + * \linear_performance{select-flagged, select-if, and select-unique} + * + * \par + * The following chart illustrates DeviceSelect::If + * performance across different CUDA architectures for \p int32 items, + * where 50% of the items are randomly selected. + * + * \image html select_if_int32_50_percent.png + * + * \par + * The following chart illustrates DeviceSelect::Unique + * performance across different CUDA architectures for \p int32 items + * where segments have lengths uniformly sampled from [1,1000]. + * + * \image html select_unique_int32_len_500.png + * + * \par + * \plots_below + * + */ +struct DeviceSelect +{ + /** + * \brief Uses the \p d_flags sequence to selectively copy the corresponding items from \p d_in into \p d_out. The total number of items selected is written to \p d_num_selected_out. ![](select_flags_logo.png) + * + * \par + * - The value type of \p d_flags must be castable to \p bool (e.g., \p bool, \p char, \p int, etc.). + * - Copies of the selected items are compacted into \p d_out and maintain their original relative ordering. + * - \devicestorage + * + * \par Snippet + * The code snippet below illustrates the compaction of items selected from an \p int device vector. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input, flags, and output + * int num_items; // e.g., 8 + * int *d_in; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] + * char *d_flags; // e.g., [1, 0, 0, 1, 0, 1, 1, 0] + * int *d_out; // e.g., [ , , , , , , , ] + * int *d_num_selected_out; // e.g., [ ] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSelect::Flagged(d_temp_storage, temp_storage_bytes, d_in, d_flags, d_out, d_num_selected_out, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run selection + * cub::DeviceSelect::Flagged(d_temp_storage, temp_storage_bytes, d_in, d_flags, d_out, d_num_selected_out, num_items); + * + * // d_out <-- [1, 4, 6, 7] + * // d_num_selected_out <-- [4] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items \iterator + * \tparam FlagIterator [inferred] Random-access input iterator type for reading selection flags \iterator + * \tparam OutputIteratorT [inferred] Random-access output iterator type for writing selected items \iterator + * \tparam NumSelectedIteratorT [inferred] Output iterator type for recording the number of items selected \iterator + */ + template < + typename InputIteratorT, + typename FlagIterator, + typename OutputIteratorT, + typename NumSelectedIteratorT> + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Flagged( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + FlagIterator d_flags, ///< [in] Pointer to the input sequence of selection flags + OutputIteratorT d_out, ///< [out] Pointer to the output sequence of selected data items + NumSelectedIteratorT d_num_selected_out, ///< [out] Pointer to the output total number of items selected (i.e., length of \p d_out) + int num_items, ///< [in] Total number of input items (i.e., length of \p d_in) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + typedef int OffsetT; // Signed integer type for global offsets + typedef NullType SelectOp; // Selection op (not used) + typedef NullType EqualityOp; // Equality operator (not used) + + return DispatchSelectIf::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_flags, + d_out, + d_num_selected_out, + SelectOp(), + EqualityOp(), + num_items, + stream, + debug_synchronous); + } + + + /** + * \brief Uses the \p select_op functor to selectively copy items from \p d_in into \p d_out. The total number of items selected is written to \p d_num_selected_out. ![](select_logo.png) + * + * \par + * - Copies of the selected items are compacted into \p d_out and maintain their original relative ordering. + * - \devicestorage + * + * \par Performance + * The following charts illustrate saturated select-if performance across different + * CUDA architectures for \p int32 and \p int64 items, respectively. Items are + * selected with 50% probability. + * + * \image html select_if_int32_50_percent.png + * \image html select_if_int64_50_percent.png + * + * \par + * The following charts are similar, but 5% selection probability: + * + * \image html select_if_int32_5_percent.png + * \image html select_if_int64_5_percent.png + * + * \par Snippet + * The code snippet below illustrates the compaction of items selected from an \p int device vector. + * \par + * \code + * #include // or equivalently + * + * // Functor type for selecting values less than some criteria + * struct LessThan + * { + * int compare; + * + * CUB_RUNTIME_FUNCTION __forceinline__ + * LessThan(int compare) : compare(compare) {} + * + * CUB_RUNTIME_FUNCTION __forceinline__ + * bool operator()(const int &a) const { + * return (a < compare); + * } + * }; + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_items; // e.g., 8 + * int *d_in; // e.g., [0, 2, 3, 9, 5, 2, 81, 8] + * int *d_out; // e.g., [ , , , , , , , ] + * int *d_num_selected_out; // e.g., [ ] + * LessThan select_op(7); + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSelect::If(d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run selection + * cub::DeviceSelect::If(d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op); + * + * // d_out <-- [0, 2, 3, 5, 2] + * // d_num_selected_out <-- [5] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items \iterator + * \tparam OutputIteratorT [inferred] Random-access output iterator type for writing selected items \iterator + * \tparam NumSelectedIteratorT [inferred] Output iterator type for recording the number of items selected \iterator + * \tparam SelectOp [inferred] Selection operator type having member bool operator()(const T &a) + */ + template < + typename InputIteratorT, + typename OutputIteratorT, + typename NumSelectedIteratorT, + typename SelectOp> + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t If( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output sequence of selected data items + NumSelectedIteratorT d_num_selected_out, ///< [out] Pointer to the output total number of items selected (i.e., length of \p d_out) + int num_items, ///< [in] Total number of input items (i.e., length of \p d_in) + SelectOp select_op, ///< [in] Unary selection operator + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + typedef int OffsetT; // Signed integer type for global offsets + typedef NullType* FlagIterator; // FlagT iterator type (not used) + typedef NullType EqualityOp; // Equality operator (not used) + + return DispatchSelectIf::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + NULL, + d_out, + d_num_selected_out, + select_op, + EqualityOp(), + num_items, + stream, + debug_synchronous); + } + + + /** + * \brief Given an input sequence \p d_in having runs of consecutive equal-valued keys, only the first key from each run is selectively copied to \p d_out. The total number of items selected is written to \p d_num_selected_out. ![](unique_logo.png) + * + * \par + * - The == equality operator is used to determine whether keys are equivalent + * - Copies of the selected items are compacted into \p d_out and maintain their original relative ordering. + * - \devicestorage + * + * \par Performance + * The following charts illustrate saturated select-unique performance across different + * CUDA architectures for \p int32 and \p int64 items, respectively. Segments have + * lengths uniformly sampled from [1,1000]. + * + * \image html select_unique_int32_len_500.png + * \image html select_unique_int64_len_500.png + * + * \par + * The following charts are similar, but with segment lengths uniformly sampled from [1,10]: + * + * \image html select_unique_int32_len_5.png + * \image html select_unique_int64_len_5.png + * + * \par Snippet + * The code snippet below illustrates the compaction of items selected from an \p int device vector. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input and output + * int num_items; // e.g., 8 + * int *d_in; // e.g., [0, 2, 2, 9, 5, 5, 5, 8] + * int *d_out; // e.g., [ , , , , , , , ] + * int *d_num_selected_out; // e.g., [ ] + * ... + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSelect::Unique(d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run selection + * cub::DeviceSelect::Unique(d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items); + * + * // d_out <-- [0, 2, 9, 5, 8] + * // d_num_selected_out <-- [5] + * + * \endcode + * + * \tparam InputIteratorT [inferred] Random-access input iterator type for reading input items \iterator + * \tparam OutputIteratorT [inferred] Random-access output iterator type for writing selected items \iterator + * \tparam NumSelectedIteratorT [inferred] Output iterator type for recording the number of items selected \iterator + */ + template < + typename InputIteratorT, + typename OutputIteratorT, + typename NumSelectedIteratorT> + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Unique( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output sequence of selected data items + NumSelectedIteratorT d_num_selected_out, ///< [out] Pointer to the output total number of items selected (i.e., length of \p d_out) + int num_items, ///< [in] Total number of input items (i.e., length of \p d_in) + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + typedef int OffsetT; // Signed integer type for global offsets + typedef NullType* FlagIterator; // FlagT iterator type (not used) + typedef NullType SelectOp; // Selection op (not used) + typedef Equality EqualityOp; // Default == operator + + return DispatchSelectIf::Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + NULL, + d_out, + d_num_selected_out, + SelectOp(), + EqualityOp(), + num_items, + stream, + debug_synchronous); + } + +}; + +/** + * \example example_device_select_flagged.cu + * \example example_device_select_if.cu + * \example example_device_select_unique.cu + */ + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/device/device_spmv.cuh b/fastertransformer/cuda/cub/device/device_spmv.cuh new file mode 100644 index 000000000..63b6a7e86 --- /dev/null +++ b/fastertransformer/cuda/cub/device/device_spmv.cuh @@ -0,0 +1,174 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DeviceSpmv provides device-wide parallel operations for performing sparse-matrix * vector multiplication (SpMV). + */ + +#pragma once + +#include +#include +#include + +#include "dispatch/dispatch_spmv_orig.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \brief DeviceSpmv provides device-wide parallel operations for performing sparse-matrix * dense-vector multiplication (SpMV). + * \ingroup SingleModule + * + * \par Overview + * The [SpMV computation](http://en.wikipedia.org/wiki/Sparse_matrix-vector_multiplication) + * performs the matrix-vector operation + * y = alpha*A*x + beta*y, + * where: + * - A is an mxn sparse matrix whose non-zero structure is specified in + * [compressed-storage-row (CSR) format](http://en.wikipedia.org/wiki/Sparse_matrix#Compressed_row_Storage_.28CRS_or_CSR.29) + * (i.e., three arrays: values, row_offsets, and column_indices) + * - x and y are dense vectors + * - alpha and beta are scalar multiplicands + * + * \par Usage Considerations + * \cdp_class{DeviceSpmv} + * + */ +struct DeviceSpmv +{ + /******************************************************************//** + * \name CSR matrix operations + *********************************************************************/ + //@{ + + /** + * \brief This function performs the matrix-vector operation y = A*x. + * + * \par Snippet + * The code snippet below illustrates SpMV upon a 9x9 CSR matrix A + * representing a 3x3 lattice (24 non-zeros). + * + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for input matrix A, input vector x, + * // and output vector y + * int num_rows = 9; + * int num_cols = 9; + * int num_nonzeros = 24; + * + * float* d_values; // e.g., [1, 1, 1, 1, 1, 1, 1, 1, + * // 1, 1, 1, 1, 1, 1, 1, 1, + * // 1, 1, 1, 1, 1, 1, 1, 1] + * + * int* d_column_indices; // e.g., [1, 3, 0, 2, 4, 1, 5, 0, + * // 4, 6, 1, 3, 5, 7, 2, 4, + * // 8, 3, 7, 4, 6, 8, 5, 7] + * + * int* d_row_offsets; // e.g., [0, 2, 5, 7, 10, 14, 17, 19, 22, 24] + * + * float* d_vector_x; // e.g., [1, 1, 1, 1, 1, 1, 1, 1, 1] + * float* d_vector_y; // e.g., [ , , , , , , , , ] + * ... + * + * // Determine temporary device storage requirements + * void* d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceSpmv::CsrMV(d_temp_storage, temp_storage_bytes, d_values, + * d_row_offsets, d_column_indices, d_vector_x, d_vector_y, + * num_rows, num_cols, num_nonzeros, alpha, beta); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run SpMV + * cub::DeviceSpmv::CsrMV(d_temp_storage, temp_storage_bytes, d_values, + * d_row_offsets, d_column_indices, d_vector_x, d_vector_y, + * num_rows, num_cols, num_nonzeros, alpha, beta); + * + * // d_vector_y <-- [2, 3, 2, 3, 4, 3, 2, 3, 2] + * + * \endcode + * + * \tparam ValueT [inferred] Matrix and vector value type (e.g., /p float, /p double, etc.) + */ + template < + typename ValueT> + CUB_RUNTIME_FUNCTION + static cudaError_t CsrMV( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + ValueT* d_values, ///< [in] Pointer to the array of \p num_nonzeros values of the corresponding nonzero elements of matrix A. + int* d_row_offsets, ///< [in] Pointer to the array of \p m + 1 offsets demarcating the start of every row in \p d_column_indices and \p d_values (with the final entry being equal to \p num_nonzeros) + int* d_column_indices, ///< [in] Pointer to the array of \p num_nonzeros column-indices of the corresponding nonzero elements of matrix A. (Indices are zero-valued.) + ValueT* d_vector_x, ///< [in] Pointer to the array of \p num_cols values corresponding to the dense input vector x + ValueT* d_vector_y, ///< [out] Pointer to the array of \p num_rows values corresponding to the dense output vector y + int num_rows, ///< [in] number of rows of matrix A. + int num_cols, ///< [in] number of columns of matrix A. + int num_nonzeros, ///< [in] number of nonzero elements of matrix A. + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + SpmvParams spmv_params; + spmv_params.d_values = d_values; + spmv_params.d_row_end_offsets = d_row_offsets + 1; + spmv_params.d_column_indices = d_column_indices; + spmv_params.d_vector_x = d_vector_x; + spmv_params.d_vector_y = d_vector_y; + spmv_params.num_rows = num_rows; + spmv_params.num_cols = num_cols; + spmv_params.num_nonzeros = num_nonzeros; + spmv_params.alpha = 1.0; + spmv_params.beta = 0.0; + + return DispatchSpmv::Dispatch( + d_temp_storage, + temp_storage_bytes, + spmv_params, + stream, + debug_synchronous); + } + + //@} end member group +}; + + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/device/dispatch/dispatch_histogram.cuh b/fastertransformer/cuda/cub/device/dispatch/dispatch_histogram.cuh new file mode 100644 index 000000000..ab08e8ed0 --- /dev/null +++ b/fastertransformer/cuda/cub/device/dispatch/dispatch_histogram.cuh @@ -0,0 +1,1096 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DeviceHistogram provides device-wide parallel operations for constructing histogram(s) from a sequence of samples data residing within device-accessible memory. + */ + +#pragma once + +#include +#include +#include + +#include "../../agent/agent_histogram.cuh" +#include "../../util_debug.cuh" +#include "../../util_device.cuh" +#include "../../thread/thread_search.cuh" +#include "../../grid/grid_queue.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + + +/****************************************************************************** + * Histogram kernel entry points + *****************************************************************************/ + +/** + * Histogram initialization kernel entry point + */ +template < + int NUM_ACTIVE_CHANNELS, ///< Number of channels actively being histogrammed + typename CounterT, ///< Integer type for counting sample occurrences per histogram bin + typename OffsetT> ///< Signed integer type for global offsets +__global__ void DeviceHistogramInitKernel( + ArrayWrapper num_output_bins_wrapper, ///< Number of output histogram bins per channel + ArrayWrapper d_output_histograms_wrapper, ///< Histogram counter data having logical dimensions CounterT[NUM_ACTIVE_CHANNELS][num_bins.array[CHANNEL]] + GridQueue tile_queue) ///< Drain queue descriptor for dynamically mapping tile data onto thread blocks +{ + if ((threadIdx.x == 0) && (blockIdx.x == 0)) + tile_queue.ResetDrain(); + + int output_bin = (blockIdx.x * blockDim.x) + threadIdx.x; + + #pragma unroll + for (int CHANNEL = 0; CHANNEL < NUM_ACTIVE_CHANNELS; ++CHANNEL) + { + if (output_bin < num_output_bins_wrapper.array[CHANNEL]) + d_output_histograms_wrapper.array[CHANNEL][output_bin] = 0; + } +} + + +/** + * Histogram privatized sweep kernel entry point (multi-block). Computes privatized histograms, one per thread block. + */ +template < + typename AgentHistogramPolicyT, ///< Parameterized AgentHistogramPolicy tuning policy type + int PRIVATIZED_SMEM_BINS, ///< Maximum number of histogram bins per channel (e.g., up to 256) + int NUM_CHANNELS, ///< Number of channels interleaved in the input data (may be greater than the number of channels being actively histogrammed) + int NUM_ACTIVE_CHANNELS, ///< Number of channels actively being histogrammed + typename SampleIteratorT, ///< The input iterator type. \iterator. + typename CounterT, ///< Integer type for counting sample occurrences per histogram bin + typename PrivatizedDecodeOpT, ///< The transform operator type for determining privatized counter indices from samples, one for each channel + typename OutputDecodeOpT, ///< The transform operator type for determining output bin-ids from privatized counter indices, one for each channel + typename OffsetT> ///< Signed integer type for global offsets +__launch_bounds__ (int(AgentHistogramPolicyT::BLOCK_THREADS)) +__global__ void DeviceHistogramSweepKernel( + SampleIteratorT d_samples, ///< Input data to reduce + ArrayWrapper num_output_bins_wrapper, ///< The number bins per final output histogram + ArrayWrapper num_privatized_bins_wrapper, ///< The number bins per privatized histogram + ArrayWrapper d_output_histograms_wrapper, ///< Reference to final output histograms + ArrayWrapper d_privatized_histograms_wrapper, ///< Reference to privatized histograms + ArrayWrapper output_decode_op_wrapper, ///< The transform operator for determining output bin-ids from privatized counter indices, one for each channel + ArrayWrapper privatized_decode_op_wrapper, ///< The transform operator for determining privatized counter indices from samples, one for each channel + OffsetT num_row_pixels, ///< The number of multi-channel pixels per row in the region of interest + OffsetT num_rows, ///< The number of rows in the region of interest + OffsetT row_stride_samples, ///< The number of samples between starts of consecutive rows in the region of interest + int tiles_per_row, ///< Number of image tiles per row + GridQueue tile_queue) ///< Drain queue descriptor for dynamically mapping tile data onto thread blocks +{ + // Thread block type for compositing input tiles + typedef AgentHistogram< + AgentHistogramPolicyT, + PRIVATIZED_SMEM_BINS, + NUM_CHANNELS, + NUM_ACTIVE_CHANNELS, + SampleIteratorT, + CounterT, + PrivatizedDecodeOpT, + OutputDecodeOpT, + OffsetT> + AgentHistogramT; + + // Shared memory for AgentHistogram + __shared__ typename AgentHistogramT::TempStorage temp_storage; + + AgentHistogramT agent( + temp_storage, + d_samples, + num_output_bins_wrapper.array, + num_privatized_bins_wrapper.array, + d_output_histograms_wrapper.array, + d_privatized_histograms_wrapper.array, + output_decode_op_wrapper.array, + privatized_decode_op_wrapper.array); + + // Initialize counters + agent.InitBinCounters(); + + // Consume input tiles + agent.ConsumeTiles( + num_row_pixels, + num_rows, + row_stride_samples, + tiles_per_row, + tile_queue); + + // Store output to global (if necessary) + agent.StoreOutput(); + +} + + + + + + +/****************************************************************************** + * Dispatch + ******************************************************************************/ + +/** + * Utility class for dispatching the appropriately-tuned kernels for DeviceHistogram + */ +template < + int NUM_CHANNELS, ///< Number of channels interleaved in the input data (may be greater than the number of channels being actively histogrammed) + int NUM_ACTIVE_CHANNELS, ///< Number of channels actively being histogrammed + typename SampleIteratorT, ///< Random-access input iterator type for reading input items \iterator + typename CounterT, ///< Integer type for counting sample occurrences per histogram bin + typename LevelT, ///< Type for specifying bin level boundaries + typename OffsetT> ///< Signed integer type for global offsets +struct DipatchHistogram +{ + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + /// The sample value type of the input iterator + typedef typename std::iterator_traits::value_type SampleT; + + enum + { + // Maximum number of bins per channel for which we will use a privatized smem strategy + MAX_PRIVATIZED_SMEM_BINS = 256 + }; + + + //--------------------------------------------------------------------- + // Transform functors for converting samples to bin-ids + //--------------------------------------------------------------------- + + // Searches for bin given a list of bin-boundary levels + template + struct SearchTransform + { + LevelIteratorT d_levels; // Pointer to levels array + int num_output_levels; // Number of levels in array + + // Initializer + __host__ __device__ __forceinline__ void Init( + LevelIteratorT d_levels, // Pointer to levels array + int num_output_levels) // Number of levels in array + { + this->d_levels = d_levels; + this->num_output_levels = num_output_levels; + } + + // Method for converting samples to bin-ids + template + __host__ __device__ __forceinline__ void BinSelect(_SampleT sample, int &bin, bool valid) + { + /// Level iterator wrapper type + typedef typename If::VALUE, + CacheModifiedInputIterator, // Wrap the native input pointer with CacheModifiedInputIterator + LevelIteratorT>::Type // Directly use the supplied input iterator type + WrappedLevelIteratorT; + + WrappedLevelIteratorT wrapped_levels(d_levels); + + int num_bins = num_output_levels - 1; + if (valid) + { + bin = UpperBound(wrapped_levels, num_output_levels, (LevelT) sample) - 1; + if (bin >= num_bins) + bin = -1; + } + } + }; + + + // Scales samples to evenly-spaced bins + struct ScaleTransform + { + int num_bins; // Number of levels in array + LevelT max; // Max sample level (exclusive) + LevelT min; // Min sample level (inclusive) + LevelT scale; // Bin scaling factor + + // Initializer + template + __host__ __device__ __forceinline__ void Init( + int num_output_levels, // Number of levels in array + _LevelT max, // Max sample level (exclusive) + _LevelT min, // Min sample level (inclusive) + _LevelT scale) // Bin scaling factor + { + this->num_bins = num_output_levels - 1; + this->max = max; + this->min = min; + this->scale = scale; + } + + // Initializer (float specialization) + __host__ __device__ __forceinline__ void Init( + int num_output_levels, // Number of levels in array + float max, // Max sample level (exclusive) + float min, // Min sample level (inclusive) + float scale) // Bin scaling factor + { + this->num_bins = num_output_levels - 1; + this->max = max; + this->min = min; + this->scale = float(1.0) / scale; + } + + // Initializer (double specialization) + __host__ __device__ __forceinline__ void Init( + int num_output_levels, // Number of levels in array + double max, // Max sample level (exclusive) + double min, // Min sample level (inclusive) + double scale) // Bin scaling factor + { + this->num_bins = num_output_levels - 1; + this->max = max; + this->min = min; + this->scale = double(1.0) / scale; + } + + // Method for converting samples to bin-ids + template + __host__ __device__ __forceinline__ void BinSelect(_SampleT sample, int &bin, bool valid) + { + LevelT level_sample = (LevelT) sample; + + if (valid && (level_sample >= min) && (level_sample < max)) + bin = (int) ((level_sample - min) / scale); + } + + // Method for converting samples to bin-ids (float specialization) + template + __host__ __device__ __forceinline__ void BinSelect(float sample, int &bin, bool valid) + { + LevelT level_sample = (LevelT) sample; + + if (valid && (level_sample >= min) && (level_sample < max)) + bin = (int) ((level_sample - min) * scale); + } + + // Method for converting samples to bin-ids (double specialization) + template + __host__ __device__ __forceinline__ void BinSelect(double sample, int &bin, bool valid) + { + LevelT level_sample = (LevelT) sample; + + if (valid && (level_sample >= min) && (level_sample < max)) + bin = (int) ((level_sample - min) * scale); + } + }; + + + // Pass-through bin transform operator + struct PassThruTransform + { + // Method for converting samples to bin-ids + template + __host__ __device__ __forceinline__ void BinSelect(_SampleT sample, int &bin, bool valid) + { + if (valid) + bin = (int) sample; + } + }; + + + + //--------------------------------------------------------------------- + // Tuning policies + //--------------------------------------------------------------------- + + template + struct TScale + { + enum + { + V_SCALE = (sizeof(SampleT) + sizeof(int) - 1) / sizeof(int), + VALUE = CUB_MAX((NOMINAL_ITEMS_PER_THREAD / NUM_ACTIVE_CHANNELS / V_SCALE), 1) + }; + }; + + + /// SM11 + struct Policy110 + { + // HistogramSweepPolicy + typedef AgentHistogramPolicy< + 512, + (NUM_CHANNELS == 1) ? 8 : 2, + BLOCK_LOAD_DIRECT, + LOAD_DEFAULT, + true, + GMEM, + false> + HistogramSweepPolicy; + }; + + /// SM20 + struct Policy200 + { + // HistogramSweepPolicy + typedef AgentHistogramPolicy< + (NUM_CHANNELS == 1) ? 256 : 128, + (NUM_CHANNELS == 1) ? 8 : 3, + (NUM_CHANNELS == 1) ? BLOCK_LOAD_DIRECT : BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + true, + SMEM, + false> + HistogramSweepPolicy; + }; + + /// SM30 + struct Policy300 + { + // HistogramSweepPolicy + typedef AgentHistogramPolicy< + 512, + (NUM_CHANNELS == 1) ? 8 : 2, + BLOCK_LOAD_DIRECT, + LOAD_DEFAULT, + true, + GMEM, + false> + HistogramSweepPolicy; + }; + + /// SM35 + struct Policy350 + { + // HistogramSweepPolicy + typedef AgentHistogramPolicy< + 128, + TScale<8>::VALUE, + BLOCK_LOAD_DIRECT, + LOAD_LDG, + true, + BLEND, + true> + HistogramSweepPolicy; + }; + + /// SM50 + struct Policy500 + { + // HistogramSweepPolicy + typedef AgentHistogramPolicy< + 384, + TScale<16>::VALUE, + BLOCK_LOAD_DIRECT, + LOAD_LDG, + true, + SMEM, + false> + HistogramSweepPolicy; + }; + + + + //--------------------------------------------------------------------- + // Tuning policies of current PTX compiler pass + //--------------------------------------------------------------------- + +#if (CUB_PTX_ARCH >= 500) + typedef Policy500 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 350) + typedef Policy350 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 300) + typedef Policy300 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 200) + typedef Policy200 PtxPolicy; + +#else + typedef Policy110 PtxPolicy; + +#endif + + // "Opaque" policies (whose parameterizations aren't reflected in the type signature) + struct PtxHistogramSweepPolicy : PtxPolicy::HistogramSweepPolicy {}; + + + //--------------------------------------------------------------------- + // Utilities + //--------------------------------------------------------------------- + + /** + * Initialize kernel dispatch configurations with the policies corresponding to the PTX assembly we will use + */ + template + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t InitConfigs( + int ptx_version, + KernelConfig &histogram_sweep_config) + { + #if (CUB_PTX_ARCH > 0) + + // We're on the device, so initialize the kernel dispatch configurations with the current PTX policy + return histogram_sweep_config.template Init(); + + #else + + // We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version + if (ptx_version >= 500) + { + return histogram_sweep_config.template Init(); + } + else if (ptx_version >= 350) + { + return histogram_sweep_config.template Init(); + } + else if (ptx_version >= 300) + { + return histogram_sweep_config.template Init(); + } + else if (ptx_version >= 200) + { + return histogram_sweep_config.template Init(); + } + else if (ptx_version >= 110) + { + return histogram_sweep_config.template Init(); + } + else + { + // No global atomic support + return cudaErrorNotSupported; + } + + #endif + } + + + /** + * Kernel kernel dispatch configuration + */ + struct KernelConfig + { + int block_threads; + int pixels_per_thread; + + template + CUB_RUNTIME_FUNCTION __forceinline__ + cudaError_t Init() + { + block_threads = BlockPolicy::BLOCK_THREADS; + pixels_per_thread = BlockPolicy::PIXELS_PER_THREAD; + + return cudaSuccess; + } + }; + + + //--------------------------------------------------------------------- + // Dispatch entrypoints + //--------------------------------------------------------------------- + + /** + * Privatization-based dispatch routine + */ + template < + typename PrivatizedDecodeOpT, ///< The transform operator type for determining privatized counter indices from samples, one for each channel + typename OutputDecodeOpT, ///< The transform operator type for determining output bin-ids from privatized counter indices, one for each channel + typename DeviceHistogramInitKernelT, ///< Function type of cub::DeviceHistogramInitKernel + typename DeviceHistogramSweepKernelT> ///< Function type of cub::DeviceHistogramSweepKernel + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t PrivatizedDispatch( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + SampleIteratorT d_samples, ///< [in] The pointer to the input sequence of sample items. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four RGBA 8-bit samples). + CounterT* d_output_histograms[NUM_ACTIVE_CHANNELS], ///< [out] The pointers to the histogram counter output arrays, one for each active channel. For channeli, the allocation length of d_histograms[i] should be num_output_levels[i] - 1. + int num_privatized_levels[NUM_ACTIVE_CHANNELS], ///< [in] The number of bin level boundaries for delineating histogram samples in each active channel. Implies that the number of bins for channeli is num_output_levels[i] - 1. + PrivatizedDecodeOpT privatized_decode_op[NUM_ACTIVE_CHANNELS], ///< [in] Transform operators for determining bin-ids from samples, one for each channel + int num_output_levels[NUM_ACTIVE_CHANNELS], ///< [in] The number of bin level boundaries for delineating histogram samples in each active channel. Implies that the number of bins for channeli is num_output_levels[i] - 1. + OutputDecodeOpT output_decode_op[NUM_ACTIVE_CHANNELS], ///< [in] Transform operators for determining bin-ids from samples, one for each channel + int max_num_output_bins, ///< [in] Maximum number of output bins in any channel + OffsetT num_row_pixels, ///< [in] The number of multi-channel pixels per row in the region of interest + OffsetT num_rows, ///< [in] The number of rows in the region of interest + OffsetT row_stride_samples, ///< [in] The number of samples between starts of consecutive rows in the region of interest + DeviceHistogramInitKernelT histogram_init_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceHistogramInitKernel + DeviceHistogramSweepKernelT histogram_sweep_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceHistogramSweepKernel + KernelConfig histogram_sweep_config, ///< [in] Dispatch parameters that match the policy that \p histogram_sweep_kernel was compiled for + cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous) ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + #ifndef CUB_RUNTIME_ENABLED + + // Kernel launch not supported from this device + return CubDebug(cudaErrorNotSupported); + + #else + + cudaError error = cudaSuccess; + do + { + // Get device ordinal + int device_ordinal; + if (CubDebug(error = cudaGetDevice(&device_ordinal))) break; + + // Get SM count + int sm_count; + if (CubDebug(error = cudaDeviceGetAttribute (&sm_count, cudaDevAttrMultiProcessorCount, device_ordinal))) break; + + // Get SM occupancy for histogram_sweep_kernel + int histogram_sweep_sm_occupancy; + if (CubDebug(error = MaxSmOccupancy( + histogram_sweep_sm_occupancy, + histogram_sweep_kernel, + histogram_sweep_config.block_threads))) break; + + // Get device occupancy for histogram_sweep_kernel + int histogram_sweep_occupancy = histogram_sweep_sm_occupancy * sm_count; + + if (num_row_pixels * NUM_CHANNELS == row_stride_samples) + { + // Treat as a single linear array of samples + num_row_pixels *= num_rows; + num_rows = 1; + row_stride_samples = num_row_pixels * NUM_CHANNELS; + } + + // Get grid dimensions, trying to keep total blocks ~histogram_sweep_occupancy + int pixels_per_tile = histogram_sweep_config.block_threads * histogram_sweep_config.pixels_per_thread; + int tiles_per_row = int(num_row_pixels + pixels_per_tile - 1) / pixels_per_tile; + int blocks_per_row = CUB_MIN(histogram_sweep_occupancy, tiles_per_row); + int blocks_per_col = (blocks_per_row > 0) ? + int(CUB_MIN(histogram_sweep_occupancy / blocks_per_row, num_rows)) : + 0; + int num_thread_blocks = blocks_per_row * blocks_per_col; + + dim3 sweep_grid_dims; + sweep_grid_dims.x = (unsigned int) blocks_per_row; + sweep_grid_dims.y = (unsigned int) blocks_per_col; + sweep_grid_dims.z = 1; + + // Temporary storage allocation requirements + const int NUM_ALLOCATIONS = NUM_ACTIVE_CHANNELS + 1; + void* allocations[NUM_ALLOCATIONS]; + size_t allocation_sizes[NUM_ALLOCATIONS]; + + for (int CHANNEL = 0; CHANNEL < NUM_ACTIVE_CHANNELS; ++CHANNEL) + allocation_sizes[CHANNEL] = size_t(num_thread_blocks) * (num_privatized_levels[CHANNEL] - 1) * sizeof(CounterT); + + allocation_sizes[NUM_ALLOCATIONS - 1] = GridQueue::AllocationSize(); + + // Alias the temporary allocations from the single storage blob (or compute the necessary size of the blob) + if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break; + if (d_temp_storage == NULL) + { + // Return if the caller is simply requesting the size of the storage allocation + break; + } + + // Construct the grid queue descriptor + GridQueue tile_queue(allocations[NUM_ALLOCATIONS - 1]); + + // Setup array wrapper for histogram channel output (because we can't pass static arrays as kernel parameters) + ArrayWrapper d_output_histograms_wrapper; + for (int CHANNEL = 0; CHANNEL < NUM_ACTIVE_CHANNELS; ++CHANNEL) + d_output_histograms_wrapper.array[CHANNEL] = d_output_histograms[CHANNEL]; + + // Setup array wrapper for privatized per-block histogram channel output (because we can't pass static arrays as kernel parameters) + ArrayWrapper d_privatized_histograms_wrapper; + for (int CHANNEL = 0; CHANNEL < NUM_ACTIVE_CHANNELS; ++CHANNEL) + d_privatized_histograms_wrapper.array[CHANNEL] = (CounterT*) allocations[CHANNEL]; + + // Setup array wrapper for sweep bin transforms (because we can't pass static arrays as kernel parameters) + ArrayWrapper privatized_decode_op_wrapper; + for (int CHANNEL = 0; CHANNEL < NUM_ACTIVE_CHANNELS; ++CHANNEL) + privatized_decode_op_wrapper.array[CHANNEL] = privatized_decode_op[CHANNEL]; + + // Setup array wrapper for aggregation bin transforms (because we can't pass static arrays as kernel parameters) + ArrayWrapper output_decode_op_wrapper; + for (int CHANNEL = 0; CHANNEL < NUM_ACTIVE_CHANNELS; ++CHANNEL) + output_decode_op_wrapper.array[CHANNEL] = output_decode_op[CHANNEL]; + + // Setup array wrapper for num privatized bins (because we can't pass static arrays as kernel parameters) + ArrayWrapper num_privatized_bins_wrapper; + for (int CHANNEL = 0; CHANNEL < NUM_ACTIVE_CHANNELS; ++CHANNEL) + num_privatized_bins_wrapper.array[CHANNEL] = num_privatized_levels[CHANNEL] - 1; + + // Setup array wrapper for num output bins (because we can't pass static arrays as kernel parameters) + ArrayWrapper num_output_bins_wrapper; + for (int CHANNEL = 0; CHANNEL < NUM_ACTIVE_CHANNELS; ++CHANNEL) + num_output_bins_wrapper.array[CHANNEL] = num_output_levels[CHANNEL] - 1; + + int histogram_init_block_threads = 256; + int histogram_init_grid_dims = (max_num_output_bins + histogram_init_block_threads - 1) / histogram_init_block_threads; + + // Log DeviceHistogramInitKernel configuration + if (debug_synchronous) _CubLog("Invoking DeviceHistogramInitKernel<<<%d, %d, 0, %lld>>>()\n", + histogram_init_grid_dims, histogram_init_block_threads, (long long) stream); + + // Invoke histogram_init_kernel + histogram_init_kernel<<>>( + num_output_bins_wrapper, + d_output_histograms_wrapper, + tile_queue); + + // Return if empty problem + if ((blocks_per_row == 0) || (blocks_per_col == 0)) + break; + + // Log histogram_sweep_kernel configuration + if (debug_synchronous) _CubLog("Invoking histogram_sweep_kernel<<<{%d, %d, %d}, %d, 0, %lld>>>(), %d pixels per thread, %d SM occupancy\n", + sweep_grid_dims.x, sweep_grid_dims.y, sweep_grid_dims.z, + histogram_sweep_config.block_threads, (long long) stream, histogram_sweep_config.pixels_per_thread, histogram_sweep_sm_occupancy); + + // Invoke histogram_sweep_kernel + histogram_sweep_kernel<<>>( + d_samples, + num_output_bins_wrapper, + num_privatized_bins_wrapper, + d_output_histograms_wrapper, + d_privatized_histograms_wrapper, + output_decode_op_wrapper, + privatized_decode_op_wrapper, + num_row_pixels, + num_rows, + row_stride_samples, + tiles_per_row, + tile_queue); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + + } + while (0); + + return error; + + #endif // CUB_RUNTIME_ENABLED + } + + + + /** + * Dispatch routine for HistogramRange, specialized for sample types larger than 8bit + */ + CUB_RUNTIME_FUNCTION + static cudaError_t DispatchRange( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + SampleIteratorT d_samples, ///< [in] The pointer to the multi-channel input sequence of data samples. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four RGBA 8-bit samples). + CounterT* d_output_histograms[NUM_ACTIVE_CHANNELS], ///< [out] The pointers to the histogram counter output arrays, one for each active channel. For channeli, the allocation length of d_histograms[i] should be num_output_levels[i] - 1. + int num_output_levels[NUM_ACTIVE_CHANNELS], ///< [in] The number of boundaries (levels) for delineating histogram samples in each active channel. Implies that the number of bins for channeli is num_output_levels[i] - 1. + LevelT *d_levels[NUM_ACTIVE_CHANNELS], ///< [in] The pointers to the arrays of boundaries (levels), one for each active channel. Bin ranges are defined by consecutive boundary pairings: lower sample value boundaries are inclusive and upper sample value boundaries are exclusive. + OffsetT num_row_pixels, ///< [in] The number of multi-channel pixels per row in the region of interest + OffsetT num_rows, ///< [in] The number of rows in the region of interest + OffsetT row_stride_samples, ///< [in] The number of samples between starts of consecutive rows in the region of interest + cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous, ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + Int2Type is_byte_sample) ///< [in] Marker type indicating whether or not SampleT is a 8b type + { + cudaError error = cudaSuccess; + do + { + // Get PTX version + int ptx_version; + #if (CUB_PTX_ARCH == 0) + if (CubDebug(error = PtxVersion(ptx_version))) break; + #else + ptx_version = CUB_PTX_ARCH; + #endif + + // Get kernel dispatch configurations + KernelConfig histogram_sweep_config; + if (CubDebug(error = InitConfigs(ptx_version, histogram_sweep_config))) + break; + + // Use the search transform op for converting samples to privatized bins + typedef SearchTransform PrivatizedDecodeOpT; + + // Use the pass-thru transform op for converting privatized bins to output bins + typedef PassThruTransform OutputDecodeOpT; + + PrivatizedDecodeOpT privatized_decode_op[NUM_ACTIVE_CHANNELS]; + OutputDecodeOpT output_decode_op[NUM_ACTIVE_CHANNELS]; + int max_levels = num_output_levels[0]; + + for (int channel = 0; channel < NUM_ACTIVE_CHANNELS; ++channel) + { + privatized_decode_op[channel].Init(d_levels[channel], num_output_levels[channel]); + if (num_output_levels[channel] > max_levels) + max_levels = num_output_levels[channel]; + } + int max_num_output_bins = max_levels - 1; + + // Dispatch + if (max_num_output_bins > MAX_PRIVATIZED_SMEM_BINS) + { + // Too many bins to keep in shared memory. + const int PRIVATIZED_SMEM_BINS = 0; + + if (CubDebug(error = PrivatizedDispatch( + d_temp_storage, + temp_storage_bytes, + d_samples, + d_output_histograms, + num_output_levels, + privatized_decode_op, + num_output_levels, + output_decode_op, + max_num_output_bins, + num_row_pixels, + num_rows, + row_stride_samples, + DeviceHistogramInitKernel, + DeviceHistogramSweepKernel, + histogram_sweep_config, + stream, + debug_synchronous))) break; + } + else + { + // Dispatch shared-privatized approach + const int PRIVATIZED_SMEM_BINS = MAX_PRIVATIZED_SMEM_BINS; + + if (CubDebug(error = PrivatizedDispatch( + d_temp_storage, + temp_storage_bytes, + d_samples, + d_output_histograms, + num_output_levels, + privatized_decode_op, + num_output_levels, + output_decode_op, + max_num_output_bins, + num_row_pixels, + num_rows, + row_stride_samples, + DeviceHistogramInitKernel, + DeviceHistogramSweepKernel, + histogram_sweep_config, + stream, + debug_synchronous))) break; + } + + } while (0); + + return error; + } + + + /** + * Dispatch routine for HistogramRange, specialized for 8-bit sample types (computes 256-bin privatized histograms and then reduces to user-specified levels) + */ + CUB_RUNTIME_FUNCTION + static cudaError_t DispatchRange( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + SampleIteratorT d_samples, ///< [in] The pointer to the multi-channel input sequence of data samples. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four RGBA 8-bit samples). + CounterT* d_output_histograms[NUM_ACTIVE_CHANNELS], ///< [out] The pointers to the histogram counter output arrays, one for each active channel. For channeli, the allocation length of d_histograms[i] should be num_output_levels[i] - 1. + int num_output_levels[NUM_ACTIVE_CHANNELS], ///< [in] The number of boundaries (levels) for delineating histogram samples in each active channel. Implies that the number of bins for channeli is num_output_levels[i] - 1. + LevelT *d_levels[NUM_ACTIVE_CHANNELS], ///< [in] The pointers to the arrays of boundaries (levels), one for each active channel. Bin ranges are defined by consecutive boundary pairings: lower sample value boundaries are inclusive and upper sample value boundaries are exclusive. + OffsetT num_row_pixels, ///< [in] The number of multi-channel pixels per row in the region of interest + OffsetT num_rows, ///< [in] The number of rows in the region of interest + OffsetT row_stride_samples, ///< [in] The number of samples between starts of consecutive rows in the region of interest + cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous, ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + Int2Type is_byte_sample) ///< [in] Marker type indicating whether or not SampleT is a 8b type + { + cudaError error = cudaSuccess; + do + { + // Get PTX version + int ptx_version; + #if (CUB_PTX_ARCH == 0) + if (CubDebug(error = PtxVersion(ptx_version))) break; + #else + ptx_version = CUB_PTX_ARCH; + #endif + + // Get kernel dispatch configurations + KernelConfig histogram_sweep_config; + if (CubDebug(error = InitConfigs(ptx_version, histogram_sweep_config))) + break; + + // Use the pass-thru transform op for converting samples to privatized bins + typedef PassThruTransform PrivatizedDecodeOpT; + + // Use the search transform op for converting privatized bins to output bins + typedef SearchTransform OutputDecodeOpT; + + int num_privatized_levels[NUM_ACTIVE_CHANNELS]; + PrivatizedDecodeOpT privatized_decode_op[NUM_ACTIVE_CHANNELS]; + OutputDecodeOpT output_decode_op[NUM_ACTIVE_CHANNELS]; + int max_levels = num_output_levels[0]; // Maximum number of levels in any channel + + for (int channel = 0; channel < NUM_ACTIVE_CHANNELS; ++channel) + { + num_privatized_levels[channel] = 257; + output_decode_op[channel].Init(d_levels[channel], num_output_levels[channel]); + + if (num_output_levels[channel] > max_levels) + max_levels = num_output_levels[channel]; + } + int max_num_output_bins = max_levels - 1; + + const int PRIVATIZED_SMEM_BINS = 256; + + if (CubDebug(error = PrivatizedDispatch( + d_temp_storage, + temp_storage_bytes, + d_samples, + d_output_histograms, + num_privatized_levels, + privatized_decode_op, + num_output_levels, + output_decode_op, + max_num_output_bins, + num_row_pixels, + num_rows, + row_stride_samples, + DeviceHistogramInitKernel, + DeviceHistogramSweepKernel, + histogram_sweep_config, + stream, + debug_synchronous))) break; + + } while (0); + + return error; + } + + + /** + * Dispatch routine for HistogramEven, specialized for sample types larger than 8-bit + */ + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t DispatchEven( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + SampleIteratorT d_samples, ///< [in] The pointer to the input sequence of sample items. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four RGBA 8-bit samples). + CounterT* d_output_histograms[NUM_ACTIVE_CHANNELS], ///< [out] The pointers to the histogram counter output arrays, one for each active channel. For channeli, the allocation length of d_histograms[i] should be num_output_levels[i] - 1. + int num_output_levels[NUM_ACTIVE_CHANNELS], ///< [in] The number of bin level boundaries for delineating histogram samples in each active channel. Implies that the number of bins for channeli is num_output_levels[i] - 1. + LevelT lower_level[NUM_ACTIVE_CHANNELS], ///< [in] The lower sample value bound (inclusive) for the lowest histogram bin in each active channel. + LevelT upper_level[NUM_ACTIVE_CHANNELS], ///< [in] The upper sample value bound (exclusive) for the highest histogram bin in each active channel. + OffsetT num_row_pixels, ///< [in] The number of multi-channel pixels per row in the region of interest + OffsetT num_rows, ///< [in] The number of rows in the region of interest + OffsetT row_stride_samples, ///< [in] The number of samples between starts of consecutive rows in the region of interest + cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous, ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + Int2Type is_byte_sample) ///< [in] Marker type indicating whether or not SampleT is a 8b type + { + cudaError error = cudaSuccess; + do + { + // Get PTX version + int ptx_version; + #if (CUB_PTX_ARCH == 0) + if (CubDebug(error = PtxVersion(ptx_version))) break; + #else + ptx_version = CUB_PTX_ARCH; + #endif + + // Get kernel dispatch configurations + KernelConfig histogram_sweep_config; + if (CubDebug(error = InitConfigs(ptx_version, histogram_sweep_config))) + break; + + // Use the scale transform op for converting samples to privatized bins + typedef ScaleTransform PrivatizedDecodeOpT; + + // Use the pass-thru transform op for converting privatized bins to output bins + typedef PassThruTransform OutputDecodeOpT; + + PrivatizedDecodeOpT privatized_decode_op[NUM_ACTIVE_CHANNELS]; + OutputDecodeOpT output_decode_op[NUM_ACTIVE_CHANNELS]; + int max_levels = num_output_levels[0]; + + for (int channel = 0; channel < NUM_ACTIVE_CHANNELS; ++channel) + { + int bins = num_output_levels[channel] - 1; + LevelT scale = (upper_level[channel] - lower_level[channel]) / bins; + + privatized_decode_op[channel].Init(num_output_levels[channel], upper_level[channel], lower_level[channel], scale); + + if (num_output_levels[channel] > max_levels) + max_levels = num_output_levels[channel]; + } + int max_num_output_bins = max_levels - 1; + + if (max_num_output_bins > MAX_PRIVATIZED_SMEM_BINS) + { + // Dispatch shared-privatized approach + const int PRIVATIZED_SMEM_BINS = 0; + + if (CubDebug(error = PrivatizedDispatch( + d_temp_storage, + temp_storage_bytes, + d_samples, + d_output_histograms, + num_output_levels, + privatized_decode_op, + num_output_levels, + output_decode_op, + max_num_output_bins, + num_row_pixels, + num_rows, + row_stride_samples, + DeviceHistogramInitKernel, + DeviceHistogramSweepKernel, + histogram_sweep_config, + stream, + debug_synchronous))) break; + } + else + { + // Dispatch shared-privatized approach + const int PRIVATIZED_SMEM_BINS = MAX_PRIVATIZED_SMEM_BINS; + + if (CubDebug(error = PrivatizedDispatch( + d_temp_storage, + temp_storage_bytes, + d_samples, + d_output_histograms, + num_output_levels, + privatized_decode_op, + num_output_levels, + output_decode_op, + max_num_output_bins, + num_row_pixels, + num_rows, + row_stride_samples, + DeviceHistogramInitKernel, + DeviceHistogramSweepKernel, + histogram_sweep_config, + stream, + debug_synchronous))) break; + } + } + while (0); + + return error; + } + + + /** + * Dispatch routine for HistogramEven, specialized for 8-bit sample types (computes 256-bin privatized histograms and then reduces to user-specified levels) + */ + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t DispatchEven( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + SampleIteratorT d_samples, ///< [in] The pointer to the input sequence of sample items. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four RGBA 8-bit samples). + CounterT* d_output_histograms[NUM_ACTIVE_CHANNELS], ///< [out] The pointers to the histogram counter output arrays, one for each active channel. For channeli, the allocation length of d_histograms[i] should be num_output_levels[i] - 1. + int num_output_levels[NUM_ACTIVE_CHANNELS], ///< [in] The number of bin level boundaries for delineating histogram samples in each active channel. Implies that the number of bins for channeli is num_output_levels[i] - 1. + LevelT lower_level[NUM_ACTIVE_CHANNELS], ///< [in] The lower sample value bound (inclusive) for the lowest histogram bin in each active channel. + LevelT upper_level[NUM_ACTIVE_CHANNELS], ///< [in] The upper sample value bound (exclusive) for the highest histogram bin in each active channel. + OffsetT num_row_pixels, ///< [in] The number of multi-channel pixels per row in the region of interest + OffsetT num_rows, ///< [in] The number of rows in the region of interest + OffsetT row_stride_samples, ///< [in] The number of samples between starts of consecutive rows in the region of interest + cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous, ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + Int2Type is_byte_sample) ///< [in] Marker type indicating whether or not SampleT is a 8b type + { + cudaError error = cudaSuccess; + do + { + // Get PTX version + int ptx_version; + #if (CUB_PTX_ARCH == 0) + if (CubDebug(error = PtxVersion(ptx_version))) break; + #else + ptx_version = CUB_PTX_ARCH; + #endif + + // Get kernel dispatch configurations + KernelConfig histogram_sweep_config; + if (CubDebug(error = InitConfigs(ptx_version, histogram_sweep_config))) + break; + + // Use the pass-thru transform op for converting samples to privatized bins + typedef PassThruTransform PrivatizedDecodeOpT; + + // Use the scale transform op for converting privatized bins to output bins + typedef ScaleTransform OutputDecodeOpT; + + int num_privatized_levels[NUM_ACTIVE_CHANNELS]; + PrivatizedDecodeOpT privatized_decode_op[NUM_ACTIVE_CHANNELS]; + OutputDecodeOpT output_decode_op[NUM_ACTIVE_CHANNELS]; + int max_levels = num_output_levels[0]; + + for (int channel = 0; channel < NUM_ACTIVE_CHANNELS; ++channel) + { + num_privatized_levels[channel] = 257; + + int bins = num_output_levels[channel] - 1; + LevelT scale = (upper_level[channel] - lower_level[channel]) / bins; + output_decode_op[channel].Init(num_output_levels[channel], upper_level[channel], lower_level[channel], scale); + + if (num_output_levels[channel] > max_levels) + max_levels = num_output_levels[channel]; + } + int max_num_output_bins = max_levels - 1; + + const int PRIVATIZED_SMEM_BINS = 256; + + if (CubDebug(error = PrivatizedDispatch( + d_temp_storage, + temp_storage_bytes, + d_samples, + d_output_histograms, + num_privatized_levels, + privatized_decode_op, + num_output_levels, + output_decode_op, + max_num_output_bins, + num_row_pixels, + num_rows, + row_stride_samples, + DeviceHistogramInitKernel, + DeviceHistogramSweepKernel, + histogram_sweep_config, + stream, + debug_synchronous))) break; + + } + while (0); + + return error; + } + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/device/dispatch/dispatch_radix_sort.cuh b/fastertransformer/cuda/cub/device/dispatch/dispatch_radix_sort.cuh new file mode 100644 index 000000000..d1a992d43 --- /dev/null +++ b/fastertransformer/cuda/cub/device/dispatch/dispatch_radix_sort.cuh @@ -0,0 +1,1619 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DeviceRadixSort provides device-wide, parallel operations for computing a radix sort across a sequence of data items residing within device-accessible memory. + */ + +#pragma once + +#include +#include + +#include "../../agent/agent_radix_sort_upsweep.cuh" +#include "../../agent/agent_radix_sort_downsweep.cuh" +#include "../../agent/agent_scan.cuh" +#include "../../block/block_radix_sort.cuh" +#include "../../grid/grid_even_share.cuh" +#include "../../util_type.cuh" +#include "../../util_debug.cuh" +#include "../../util_device.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/****************************************************************************** + * Kernel entry points + *****************************************************************************/ + +/** + * Upsweep digit-counting kernel entry point (multi-block). Computes privatized digit histograms, one per block. + */ +template < + typename ChainedPolicyT, ///< Chained tuning policy + bool ALT_DIGIT_BITS, ///< Whether or not to use the alternate (lower-bits) policy + bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low + typename KeyT, ///< Key type + typename OffsetT> ///< Signed integer type for global offsets +__launch_bounds__ (int((ALT_DIGIT_BITS) ? + ChainedPolicyT::ActivePolicy::AltUpsweepPolicy::BLOCK_THREADS : + ChainedPolicyT::ActivePolicy::UpsweepPolicy::BLOCK_THREADS)) +__global__ void DeviceRadixSortUpsweepKernel( + const KeyT *d_keys, ///< [in] Input keys buffer + OffsetT *d_spine, ///< [out] Privatized (per block) digit histograms (striped, i.e., 0s counts from each block, then 1s counts from each block, etc.) + OffsetT /*num_items*/, ///< [in] Total number of input data items + int current_bit, ///< [in] Bit position of current radix digit + int num_bits, ///< [in] Number of bits of current radix digit + GridEvenShare even_share) ///< [in] Even-share descriptor for mapan equal number of tiles onto each thread block +{ + enum { + TILE_ITEMS = ChainedPolicyT::ActivePolicy::AltUpsweepPolicy::BLOCK_THREADS * + ChainedPolicyT::ActivePolicy::AltUpsweepPolicy::ITEMS_PER_THREAD + }; + + // Parameterize AgentRadixSortUpsweep type for the current configuration + typedef AgentRadixSortUpsweep< + typename If<(ALT_DIGIT_BITS), + typename ChainedPolicyT::ActivePolicy::AltUpsweepPolicy, + typename ChainedPolicyT::ActivePolicy::UpsweepPolicy>::Type, + KeyT, + OffsetT> + AgentRadixSortUpsweepT; + + // Shared memory storage + __shared__ typename AgentRadixSortUpsweepT::TempStorage temp_storage; + + // Initialize GRID_MAPPING_RAKE even-share descriptor for this thread block + even_share.template BlockInit(); + + AgentRadixSortUpsweepT upsweep(temp_storage, d_keys, current_bit, num_bits); + + upsweep.ProcessRegion(even_share.block_offset, even_share.block_end); + + CTA_SYNC(); + + // Write out digit counts (striped) + upsweep.template ExtractCounts(d_spine, gridDim.x, blockIdx.x); +} + + +/** + * Spine scan kernel entry point (single-block). Computes an exclusive prefix sum over the privatized digit histograms + */ +template < + typename ChainedPolicyT, ///< Chained tuning policy + typename OffsetT> ///< Signed integer type for global offsets +__launch_bounds__ (int(ChainedPolicyT::ActivePolicy::ScanPolicy::BLOCK_THREADS), 1) +__global__ void RadixSortScanBinsKernel( + OffsetT *d_spine, ///< [in,out] Privatized (per block) digit histograms (striped, i.e., 0s counts from each block, then 1s counts from each block, etc.) + int num_counts) ///< [in] Total number of bin-counts +{ + // Parameterize the AgentScan type for the current configuration + typedef AgentScan< + typename ChainedPolicyT::ActivePolicy::ScanPolicy, + OffsetT*, + OffsetT*, + cub::Sum, + OffsetT, + OffsetT> + AgentScanT; + + // Shared memory storage + __shared__ typename AgentScanT::TempStorage temp_storage; + + // Block scan instance + AgentScanT block_scan(temp_storage, d_spine, d_spine, cub::Sum(), OffsetT(0)) ; + + // Process full input tiles + int block_offset = 0; + BlockScanRunningPrefixOp prefix_op(0, Sum()); + while (block_offset + AgentScanT::TILE_ITEMS <= num_counts) + { + block_scan.template ConsumeTile(block_offset, prefix_op); + block_offset += AgentScanT::TILE_ITEMS; + } +} + + +/** + * Downsweep pass kernel entry point (multi-block). Scatters keys (and values) into corresponding bins for the current digit place. + */ +template < + typename ChainedPolicyT, ///< Chained tuning policy + bool ALT_DIGIT_BITS, ///< Whether or not to use the alternate (lower-bits) policy + bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low + typename KeyT, ///< Key type + typename ValueT, ///< Value type + typename OffsetT> ///< Signed integer type for global offsets +__launch_bounds__ (int((ALT_DIGIT_BITS) ? + ChainedPolicyT::ActivePolicy::AltDownsweepPolicy::BLOCK_THREADS : + ChainedPolicyT::ActivePolicy::DownsweepPolicy::BLOCK_THREADS)) +__global__ void DeviceRadixSortDownsweepKernel( + const KeyT *d_keys_in, ///< [in] Input keys buffer + KeyT *d_keys_out, ///< [in] Output keys buffer + const ValueT *d_values_in, ///< [in] Input values buffer + ValueT *d_values_out, ///< [in] Output values buffer + OffsetT *d_spine, ///< [in] Scan of privatized (per block) digit histograms (striped, i.e., 0s counts from each block, then 1s counts from each block, etc.) + OffsetT num_items, ///< [in] Total number of input data items + int current_bit, ///< [in] Bit position of current radix digit + int num_bits, ///< [in] Number of bits of current radix digit + GridEvenShare even_share) ///< [in] Even-share descriptor for mapan equal number of tiles onto each thread block +{ + enum { + TILE_ITEMS = ChainedPolicyT::ActivePolicy::AltUpsweepPolicy::BLOCK_THREADS * + ChainedPolicyT::ActivePolicy::AltUpsweepPolicy::ITEMS_PER_THREAD + }; + + // Parameterize AgentRadixSortDownsweep type for the current configuration + typedef AgentRadixSortDownsweep< + typename If<(ALT_DIGIT_BITS), + typename ChainedPolicyT::ActivePolicy::AltDownsweepPolicy, + typename ChainedPolicyT::ActivePolicy::DownsweepPolicy>::Type, + IS_DESCENDING, + KeyT, + ValueT, + OffsetT> + AgentRadixSortDownsweepT; + + // Shared memory storage + __shared__ typename AgentRadixSortDownsweepT::TempStorage temp_storage; + + // Initialize even-share descriptor for this thread block + even_share.template BlockInit(); + + // Process input tiles + AgentRadixSortDownsweepT(temp_storage, num_items, d_spine, d_keys_in, d_keys_out, d_values_in, d_values_out, current_bit, num_bits).ProcessRegion( + even_share.block_offset, + even_share.block_end); +} + + +/** + * Single pass kernel entry point (single-block). Fully sorts a tile of input. + */ +template < + typename ChainedPolicyT, ///< Chained tuning policy + bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low + typename KeyT, ///< Key type + typename ValueT, ///< Value type + typename OffsetT> ///< Signed integer type for global offsets +__launch_bounds__ (int(ChainedPolicyT::ActivePolicy::SingleTilePolicy::BLOCK_THREADS), 1) +__global__ void DeviceRadixSortSingleTileKernel( + const KeyT *d_keys_in, ///< [in] Input keys buffer + KeyT *d_keys_out, ///< [in] Output keys buffer + const ValueT *d_values_in, ///< [in] Input values buffer + ValueT *d_values_out, ///< [in] Output values buffer + OffsetT num_items, ///< [in] Total number of input data items + int current_bit, ///< [in] Bit position of current radix digit + int end_bit) ///< [in] The past-the-end (most-significant) bit index needed for key comparison +{ + // Constants + enum + { + BLOCK_THREADS = ChainedPolicyT::ActivePolicy::SingleTilePolicy::BLOCK_THREADS, + ITEMS_PER_THREAD = ChainedPolicyT::ActivePolicy::SingleTilePolicy::ITEMS_PER_THREAD, + KEYS_ONLY = Equals::VALUE, + }; + + // BlockRadixSort type + typedef BlockRadixSort< + KeyT, + BLOCK_THREADS, + ITEMS_PER_THREAD, + ValueT, + ChainedPolicyT::ActivePolicy::SingleTilePolicy::RADIX_BITS, + (ChainedPolicyT::ActivePolicy::SingleTilePolicy::RANK_ALGORITHM == RADIX_RANK_MEMOIZE), + ChainedPolicyT::ActivePolicy::SingleTilePolicy::SCAN_ALGORITHM> + BlockRadixSortT; + + // BlockLoad type (keys) + typedef BlockLoad< + KeyT, + BLOCK_THREADS, + ITEMS_PER_THREAD, + ChainedPolicyT::ActivePolicy::SingleTilePolicy::LOAD_ALGORITHM> BlockLoadKeys; + + // BlockLoad type (values) + typedef BlockLoad< + ValueT, + BLOCK_THREADS, + ITEMS_PER_THREAD, + ChainedPolicyT::ActivePolicy::SingleTilePolicy::LOAD_ALGORITHM> BlockLoadValues; + + // Unsigned word for key bits + typedef typename Traits::UnsignedBits UnsignedBitsT; + + // Shared memory storage + __shared__ union TempStorage + { + typename BlockRadixSortT::TempStorage sort; + typename BlockLoadKeys::TempStorage load_keys; + typename BlockLoadValues::TempStorage load_values; + + } temp_storage; + + // Keys and values for the block + KeyT keys[ITEMS_PER_THREAD]; + ValueT values[ITEMS_PER_THREAD]; + + // Get default (min/max) value for out-of-bounds keys + UnsignedBitsT default_key_bits = (IS_DESCENDING) ? Traits::LOWEST_KEY : Traits::MAX_KEY; + KeyT default_key = reinterpret_cast(default_key_bits); + + // Load keys + BlockLoadKeys(temp_storage.load_keys).Load(d_keys_in, keys, num_items, default_key); + + CTA_SYNC(); + + // Load values + if (!KEYS_ONLY) + { + // Register pressure work-around: moving num_items through shfl prevents compiler + // from reusing guards/addressing from prior guarded loads + num_items = ShuffleIndex(num_items, 0, 0xffffffff); + + BlockLoadValues(temp_storage.load_values).Load(d_values_in, values, num_items); + + CTA_SYNC(); + } + + // Sort tile + BlockRadixSortT(temp_storage.sort).SortBlockedToStriped( + keys, + values, + current_bit, + end_bit, + Int2Type(), + Int2Type()); + + // Store keys and values + #pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + int item_offset = ITEM * BLOCK_THREADS + threadIdx.x; + if (item_offset < num_items) + { + d_keys_out[item_offset] = keys[ITEM]; + if (!KEYS_ONLY) + d_values_out[item_offset] = values[ITEM]; + } + } +} + + +/** + * Segmented radix sorting pass (one block per segment) + */ +template < + typename ChainedPolicyT, ///< Chained tuning policy + bool ALT_DIGIT_BITS, ///< Whether or not to use the alternate (lower-bits) policy + bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low + typename KeyT, ///< Key type + typename ValueT, ///< Value type + typename OffsetIteratorT, ///< Random-access input iterator type for reading segment offsets \iterator + typename OffsetT> ///< Signed integer type for global offsets +__launch_bounds__ (int((ALT_DIGIT_BITS) ? + ChainedPolicyT::ActivePolicy::AltSegmentedPolicy::BLOCK_THREADS : + ChainedPolicyT::ActivePolicy::SegmentedPolicy::BLOCK_THREADS)) +__global__ void DeviceSegmentedRadixSortKernel( + const KeyT *d_keys_in, ///< [in] Input keys buffer + KeyT *d_keys_out, ///< [in] Output keys buffer + const ValueT *d_values_in, ///< [in] Input values buffer + ValueT *d_values_out, ///< [in] Output values buffer + OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + int /*num_segments*/, ///< [in] The number of segments that comprise the sorting data + int current_bit, ///< [in] Bit position of current radix digit + int pass_bits) ///< [in] Number of bits of current radix digit +{ + // + // Constants + // + + typedef typename If<(ALT_DIGIT_BITS), + typename ChainedPolicyT::ActivePolicy::AltSegmentedPolicy, + typename ChainedPolicyT::ActivePolicy::SegmentedPolicy>::Type SegmentedPolicyT; + + enum + { + BLOCK_THREADS = SegmentedPolicyT::BLOCK_THREADS, + ITEMS_PER_THREAD = SegmentedPolicyT::ITEMS_PER_THREAD, + RADIX_BITS = SegmentedPolicyT::RADIX_BITS, + TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, + RADIX_DIGITS = 1 << RADIX_BITS, + KEYS_ONLY = Equals::VALUE, + }; + + // Upsweep type + typedef AgentRadixSortUpsweep< + AgentRadixSortUpsweepPolicy, + KeyT, + OffsetT> + BlockUpsweepT; + + // Digit-scan type + typedef BlockScan DigitScanT; + + // Downsweep type + typedef AgentRadixSortDownsweep BlockDownsweepT; + + enum + { + /// Number of bin-starting offsets tracked per thread + BINS_TRACKED_PER_THREAD = BlockDownsweepT::BINS_TRACKED_PER_THREAD + }; + + // + // Process input tiles + // + + // Shared memory storage + __shared__ union + { + typename BlockUpsweepT::TempStorage upsweep; + typename BlockDownsweepT::TempStorage downsweep; + struct + { + volatile OffsetT reverse_counts_in[RADIX_DIGITS]; + volatile OffsetT reverse_counts_out[RADIX_DIGITS]; + typename DigitScanT::TempStorage scan; + }; + + } temp_storage; + + OffsetT segment_begin = d_begin_offsets[blockIdx.x]; + OffsetT segment_end = d_end_offsets[blockIdx.x]; + OffsetT num_items = segment_end - segment_begin; + + // Check if empty segment + if (num_items <= 0) + return; + + // Upsweep + BlockUpsweepT upsweep(temp_storage.upsweep, d_keys_in, current_bit, pass_bits); + upsweep.ProcessRegion(segment_begin, segment_end); + + CTA_SYNC(); + + // The count of each digit value in this pass (valid in the first RADIX_DIGITS threads) + OffsetT bin_count[BINS_TRACKED_PER_THREAD]; + upsweep.ExtractCounts(bin_count); + + CTA_SYNC(); + + if (IS_DESCENDING) + { + // Reverse bin counts + #pragma unroll + for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) + { + int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track; + + if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) + temp_storage.reverse_counts_in[bin_idx] = bin_count[track]; + } + + CTA_SYNC(); + + #pragma unroll + for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) + { + int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track; + + if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) + bin_count[track] = temp_storage.reverse_counts_in[RADIX_DIGITS - bin_idx - 1]; + } + } + + // Scan + OffsetT bin_offset[BINS_TRACKED_PER_THREAD]; // The global scatter base offset for each digit value in this pass (valid in the first RADIX_DIGITS threads) + DigitScanT(temp_storage.scan).ExclusiveSum(bin_count, bin_offset); + + #pragma unroll + for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) + { + bin_offset[track] += segment_begin; + } + + if (IS_DESCENDING) + { + // Reverse bin offsets + #pragma unroll + for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) + { + int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track; + + if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) + temp_storage.reverse_counts_out[threadIdx.x] = bin_offset[track]; + } + + CTA_SYNC(); + + #pragma unroll + for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) + { + int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track; + + if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) + bin_offset[track] = temp_storage.reverse_counts_out[RADIX_DIGITS - bin_idx - 1]; + } + } + + CTA_SYNC(); + + // Downsweep + BlockDownsweepT downsweep(temp_storage.downsweep, bin_offset, num_items, d_keys_in, d_keys_out, d_values_in, d_values_out, current_bit, pass_bits); + downsweep.ProcessRegion(segment_begin, segment_end); +} + + + +/****************************************************************************** + * Policy + ******************************************************************************/ + +/** + * Tuning policy for kernel specialization + */ +template < + typename KeyT, ///< Key type + typename ValueT, ///< Value type + typename OffsetT> ///< Signed integer type for global offsets +struct DeviceRadixSortPolicy +{ + //------------------------------------------------------------------------------ + // Constants + //------------------------------------------------------------------------------ + + enum + { + // Whether this is a keys-only (or key-value) sort + KEYS_ONLY = (Equals::VALUE), + }; + + // Dominant-sized key/value type + typedef typename If<(sizeof(ValueT) > 4) && (sizeof(KeyT) < sizeof(ValueT)), ValueT, KeyT>::Type DominantT; + + //------------------------------------------------------------------------------ + // Architecture-specific tuning policies + //------------------------------------------------------------------------------ + + /// SM20 + struct Policy200 : ChainedPolicy<200, Policy200, Policy200> + { + enum { + PRIMARY_RADIX_BITS = 5, + ALT_RADIX_BITS = PRIMARY_RADIX_BITS - 1, + + // Relative size of KeyT type to a 4-byte word + SCALE_FACTOR_4B = (CUB_MAX(sizeof(KeyT), sizeof(ValueT)) + 3) / 4, + }; + + // Keys-only upsweep policies + typedef AgentRadixSortUpsweepPolicy <64, CUB_MAX(1, 18 / SCALE_FACTOR_4B), LOAD_DEFAULT, PRIMARY_RADIX_BITS> UpsweepPolicyKeys; + typedef AgentRadixSortUpsweepPolicy <64, CUB_MAX(1, 18 / SCALE_FACTOR_4B), LOAD_DEFAULT, ALT_RADIX_BITS> AltUpsweepPolicyKeys; + + // Key-value pairs upsweep policies + typedef AgentRadixSortUpsweepPolicy <128, CUB_MAX(1, 13 / SCALE_FACTOR_4B), LOAD_DEFAULT, PRIMARY_RADIX_BITS> UpsweepPolicyPairs; + typedef AgentRadixSortUpsweepPolicy <128, CUB_MAX(1, 13 / SCALE_FACTOR_4B), LOAD_DEFAULT, ALT_RADIX_BITS> AltUpsweepPolicyPairs; + + // Upsweep policies + typedef typename If::Type UpsweepPolicy; + typedef typename If::Type AltUpsweepPolicy; + + // Scan policy + typedef AgentScanPolicy <512, 4, BLOCK_LOAD_VECTORIZE, LOAD_DEFAULT, BLOCK_STORE_VECTORIZE, BLOCK_SCAN_RAKING_MEMOIZE> ScanPolicy; + + // Keys-only downsweep policies + typedef AgentRadixSortDownsweepPolicy <64, CUB_MAX(1, 18 / SCALE_FACTOR_4B), BLOCK_LOAD_WARP_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_BASIC, BLOCK_SCAN_WARP_SCANS, PRIMARY_RADIX_BITS> DownsweepPolicyKeys; + typedef AgentRadixSortDownsweepPolicy <64, CUB_MAX(1, 18 / SCALE_FACTOR_4B), BLOCK_LOAD_WARP_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_BASIC, BLOCK_SCAN_WARP_SCANS, ALT_RADIX_BITS> AltDownsweepPolicyKeys; + + // Key-value pairs downsweep policies + typedef AgentRadixSortDownsweepPolicy <128, CUB_MAX(1, 13 / SCALE_FACTOR_4B), BLOCK_LOAD_WARP_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_BASIC, BLOCK_SCAN_WARP_SCANS, PRIMARY_RADIX_BITS> DownsweepPolicyPairs; + typedef AgentRadixSortDownsweepPolicy <128, CUB_MAX(1, 13 / SCALE_FACTOR_4B), BLOCK_LOAD_WARP_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_BASIC, BLOCK_SCAN_WARP_SCANS, ALT_RADIX_BITS> AltDownsweepPolicyPairs; + + // Downsweep policies + typedef typename If::Type DownsweepPolicy; + typedef typename If::Type AltDownsweepPolicy; + + // Single-tile policy + typedef DownsweepPolicy SingleTilePolicy; + + // Segmented policies + typedef DownsweepPolicy SegmentedPolicy; + typedef AltDownsweepPolicy AltSegmentedPolicy; + }; + + /// SM30 + struct Policy300 : ChainedPolicy<300, Policy300, Policy200> + { + enum { + PRIMARY_RADIX_BITS = 5, + ALT_RADIX_BITS = PRIMARY_RADIX_BITS - 1, + + // Relative size of KeyT type to a 4-byte word + SCALE_FACTOR_4B = (CUB_MAX(sizeof(KeyT), sizeof(ValueT)) + 3) / 4, + }; + + // Keys-only upsweep policies + typedef AgentRadixSortUpsweepPolicy <256, CUB_MAX(1, 7 / SCALE_FACTOR_4B), LOAD_DEFAULT, PRIMARY_RADIX_BITS> UpsweepPolicyKeys; + typedef AgentRadixSortUpsweepPolicy <256, CUB_MAX(1, 7 / SCALE_FACTOR_4B), LOAD_DEFAULT, ALT_RADIX_BITS> AltUpsweepPolicyKeys; + + // Key-value pairs upsweep policies + typedef AgentRadixSortUpsweepPolicy <256, CUB_MAX(1, 5 / SCALE_FACTOR_4B), LOAD_DEFAULT, PRIMARY_RADIX_BITS> UpsweepPolicyPairs; + typedef AgentRadixSortUpsweepPolicy <256, CUB_MAX(1, 5 / SCALE_FACTOR_4B), LOAD_DEFAULT, ALT_RADIX_BITS> AltUpsweepPolicyPairs; + + // Upsweep policies + typedef typename If::Type UpsweepPolicy; + typedef typename If::Type AltUpsweepPolicy; + + // Scan policy + typedef AgentScanPolicy <1024, 4, BLOCK_LOAD_VECTORIZE, LOAD_DEFAULT, BLOCK_STORE_VECTORIZE, BLOCK_SCAN_WARP_SCANS> ScanPolicy; + + // Keys-only downsweep policies + typedef AgentRadixSortDownsweepPolicy <128, CUB_MAX(1, 14 / SCALE_FACTOR_4B), BLOCK_LOAD_WARP_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_BASIC, BLOCK_SCAN_WARP_SCANS, PRIMARY_RADIX_BITS> DownsweepPolicyKeys; + typedef AgentRadixSortDownsweepPolicy <128, CUB_MAX(1, 14 / SCALE_FACTOR_4B), BLOCK_LOAD_WARP_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_BASIC, BLOCK_SCAN_WARP_SCANS, ALT_RADIX_BITS> AltDownsweepPolicyKeys; + + // Key-value pairs downsweep policies + typedef AgentRadixSortDownsweepPolicy <128, CUB_MAX(1, 10 / SCALE_FACTOR_4B), BLOCK_LOAD_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_BASIC, BLOCK_SCAN_WARP_SCANS, PRIMARY_RADIX_BITS> DownsweepPolicyPairs; + typedef AgentRadixSortDownsweepPolicy <128, CUB_MAX(1, 10 / SCALE_FACTOR_4B), BLOCK_LOAD_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_BASIC, BLOCK_SCAN_WARP_SCANS, ALT_RADIX_BITS> AltDownsweepPolicyPairs; + + // Downsweep policies + typedef typename If::Type DownsweepPolicy; + typedef typename If::Type AltDownsweepPolicy; + + // Single-tile policy + typedef DownsweepPolicy SingleTilePolicy; + + // Segmented policies + typedef DownsweepPolicy SegmentedPolicy; + typedef AltDownsweepPolicy AltSegmentedPolicy; + }; + + + /// SM35 + struct Policy350 : ChainedPolicy<350, Policy350, Policy300> + { + enum { + PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5, // 1.72B 32b keys/s, 1.17B 32b pairs/s, 1.55B 32b segmented keys/s (K40m) + }; + + // Scan policy + typedef AgentScanPolicy <1024, 4, BLOCK_LOAD_VECTORIZE, LOAD_DEFAULT, BLOCK_STORE_VECTORIZE, BLOCK_SCAN_WARP_SCANS> ScanPolicy; + + // Keys-only downsweep policies + typedef AgentRadixSortDownsweepPolicy DownsweepPolicyKeys; + typedef AgentRadixSortDownsweepPolicy AltDownsweepPolicyKeys; + + // Key-value pairs downsweep policies + typedef DownsweepPolicyKeys DownsweepPolicyPairs; + typedef AgentRadixSortDownsweepPolicy AltDownsweepPolicyPairs; + + // Downsweep policies + typedef typename If::Type DownsweepPolicy; + typedef typename If::Type AltDownsweepPolicy; + + // Upsweep policies + typedef DownsweepPolicy UpsweepPolicy; + typedef AltDownsweepPolicy AltUpsweepPolicy; + + // Single-tile policy + typedef DownsweepPolicy SingleTilePolicy; + + // Segmented policies + typedef DownsweepPolicy SegmentedPolicy; + typedef AltDownsweepPolicy AltSegmentedPolicy; + + + }; + + + /// SM50 + struct Policy500 : ChainedPolicy<500, Policy500, Policy350> + { + enum { + PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 7 : 5, // 3.5B 32b keys/s, 1.92B 32b pairs/s (TitanX) + SINGLE_TILE_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5, + SEGMENTED_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5, // 3.1B 32b segmented keys/s (TitanX) + }; + + // ScanPolicy + typedef AgentScanPolicy <512, 23, BLOCK_LOAD_WARP_TRANSPOSE, LOAD_DEFAULT, BLOCK_STORE_WARP_TRANSPOSE, BLOCK_SCAN_RAKING_MEMOIZE> ScanPolicy; + + // Downsweep policies + typedef AgentRadixSortDownsweepPolicy DownsweepPolicy; + typedef AgentRadixSortDownsweepPolicy AltDownsweepPolicy; + + // Upsweep policies + typedef DownsweepPolicy UpsweepPolicy; + typedef AltDownsweepPolicy AltUpsweepPolicy; + + // Single-tile policy + typedef AgentRadixSortDownsweepPolicy SingleTilePolicy; + + // Segmented policies + typedef AgentRadixSortDownsweepPolicy SegmentedPolicy; + typedef AgentRadixSortDownsweepPolicy AltSegmentedPolicy; + }; + + + /// SM60 (GP100) + struct Policy600 : ChainedPolicy<600, Policy600, Policy500> + { + enum { + PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 7 : 5, // 6.9B 32b keys/s (Quadro P100) + SINGLE_TILE_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5, + SEGMENTED_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5, // 5.9B 32b segmented keys/s (Quadro P100) + }; + + // ScanPolicy + typedef AgentScanPolicy <512, 23, BLOCK_LOAD_WARP_TRANSPOSE, LOAD_DEFAULT, BLOCK_STORE_WARP_TRANSPOSE, BLOCK_SCAN_RAKING_MEMOIZE> ScanPolicy; + + // Downsweep policies + typedef AgentRadixSortDownsweepPolicy DownsweepPolicy; + typedef AgentRadixSortDownsweepPolicy AltDownsweepPolicy; + + // Upsweep policies + typedef DownsweepPolicy UpsweepPolicy; + typedef AltDownsweepPolicy AltUpsweepPolicy; + + // Single-tile policy + typedef AgentRadixSortDownsweepPolicy SingleTilePolicy; + + // Segmented policies + typedef AgentRadixSortDownsweepPolicy SegmentedPolicy; + typedef AgentRadixSortDownsweepPolicy AltSegmentedPolicy; + + }; + + + /// SM61 (GP104) + struct Policy610 : ChainedPolicy<610, Policy610, Policy600> + { + enum { + PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 7 : 5, // 3.4B 32b keys/s, 1.83B 32b pairs/s (1080) + SINGLE_TILE_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5, + SEGMENTED_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5, // 3.3B 32b segmented keys/s (1080) + }; + + // ScanPolicy + typedef AgentScanPolicy <512, 23, BLOCK_LOAD_WARP_TRANSPOSE, LOAD_DEFAULT, BLOCK_STORE_WARP_TRANSPOSE, BLOCK_SCAN_RAKING_MEMOIZE> ScanPolicy; + + // Downsweep policies + typedef AgentRadixSortDownsweepPolicy DownsweepPolicy; + typedef AgentRadixSortDownsweepPolicy AltDownsweepPolicy; + + // Upsweep policies + typedef AgentRadixSortUpsweepPolicy UpsweepPolicy; + typedef AgentRadixSortUpsweepPolicy AltUpsweepPolicy; + + // Single-tile policy + typedef AgentRadixSortDownsweepPolicy SingleTilePolicy; + + // Segmented policies + typedef AgentRadixSortDownsweepPolicy SegmentedPolicy; + typedef AgentRadixSortDownsweepPolicy AltSegmentedPolicy; + }; + + + /// SM62 (Tegra, less RF) + struct Policy620 : ChainedPolicy<620, Policy620, Policy610> + { + enum { + PRIMARY_RADIX_BITS = 5, + ALT_RADIX_BITS = PRIMARY_RADIX_BITS - 1, + }; + + // ScanPolicy + typedef AgentScanPolicy <512, 23, BLOCK_LOAD_WARP_TRANSPOSE, LOAD_DEFAULT, BLOCK_STORE_WARP_TRANSPOSE, BLOCK_SCAN_RAKING_MEMOIZE> ScanPolicy; + + // Downsweep policies + typedef AgentRadixSortDownsweepPolicy DownsweepPolicy; + typedef AgentRadixSortDownsweepPolicy AltDownsweepPolicy; + + // Upsweep policies + typedef DownsweepPolicy UpsweepPolicy; + typedef AltDownsweepPolicy AltUpsweepPolicy; + + // Single-tile policy + typedef AgentRadixSortDownsweepPolicy SingleTilePolicy; + + // Segmented policies + typedef DownsweepPolicy SegmentedPolicy; + typedef AltDownsweepPolicy AltSegmentedPolicy; + }; + + + /// SM70 (GV100) + struct Policy700 : ChainedPolicy<700, Policy700, Policy620> + { + enum { + PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 7 : 5, // 7.62B 32b keys/s (GV100) + SINGLE_TILE_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5, + SEGMENTED_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5, // 8.7B 32b segmented keys/s (GV100) + }; + + // ScanPolicy + typedef AgentScanPolicy <512, 23, BLOCK_LOAD_WARP_TRANSPOSE, LOAD_DEFAULT, BLOCK_STORE_WARP_TRANSPOSE, BLOCK_SCAN_RAKING_MEMOIZE> ScanPolicy; + + // Downsweep policies + typedef AgentRadixSortDownsweepPolicy DownsweepPolicy; + typedef AgentRadixSortDownsweepPolicy AltDownsweepPolicy; + + // Upsweep policies + typedef DownsweepPolicy UpsweepPolicy; + typedef AltDownsweepPolicy AltUpsweepPolicy; + + // Single-tile policy + typedef AgentRadixSortDownsweepPolicy SingleTilePolicy; + + // Segmented policies + typedef AgentRadixSortDownsweepPolicy SegmentedPolicy; + typedef AgentRadixSortDownsweepPolicy AltSegmentedPolicy; + }; + + + /// MaxPolicy + typedef Policy700 MaxPolicy; + + +}; + + + +/****************************************************************************** + * Single-problem dispatch + ******************************************************************************/ + +/** + * Utility class for dispatching the appropriately-tuned kernels for device-wide radix sort + */ +template < + bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low + typename KeyT, ///< Key type + typename ValueT, ///< Value type + typename OffsetT> ///< Signed integer type for global offsets +struct DispatchRadixSort : + DeviceRadixSortPolicy +{ + //------------------------------------------------------------------------------ + // Constants + //------------------------------------------------------------------------------ + + enum + { + // Whether this is a keys-only (or key-value) sort + KEYS_ONLY = (Equals::VALUE), + }; + + + //------------------------------------------------------------------------------ + // Problem state + //------------------------------------------------------------------------------ + + void *d_temp_storage; ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes; ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + DoubleBuffer &d_keys; ///< [in,out] Double-buffer whose current buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + DoubleBuffer &d_values; ///< [in,out] Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + OffsetT num_items; ///< [in] Number of items to sort + int begin_bit; ///< [in] The beginning (least-significant) bit index needed for key comparison + int end_bit; ///< [in] The past-the-end (most-significant) bit index needed for key comparison + cudaStream_t stream; ///< [in] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous; ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + int ptx_version; ///< [in] PTX version + bool is_overwrite_okay; ///< [in] Whether is okay to overwrite source buffers + + + //------------------------------------------------------------------------------ + // Constructor + //------------------------------------------------------------------------------ + + /// Constructor + CUB_RUNTIME_FUNCTION __forceinline__ + DispatchRadixSort( + void* d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + OffsetT num_items, + int begin_bit, + int end_bit, + bool is_overwrite_okay, + cudaStream_t stream, + bool debug_synchronous, + int ptx_version) + : + d_temp_storage(d_temp_storage), + temp_storage_bytes(temp_storage_bytes), + d_keys(d_keys), + d_values(d_values), + num_items(num_items), + begin_bit(begin_bit), + end_bit(end_bit), + stream(stream), + debug_synchronous(debug_synchronous), + ptx_version(ptx_version), + is_overwrite_okay(is_overwrite_okay) + {} + + + //------------------------------------------------------------------------------ + // Small-problem (single tile) invocation + //------------------------------------------------------------------------------ + + /// Invoke a single block to sort in-core + template < + typename ActivePolicyT, ///< Umbrella policy active for the target device + typename SingleTileKernelT> ///< Function type of cub::DeviceRadixSortSingleTileKernel + CUB_RUNTIME_FUNCTION __forceinline__ + cudaError_t InvokeSingleTile( + SingleTileKernelT single_tile_kernel) ///< [in] Kernel function pointer to parameterization of cub::DeviceRadixSortSingleTileKernel + { +#ifndef CUB_RUNTIME_ENABLED + (void)single_tile_kernel; + // Kernel launch not supported from this device + return CubDebug(cudaErrorNotSupported ); +#else + cudaError error = cudaSuccess; + do + { + // Return if the caller is simply requesting the size of the storage allocation + if (d_temp_storage == NULL) + { + temp_storage_bytes = 1; + break; + } + + // Return if empty problem + if (num_items == 0) + break; + + // Log single_tile_kernel configuration + if (debug_synchronous) + _CubLog("Invoking single_tile_kernel<<<%d, %d, 0, %lld>>>(), %d items per thread, %d SM occupancy, current bit %d, bit_grain %d\n", + 1, ActivePolicyT::SingleTilePolicy::BLOCK_THREADS, (long long) stream, + ActivePolicyT::SingleTilePolicy::ITEMS_PER_THREAD, 1, begin_bit, ActivePolicyT::SingleTilePolicy::RADIX_BITS); + + // Invoke upsweep_kernel with same grid size as downsweep_kernel + single_tile_kernel<<<1, ActivePolicyT::SingleTilePolicy::BLOCK_THREADS, 0, stream>>>( + d_keys.Current(), + d_keys.Alternate(), + d_values.Current(), + d_values.Alternate(), + num_items, + begin_bit, + end_bit); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + + // Update selector + d_keys.selector ^= 1; + d_values.selector ^= 1; + } + while (0); + + return error; + +#endif // CUB_RUNTIME_ENABLED + } + + + //------------------------------------------------------------------------------ + // Normal problem size invocation + //------------------------------------------------------------------------------ + + /** + * Invoke a three-kernel sorting pass at the current bit. + */ + template + CUB_RUNTIME_FUNCTION __forceinline__ + cudaError_t InvokePass( + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + OffsetT *d_spine, + int spine_length, + int ¤t_bit, + PassConfigT &pass_config) + { + cudaError error = cudaSuccess; + do + { + int pass_bits = CUB_MIN(pass_config.radix_bits, (end_bit - current_bit)); + + // Log upsweep_kernel configuration + if (debug_synchronous) + _CubLog("Invoking upsweep_kernel<<<%d, %d, 0, %lld>>>(), %d items per thread, %d SM occupancy, current bit %d, bit_grain %d\n", + pass_config.even_share.grid_size, pass_config.upsweep_config.block_threads, (long long) stream, + pass_config.upsweep_config.items_per_thread, pass_config.upsweep_config.sm_occupancy, current_bit, pass_bits); + + // Invoke upsweep_kernel with same grid size as downsweep_kernel + pass_config.upsweep_kernel<<>>( + d_keys_in, + d_spine, + num_items, + current_bit, + pass_bits, + pass_config.even_share); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + + // Log scan_kernel configuration + if (debug_synchronous) _CubLog("Invoking scan_kernel<<<%d, %d, 0, %lld>>>(), %d items per thread\n", + 1, pass_config.scan_config.block_threads, (long long) stream, pass_config.scan_config.items_per_thread); + + // Invoke scan_kernel + pass_config.scan_kernel<<<1, pass_config.scan_config.block_threads, 0, stream>>>( + d_spine, + spine_length); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + + // Log downsweep_kernel configuration + if (debug_synchronous) _CubLog("Invoking downsweep_kernel<<<%d, %d, 0, %lld>>>(), %d items per thread, %d SM occupancy\n", + pass_config.even_share.grid_size, pass_config.downsweep_config.block_threads, (long long) stream, + pass_config.downsweep_config.items_per_thread, pass_config.downsweep_config.sm_occupancy); + + // Invoke downsweep_kernel + pass_config.downsweep_kernel<<>>( + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, + d_spine, + num_items, + current_bit, + pass_bits, + pass_config.even_share); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + + // Update current bit + current_bit += pass_bits; + } + while (0); + + return error; + } + + + + /// Pass configuration structure + template < + typename UpsweepKernelT, + typename ScanKernelT, + typename DownsweepKernelT> + struct PassConfig + { + UpsweepKernelT upsweep_kernel; + KernelConfig upsweep_config; + ScanKernelT scan_kernel; + KernelConfig scan_config; + DownsweepKernelT downsweep_kernel; + KernelConfig downsweep_config; + int radix_bits; + int radix_digits; + int max_downsweep_grid_size; + GridEvenShare even_share; + + /// Initialize pass configuration + template < + typename UpsweepPolicyT, + typename ScanPolicyT, + typename DownsweepPolicyT> + CUB_RUNTIME_FUNCTION __forceinline__ + cudaError_t InitPassConfig( + UpsweepKernelT upsweep_kernel, + ScanKernelT scan_kernel, + DownsweepKernelT downsweep_kernel, + int ptx_version, + int sm_count, + int num_items) + { + cudaError error = cudaSuccess; + do + { + this->upsweep_kernel = upsweep_kernel; + this->scan_kernel = scan_kernel; + this->downsweep_kernel = downsweep_kernel; + radix_bits = DownsweepPolicyT::RADIX_BITS; + radix_digits = 1 << radix_bits; + + if (CubDebug(error = upsweep_config.Init(upsweep_kernel))) break; + if (CubDebug(error = scan_config.Init(scan_kernel))) break; + if (CubDebug(error = downsweep_config.Init(downsweep_kernel))) break; + + max_downsweep_grid_size = (downsweep_config.sm_occupancy * sm_count) * CUB_SUBSCRIPTION_FACTOR(ptx_version); + + even_share.DispatchInit( + num_items, + max_downsweep_grid_size, + CUB_MAX(downsweep_config.tile_size, upsweep_config.tile_size)); + + } + while (0); + return error; + } + + }; + + + /// Invocation (run multiple digit passes) + template < + typename ActivePolicyT, ///< Umbrella policy active for the target device + typename UpsweepKernelT, ///< Function type of cub::DeviceRadixSortUpsweepKernel + typename ScanKernelT, ///< Function type of cub::SpineScanKernel + typename DownsweepKernelT> ///< Function type of cub::DeviceRadixSortDownsweepKernel + CUB_RUNTIME_FUNCTION __forceinline__ + cudaError_t InvokePasses( + UpsweepKernelT upsweep_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceRadixSortUpsweepKernel + UpsweepKernelT alt_upsweep_kernel, ///< [in] Alternate kernel function pointer to parameterization of cub::DeviceRadixSortUpsweepKernel + ScanKernelT scan_kernel, ///< [in] Kernel function pointer to parameterization of cub::SpineScanKernel + DownsweepKernelT downsweep_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceRadixSortDownsweepKernel + DownsweepKernelT alt_downsweep_kernel) ///< [in] Alternate kernel function pointer to parameterization of cub::DeviceRadixSortDownsweepKernel + { +#ifndef CUB_RUNTIME_ENABLED + (void)upsweep_kernel; + (void)alt_upsweep_kernel; + (void)scan_kernel; + (void)downsweep_kernel; + (void)alt_downsweep_kernel; + + // Kernel launch not supported from this device + return CubDebug(cudaErrorNotSupported ); +#else + + cudaError error = cudaSuccess; + do + { + // Get device ordinal + int device_ordinal; + if (CubDebug(error = cudaGetDevice(&device_ordinal))) break; + + // Get SM count + int sm_count; + if (CubDebug(error = cudaDeviceGetAttribute (&sm_count, cudaDevAttrMultiProcessorCount, device_ordinal))) break; + + // Init regular and alternate-digit kernel configurations + PassConfig pass_config, alt_pass_config; + if ((error = pass_config.template InitPassConfig< + typename ActivePolicyT::UpsweepPolicy, + typename ActivePolicyT::ScanPolicy, + typename ActivePolicyT::DownsweepPolicy>( + upsweep_kernel, scan_kernel, downsweep_kernel, ptx_version, sm_count, num_items))) break; + + if ((error = alt_pass_config.template InitPassConfig< + typename ActivePolicyT::AltUpsweepPolicy, + typename ActivePolicyT::ScanPolicy, + typename ActivePolicyT::AltDownsweepPolicy>( + alt_upsweep_kernel, scan_kernel, alt_downsweep_kernel, ptx_version, sm_count, num_items))) break; + + // Get maximum spine length + int max_grid_size = CUB_MAX(pass_config.max_downsweep_grid_size, alt_pass_config.max_downsweep_grid_size); + int spine_length = (max_grid_size * pass_config.radix_digits) + pass_config.scan_config.tile_size; + + // Temporary storage allocation requirements + void* allocations[3]; + size_t allocation_sizes[3] = + { + spine_length * sizeof(OffsetT), // bytes needed for privatized block digit histograms + (is_overwrite_okay) ? 0 : num_items * sizeof(KeyT), // bytes needed for 3rd keys buffer + (is_overwrite_okay || (KEYS_ONLY)) ? 0 : num_items * sizeof(ValueT), // bytes needed for 3rd values buffer + }; + + // Alias the temporary allocations from the single storage blob (or compute the necessary size of the blob) + if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break; + + // Return if the caller is simply requesting the size of the storage allocation + if (d_temp_storage == NULL) + return cudaSuccess; + + // Pass planning. Run passes of the alternate digit-size configuration until we have an even multiple of our preferred digit size + int num_bits = end_bit - begin_bit; + int num_passes = (num_bits + pass_config.radix_bits - 1) / pass_config.radix_bits; + bool is_num_passes_odd = num_passes & 1; + int max_alt_passes = (num_passes * pass_config.radix_bits) - num_bits; + int alt_end_bit = CUB_MIN(end_bit, begin_bit + (max_alt_passes * alt_pass_config.radix_bits)); + + // Alias the temporary storage allocations + OffsetT *d_spine = static_cast(allocations[0]); + + DoubleBuffer d_keys_remaining_passes( + (is_overwrite_okay || is_num_passes_odd) ? d_keys.Alternate() : static_cast(allocations[1]), + (is_overwrite_okay) ? d_keys.Current() : (is_num_passes_odd) ? static_cast(allocations[1]) : d_keys.Alternate()); + + DoubleBuffer d_values_remaining_passes( + (is_overwrite_okay || is_num_passes_odd) ? d_values.Alternate() : static_cast(allocations[2]), + (is_overwrite_okay) ? d_values.Current() : (is_num_passes_odd) ? static_cast(allocations[2]) : d_values.Alternate()); + + // Run first pass, consuming from the input's current buffers + int current_bit = begin_bit; + if (CubDebug(error = InvokePass( + d_keys.Current(), d_keys_remaining_passes.Current(), + d_values.Current(), d_values_remaining_passes.Current(), + d_spine, spine_length, current_bit, + (current_bit < alt_end_bit) ? alt_pass_config : pass_config))) break; + + // Run remaining passes + while (current_bit < end_bit) + { + if (CubDebug(error = InvokePass( + d_keys_remaining_passes.d_buffers[d_keys_remaining_passes.selector], d_keys_remaining_passes.d_buffers[d_keys_remaining_passes.selector ^ 1], + d_values_remaining_passes.d_buffers[d_keys_remaining_passes.selector], d_values_remaining_passes.d_buffers[d_keys_remaining_passes.selector ^ 1], + d_spine, spine_length, current_bit, + (current_bit < alt_end_bit) ? alt_pass_config : pass_config))) break;; + + // Invert selectors + d_keys_remaining_passes.selector ^= 1; + d_values_remaining_passes.selector ^= 1; + } + + // Update selector + if (!is_overwrite_okay) { + num_passes = 1; // Sorted data always ends up in the other vector + } + + d_keys.selector = (d_keys.selector + num_passes) & 1; + d_values.selector = (d_values.selector + num_passes) & 1; + } + while (0); + + return error; + +#endif // CUB_RUNTIME_ENABLED + } + + + //------------------------------------------------------------------------------ + // Chained policy invocation + //------------------------------------------------------------------------------ + + /// Invocation + template + CUB_RUNTIME_FUNCTION __forceinline__ + cudaError_t Invoke() + { + typedef typename DispatchRadixSort::MaxPolicy MaxPolicyT; + typedef typename ActivePolicyT::SingleTilePolicy SingleTilePolicyT; + + // Force kernel code-generation in all compiler passes + if (num_items <= (SingleTilePolicyT::BLOCK_THREADS * SingleTilePolicyT::ITEMS_PER_THREAD)) + { + // Small, single tile size + return InvokeSingleTile( + DeviceRadixSortSingleTileKernel); + } + else + { + // Regular size + return InvokePasses( + DeviceRadixSortUpsweepKernel< MaxPolicyT, false, IS_DESCENDING, KeyT, OffsetT>, + DeviceRadixSortUpsweepKernel< MaxPolicyT, true, IS_DESCENDING, KeyT, OffsetT>, + RadixSortScanBinsKernel< MaxPolicyT, OffsetT>, + DeviceRadixSortDownsweepKernel< MaxPolicyT, false, IS_DESCENDING, KeyT, ValueT, OffsetT>, + DeviceRadixSortDownsweepKernel< MaxPolicyT, true, IS_DESCENDING, KeyT, ValueT, OffsetT>); + } + } + + + //------------------------------------------------------------------------------ + // Dispatch entrypoints + //------------------------------------------------------------------------------ + + /** + * Internal dispatch routine + */ + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Dispatch( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + DoubleBuffer &d_keys, ///< [in,out] Double-buffer whose current buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + DoubleBuffer &d_values, ///< [in,out] Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + OffsetT num_items, ///< [in] Number of items to sort + int begin_bit, ///< [in] The beginning (least-significant) bit index needed for key comparison + int end_bit, ///< [in] The past-the-end (most-significant) bit index needed for key comparison + bool is_overwrite_okay, ///< [in] Whether is okay to overwrite source buffers + cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous) ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + typedef typename DispatchRadixSort::MaxPolicy MaxPolicyT; + + cudaError_t error; + do { + // Get PTX version + int ptx_version; + if (CubDebug(error = PtxVersion(ptx_version))) break; + + // Create dispatch functor + DispatchRadixSort dispatch( + d_temp_storage, temp_storage_bytes, + d_keys, d_values, + num_items, begin_bit, end_bit, is_overwrite_okay, + stream, debug_synchronous, ptx_version); + + // Dispatch to chained policy + if (CubDebug(error = MaxPolicyT::Invoke(ptx_version, dispatch))) break; + + } while (0); + + return error; + } +}; + + + + +/****************************************************************************** + * Segmented dispatch + ******************************************************************************/ + +/** + * Utility class for dispatching the appropriately-tuned kernels for segmented device-wide radix sort + */ +template < + bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low + typename KeyT, ///< Key type + typename ValueT, ///< Value type + typename OffsetIteratorT, ///< Random-access input iterator type for reading segment offsets \iterator + typename OffsetT> ///< Signed integer type for global offsets +struct DispatchSegmentedRadixSort : + DeviceRadixSortPolicy +{ + //------------------------------------------------------------------------------ + // Constants + //------------------------------------------------------------------------------ + + enum + { + // Whether this is a keys-only (or key-value) sort + KEYS_ONLY = (Equals::VALUE), + }; + + + //------------------------------------------------------------------------------ + // Parameter members + //------------------------------------------------------------------------------ + + void *d_temp_storage; ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes; ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + DoubleBuffer &d_keys; ///< [in,out] Double-buffer whose current buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + DoubleBuffer &d_values; ///< [in,out] Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + OffsetT num_items; ///< [in] Number of items to sort + OffsetT num_segments; ///< [in] The number of segments that comprise the sorting data + OffsetIteratorT d_begin_offsets; ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets; ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + int begin_bit; ///< [in] The beginning (least-significant) bit index needed for key comparison + int end_bit; ///< [in] The past-the-end (most-significant) bit index needed for key comparison + cudaStream_t stream; ///< [in] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous; ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + int ptx_version; ///< [in] PTX version + bool is_overwrite_okay; ///< [in] Whether is okay to overwrite source buffers + + + //------------------------------------------------------------------------------ + // Constructors + //------------------------------------------------------------------------------ + + /// Constructor + CUB_RUNTIME_FUNCTION __forceinline__ + DispatchSegmentedRadixSort( + void* d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + OffsetT num_items, + OffsetT num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + int begin_bit, + int end_bit, + bool is_overwrite_okay, + cudaStream_t stream, + bool debug_synchronous, + int ptx_version) + : + d_temp_storage(d_temp_storage), + temp_storage_bytes(temp_storage_bytes), + d_keys(d_keys), + d_values(d_values), + num_items(num_items), + num_segments(num_segments), + d_begin_offsets(d_begin_offsets), + d_end_offsets(d_end_offsets), + begin_bit(begin_bit), + end_bit(end_bit), + is_overwrite_okay(is_overwrite_okay), + stream(stream), + debug_synchronous(debug_synchronous), + ptx_version(ptx_version) + {} + + + //------------------------------------------------------------------------------ + // Multi-segment invocation + //------------------------------------------------------------------------------ + + /// Invoke a three-kernel sorting pass at the current bit. + template + CUB_RUNTIME_FUNCTION __forceinline__ + cudaError_t InvokePass( + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + int ¤t_bit, + PassConfigT &pass_config) + { + cudaError error = cudaSuccess; + do + { + int pass_bits = CUB_MIN(pass_config.radix_bits, (end_bit - current_bit)); + + // Log kernel configuration + if (debug_synchronous) + _CubLog("Invoking segmented_kernels<<<%d, %d, 0, %lld>>>(), %d items per thread, %d SM occupancy, current bit %d, bit_grain %d\n", + num_segments, pass_config.segmented_config.block_threads, (long long) stream, + pass_config.segmented_config.items_per_thread, pass_config.segmented_config.sm_occupancy, current_bit, pass_bits); + + pass_config.segmented_kernel<<>>( + d_keys_in, d_keys_out, + d_values_in, d_values_out, + d_begin_offsets, d_end_offsets, num_segments, + current_bit, pass_bits); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + + // Update current bit + current_bit += pass_bits; + } + while (0); + + return error; + } + + + /// PassConfig data structure + template + struct PassConfig + { + SegmentedKernelT segmented_kernel; + KernelConfig segmented_config; + int radix_bits; + int radix_digits; + + /// Initialize pass configuration + template + CUB_RUNTIME_FUNCTION __forceinline__ + cudaError_t InitPassConfig(SegmentedKernelT segmented_kernel) + { + this->segmented_kernel = segmented_kernel; + this->radix_bits = SegmentedPolicyT::RADIX_BITS; + this->radix_digits = 1 << radix_bits; + + return CubDebug(segmented_config.Init(segmented_kernel)); + } + }; + + + /// Invocation (run multiple digit passes) + template < + typename ActivePolicyT, ///< Umbrella policy active for the target device + typename SegmentedKernelT> ///< Function type of cub::DeviceSegmentedRadixSortKernel + CUB_RUNTIME_FUNCTION __forceinline__ + cudaError_t InvokePasses( + SegmentedKernelT segmented_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceSegmentedRadixSortKernel + SegmentedKernelT alt_segmented_kernel) ///< [in] Alternate kernel function pointer to parameterization of cub::DeviceSegmentedRadixSortKernel + { +#ifndef CUB_RUNTIME_ENABLED + (void)segmented_kernel; + (void)alt_segmented_kernel; + + // Kernel launch not supported from this device + return CubDebug(cudaErrorNotSupported ); +#else + + cudaError error = cudaSuccess; + do + { + // Init regular and alternate kernel configurations + PassConfig pass_config, alt_pass_config; + if ((error = pass_config.template InitPassConfig(segmented_kernel))) break; + if ((error = alt_pass_config.template InitPassConfig(alt_segmented_kernel))) break; + + // Temporary storage allocation requirements + void* allocations[2]; + size_t allocation_sizes[2] = + { + (is_overwrite_okay) ? 0 : num_items * sizeof(KeyT), // bytes needed for 3rd keys buffer + (is_overwrite_okay || (KEYS_ONLY)) ? 0 : num_items * sizeof(ValueT), // bytes needed for 3rd values buffer + }; + + // Alias the temporary allocations from the single storage blob (or compute the necessary size of the blob) + if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break; + + // Return if the caller is simply requesting the size of the storage allocation + if (d_temp_storage == NULL) + { + if (temp_storage_bytes == 0) + temp_storage_bytes = 1; + return cudaSuccess; + } + + // Pass planning. Run passes of the alternate digit-size configuration until we have an even multiple of our preferred digit size + int radix_bits = ActivePolicyT::SegmentedPolicy::RADIX_BITS; + int alt_radix_bits = ActivePolicyT::AltSegmentedPolicy::RADIX_BITS; + int num_bits = end_bit - begin_bit; + int num_passes = (num_bits + radix_bits - 1) / radix_bits; + bool is_num_passes_odd = num_passes & 1; + int max_alt_passes = (num_passes * radix_bits) - num_bits; + int alt_end_bit = CUB_MIN(end_bit, begin_bit + (max_alt_passes * alt_radix_bits)); + + DoubleBuffer d_keys_remaining_passes( + (is_overwrite_okay || is_num_passes_odd) ? d_keys.Alternate() : static_cast(allocations[0]), + (is_overwrite_okay) ? d_keys.Current() : (is_num_passes_odd) ? static_cast(allocations[0]) : d_keys.Alternate()); + + DoubleBuffer d_values_remaining_passes( + (is_overwrite_okay || is_num_passes_odd) ? d_values.Alternate() : static_cast(allocations[1]), + (is_overwrite_okay) ? d_values.Current() : (is_num_passes_odd) ? static_cast(allocations[1]) : d_values.Alternate()); + + // Run first pass, consuming from the input's current buffers + int current_bit = begin_bit; + + if (CubDebug(error = InvokePass( + d_keys.Current(), d_keys_remaining_passes.Current(), + d_values.Current(), d_values_remaining_passes.Current(), + current_bit, + (current_bit < alt_end_bit) ? alt_pass_config : pass_config))) break; + + // Run remaining passes + while (current_bit < end_bit) + { + if (CubDebug(error = InvokePass( + d_keys_remaining_passes.d_buffers[d_keys_remaining_passes.selector], d_keys_remaining_passes.d_buffers[d_keys_remaining_passes.selector ^ 1], + d_values_remaining_passes.d_buffers[d_keys_remaining_passes.selector], d_values_remaining_passes.d_buffers[d_keys_remaining_passes.selector ^ 1], + current_bit, + (current_bit < alt_end_bit) ? alt_pass_config : pass_config))) break; + + // Invert selectors and update current bit + d_keys_remaining_passes.selector ^= 1; + d_values_remaining_passes.selector ^= 1; + } + + // Update selector + if (!is_overwrite_okay) { + num_passes = 1; // Sorted data always ends up in the other vector + } + + d_keys.selector = (d_keys.selector + num_passes) & 1; + d_values.selector = (d_values.selector + num_passes) & 1; + } + while (0); + + return error; + +#endif // CUB_RUNTIME_ENABLED + } + + + //------------------------------------------------------------------------------ + // Chained policy invocation + //------------------------------------------------------------------------------ + + /// Invocation + template + CUB_RUNTIME_FUNCTION __forceinline__ + cudaError_t Invoke() + { + typedef typename DispatchSegmentedRadixSort::MaxPolicy MaxPolicyT; + + // Force kernel code-generation in all compiler passes + return InvokePasses( + DeviceSegmentedRadixSortKernel, + DeviceSegmentedRadixSortKernel); + } + + + //------------------------------------------------------------------------------ + // Dispatch entrypoints + //------------------------------------------------------------------------------ + + + /// Internal dispatch routine + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Dispatch( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + DoubleBuffer &d_keys, ///< [in,out] Double-buffer whose current buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + DoubleBuffer &d_values, ///< [in,out] Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + int num_items, ///< [in] Number of items to sort + int num_segments, ///< [in] The number of segments that comprise the sorting data + OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + int begin_bit, ///< [in] The beginning (least-significant) bit index needed for key comparison + int end_bit, ///< [in] The past-the-end (most-significant) bit index needed for key comparison + bool is_overwrite_okay, ///< [in] Whether is okay to overwrite source buffers + cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous) ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + typedef typename DispatchSegmentedRadixSort::MaxPolicy MaxPolicyT; + + cudaError_t error; + do { + // Get PTX version + int ptx_version; + if (CubDebug(error = PtxVersion(ptx_version))) break; + + // Create dispatch functor + DispatchSegmentedRadixSort dispatch( + d_temp_storage, temp_storage_bytes, + d_keys, d_values, + num_items, num_segments, d_begin_offsets, d_end_offsets, + begin_bit, end_bit, is_overwrite_okay, + stream, debug_synchronous, ptx_version); + + // Dispatch to chained policy + if (CubDebug(error = MaxPolicyT::Invoke(ptx_version, dispatch))) break; + + } while (0); + + return error; + } +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/device/dispatch/dispatch_reduce.cuh b/fastertransformer/cuda/cub/device/dispatch/dispatch_reduce.cuh new file mode 100644 index 000000000..e9d1b7ac1 --- /dev/null +++ b/fastertransformer/cuda/cub/device/dispatch/dispatch_reduce.cuh @@ -0,0 +1,882 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DeviceReduce provides device-wide, parallel operations for computing a reduction across a sequence of data items residing within device-accessible memory. + */ + +#pragma once + +#include +#include + +#include "../../agent/agent_reduce.cuh" +#include "../../iterator/arg_index_input_iterator.cuh" +#include "../../thread/thread_operators.cuh" +#include "../../grid/grid_even_share.cuh" +#include "../../iterator/arg_index_input_iterator.cuh" +#include "../../util_debug.cuh" +#include "../../util_device.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/****************************************************************************** + * Kernel entry points + *****************************************************************************/ + +/** + * Reduce region kernel entry point (multi-block). Computes privatized reductions, one per thread block. + */ +template < + typename ChainedPolicyT, ///< Chained tuning policy + typename InputIteratorT, ///< Random-access input iterator type for reading input items \iterator + typename OutputIteratorT, ///< Output iterator type for recording the reduced aggregate \iterator + typename OffsetT, ///< Signed integer type for global offsets + typename ReductionOpT> ///< Binary reduction functor type having member T operator()(const T &a, const T &b) +__launch_bounds__ (int(ChainedPolicyT::ActivePolicy::ReducePolicy::BLOCK_THREADS)) +__global__ void DeviceReduceKernel( + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output aggregate + OffsetT num_items, ///< [in] Total number of input data items + GridEvenShare even_share, ///< [in] Even-share descriptor for mapping an equal number of tiles onto each thread block + ReductionOpT reduction_op) ///< [in] Binary reduction functor +{ + // The output value type + typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? + typename std::iterator_traits::value_type, // ... then the input iterator's value type, + typename std::iterator_traits::value_type>::Type OutputT; // ... else the output iterator's value type + + // Thread block type for reducing input tiles + typedef AgentReduce< + typename ChainedPolicyT::ActivePolicy::ReducePolicy, + InputIteratorT, + OutputIteratorT, + OffsetT, + ReductionOpT> + AgentReduceT; + + // Shared memory storage + __shared__ typename AgentReduceT::TempStorage temp_storage; + + // Consume input tiles + OutputT block_aggregate = AgentReduceT(temp_storage, d_in, reduction_op).ConsumeTiles(even_share); + + // Output result + if (threadIdx.x == 0) + d_out[blockIdx.x] = block_aggregate; +} + + +/** + * Reduce a single tile kernel entry point (single-block). Can be used to aggregate privatized thread block reductions from a previous multi-block reduction pass. + */ +template < + typename ChainedPolicyT, ///< Chained tuning policy + typename InputIteratorT, ///< Random-access input iterator type for reading input items \iterator + typename OutputIteratorT, ///< Output iterator type for recording the reduced aggregate \iterator + typename OffsetT, ///< Signed integer type for global offsets + typename ReductionOpT, ///< Binary reduction functor type having member T operator()(const T &a, const T &b) + typename OuputT> ///< Data element type that is convertible to the \p value type of \p OutputIteratorT +__launch_bounds__ (int(ChainedPolicyT::ActivePolicy::SingleTilePolicy::BLOCK_THREADS), 1) +__global__ void DeviceReduceSingleTileKernel( + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output aggregate + OffsetT num_items, ///< [in] Total number of input data items + ReductionOpT reduction_op, ///< [in] Binary reduction functor + OuputT init) ///< [in] The initial value of the reduction +{ + // Thread block type for reducing input tiles + typedef AgentReduce< + typename ChainedPolicyT::ActivePolicy::SingleTilePolicy, + InputIteratorT, + OutputIteratorT, + OffsetT, + ReductionOpT> + AgentReduceT; + + // Shared memory storage + __shared__ typename AgentReduceT::TempStorage temp_storage; + + // Check if empty problem + if (num_items == 0) + { + if (threadIdx.x == 0) + *d_out = init; + return; + } + + // Consume input tiles + OuputT block_aggregate = AgentReduceT(temp_storage, d_in, reduction_op).ConsumeRange( + OffsetT(0), + num_items); + + // Output result + if (threadIdx.x == 0) + *d_out = reduction_op(init, block_aggregate); +} + + +/// Normalize input iterator to segment offset +template +__device__ __forceinline__ +void NormalizeReductionOutput( + T &/*val*/, + OffsetT /*base_offset*/, + IteratorT /*itr*/) +{} + + +/// Normalize input iterator to segment offset (specialized for arg-index) +template +__device__ __forceinline__ +void NormalizeReductionOutput( + KeyValuePairT &val, + OffsetT base_offset, + ArgIndexInputIterator /*itr*/) +{ + val.key -= base_offset; +} + + +/** + * Segmented reduction (one block per segment) + */ +template < + typename ChainedPolicyT, ///< Chained tuning policy + typename InputIteratorT, ///< Random-access input iterator type for reading input items \iterator + typename OutputIteratorT, ///< Output iterator type for recording the reduced aggregate \iterator + typename OffsetIteratorT, ///< Random-access input iterator type for reading segment offsets \iterator + typename OffsetT, ///< Signed integer type for global offsets + typename ReductionOpT, ///< Binary reduction functor type having member T operator()(const T &a, const T &b) + typename OutputT> ///< Data element type that is convertible to the \p value type of \p OutputIteratorT +__launch_bounds__ (int(ChainedPolicyT::ActivePolicy::ReducePolicy::BLOCK_THREADS)) +__global__ void DeviceSegmentedReduceKernel( + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output aggregate + OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + int /*num_segments*/, ///< [in] The number of segments that comprise the sorting data + ReductionOpT reduction_op, ///< [in] Binary reduction functor + OutputT init) ///< [in] The initial value of the reduction +{ + // Thread block type for reducing input tiles + typedef AgentReduce< + typename ChainedPolicyT::ActivePolicy::ReducePolicy, + InputIteratorT, + OutputIteratorT, + OffsetT, + ReductionOpT> + AgentReduceT; + + // Shared memory storage + __shared__ typename AgentReduceT::TempStorage temp_storage; + + OffsetT segment_begin = d_begin_offsets[blockIdx.x]; + OffsetT segment_end = d_end_offsets[blockIdx.x]; + + // Check if empty problem + if (segment_begin == segment_end) + { + if (threadIdx.x == 0) + d_out[blockIdx.x] = init; + return; + } + + // Consume input tiles + OutputT block_aggregate = AgentReduceT(temp_storage, d_in, reduction_op).ConsumeRange( + segment_begin, + segment_end); + + // Normalize as needed + NormalizeReductionOutput(block_aggregate, segment_begin, d_in); + + if (threadIdx.x == 0) + d_out[blockIdx.x] = reduction_op(init, block_aggregate);; +} + + + + +/****************************************************************************** + * Policy + ******************************************************************************/ + +template < + typename OuputT, ///< Data type + typename OffsetT, ///< Signed integer type for global offsets + typename ReductionOpT> ///< Binary reduction functor type having member T operator()(const T &a, const T &b) +struct DeviceReducePolicy +{ + //------------------------------------------------------------------------------ + // Architecture-specific tuning policies + //------------------------------------------------------------------------------ + + /// SM13 + struct Policy130 : ChainedPolicy<130, Policy130, Policy130> + { + // ReducePolicy + typedef AgentReducePolicy< + CUB_SCALED_GRANULARITIES(128, 8, OuputT), ///< Threads per block, items per thread + 2, ///< Number of items per vectorized load + BLOCK_REDUCE_RAKING, ///< Cooperative block-wide reduction algorithm to use + LOAD_DEFAULT> ///< Cache load modifier + ReducePolicy; + + // SingleTilePolicy + typedef ReducePolicy SingleTilePolicy; + + // SegmentedReducePolicy + typedef ReducePolicy SegmentedReducePolicy; + }; + + + /// SM20 + struct Policy200 : ChainedPolicy<200, Policy200, Policy130> + { + // ReducePolicy (GTX 580: 178.9 GB/s @ 48M 4B items, 158.1 GB/s @ 192M 1B items) + typedef AgentReducePolicy< + CUB_SCALED_GRANULARITIES(128, 8, OuputT), ///< Threads per block, items per thread + 4, ///< Number of items per vectorized load + BLOCK_REDUCE_RAKING, ///< Cooperative block-wide reduction algorithm to use + LOAD_DEFAULT> ///< Cache load modifier + ReducePolicy; + + // SingleTilePolicy + typedef ReducePolicy SingleTilePolicy; + + // SegmentedReducePolicy + typedef ReducePolicy SegmentedReducePolicy; + }; + + + /// SM30 + struct Policy300 : ChainedPolicy<300, Policy300, Policy200> + { + // ReducePolicy (GTX670: 154.0 @ 48M 4B items) + typedef AgentReducePolicy< + CUB_SCALED_GRANULARITIES(256, 20, OuputT), ///< Threads per block, items per thread + 2, ///< Number of items per vectorized load + BLOCK_REDUCE_WARP_REDUCTIONS, ///< Cooperative block-wide reduction algorithm to use + LOAD_DEFAULT> ///< Cache load modifier + ReducePolicy; + + // SingleTilePolicy + typedef ReducePolicy SingleTilePolicy; + + // SegmentedReducePolicy + typedef ReducePolicy SegmentedReducePolicy; + }; + + + /// SM35 + struct Policy350 : ChainedPolicy<350, Policy350, Policy300> + { + // ReducePolicy (GTX Titan: 255.1 GB/s @ 48M 4B items; 228.7 GB/s @ 192M 1B items) + typedef AgentReducePolicy< + CUB_SCALED_GRANULARITIES(256, 20, OuputT), ///< Threads per block, items per thread + 4, ///< Number of items per vectorized load + BLOCK_REDUCE_WARP_REDUCTIONS, ///< Cooperative block-wide reduction algorithm to use + LOAD_LDG> ///< Cache load modifier + ReducePolicy; + + // SingleTilePolicy + typedef ReducePolicy SingleTilePolicy; + + // SegmentedReducePolicy + typedef ReducePolicy SegmentedReducePolicy; + }; + + /// SM60 + struct Policy600 : ChainedPolicy<600, Policy600, Policy350> + { + // ReducePolicy (P100: 591 GB/s @ 64M 4B items; 583 GB/s @ 256M 1B items) + typedef AgentReducePolicy< + CUB_SCALED_GRANULARITIES(256, 16, OuputT), ///< Threads per block, items per thread + 4, ///< Number of items per vectorized load + BLOCK_REDUCE_WARP_REDUCTIONS, ///< Cooperative block-wide reduction algorithm to use + LOAD_LDG> ///< Cache load modifier + ReducePolicy; + + // SingleTilePolicy + typedef ReducePolicy SingleTilePolicy; + + // SegmentedReducePolicy + typedef ReducePolicy SegmentedReducePolicy; + }; + + + /// MaxPolicy + typedef Policy600 MaxPolicy; + +}; + + + +/****************************************************************************** + * Single-problem dispatch + ******************************************************************************/ + +/** + * Utility class for dispatching the appropriately-tuned kernels for device-wide reduction + */ +template < + typename InputIteratorT, ///< Random-access input iterator type for reading input items \iterator + typename OutputIteratorT, ///< Output iterator type for recording the reduced aggregate \iterator + typename OffsetT, ///< Signed integer type for global offsets + typename ReductionOpT> ///< Binary reduction functor type having member T operator()(const T &a, const T &b) +struct DispatchReduce : + DeviceReducePolicy< + typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? + typename std::iterator_traits::value_type, // ... then the input iterator's value type, + typename std::iterator_traits::value_type>::Type, // ... else the output iterator's value type + OffsetT, + ReductionOpT> +{ + //------------------------------------------------------------------------------ + // Constants + //------------------------------------------------------------------------------ + + // Data type of output iterator + typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? + typename std::iterator_traits::value_type, // ... then the input iterator's value type, + typename std::iterator_traits::value_type>::Type OutputT; // ... else the output iterator's value type + + + //------------------------------------------------------------------------------ + // Problem state + //------------------------------------------------------------------------------ + + void *d_temp_storage; ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes; ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in; ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out; ///< [out] Pointer to the output aggregate + OffsetT num_items; ///< [in] Total number of input items (i.e., length of \p d_in) + ReductionOpT reduction_op; ///< [in] Binary reduction functor + OutputT init; ///< [in] The initial value of the reduction + cudaStream_t stream; ///< [in] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous; ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + int ptx_version; ///< [in] PTX version + + //------------------------------------------------------------------------------ + // Constructor + //------------------------------------------------------------------------------ + + /// Constructor + CUB_RUNTIME_FUNCTION __forceinline__ + DispatchReduce( + void* d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + OffsetT num_items, + ReductionOpT reduction_op, + OutputT init, + cudaStream_t stream, + bool debug_synchronous, + int ptx_version) + : + d_temp_storage(d_temp_storage), + temp_storage_bytes(temp_storage_bytes), + d_in(d_in), + d_out(d_out), + num_items(num_items), + reduction_op(reduction_op), + init(init), + stream(stream), + debug_synchronous(debug_synchronous), + ptx_version(ptx_version) + {} + + + //------------------------------------------------------------------------------ + // Small-problem (single tile) invocation + //------------------------------------------------------------------------------ + + /// Invoke a single block block to reduce in-core + template < + typename ActivePolicyT, ///< Umbrella policy active for the target device + typename SingleTileKernelT> ///< Function type of cub::DeviceReduceSingleTileKernel + CUB_RUNTIME_FUNCTION __forceinline__ + cudaError_t InvokeSingleTile( + SingleTileKernelT single_tile_kernel) ///< [in] Kernel function pointer to parameterization of cub::DeviceReduceSingleTileKernel + { +#ifndef CUB_RUNTIME_ENABLED + (void)single_tile_kernel; + + // Kernel launch not supported from this device + return CubDebug(cudaErrorNotSupported ); +#else + cudaError error = cudaSuccess; + do + { + // Return if the caller is simply requesting the size of the storage allocation + if (d_temp_storage == NULL) + { + temp_storage_bytes = 1; + break; + } + + // Log single_reduce_sweep_kernel configuration + if (debug_synchronous) _CubLog("Invoking DeviceReduceSingleTileKernel<<<1, %d, 0, %lld>>>(), %d items per thread\n", + ActivePolicyT::SingleTilePolicy::BLOCK_THREADS, + (long long) stream, + ActivePolicyT::SingleTilePolicy::ITEMS_PER_THREAD); + + // Invoke single_reduce_sweep_kernel + single_tile_kernel<<<1, ActivePolicyT::SingleTilePolicy::BLOCK_THREADS, 0, stream>>>( + d_in, + d_out, + num_items, + reduction_op, + init); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + } + while (0); + + return error; + +#endif // CUB_RUNTIME_ENABLED + } + + + //------------------------------------------------------------------------------ + // Normal problem size invocation (two-pass) + //------------------------------------------------------------------------------ + + /// Invoke two-passes to reduce + template < + typename ActivePolicyT, ///< Umbrella policy active for the target device + typename ReduceKernelT, ///< Function type of cub::DeviceReduceKernel + typename SingleTileKernelT> ///< Function type of cub::DeviceReduceSingleTileKernel + CUB_RUNTIME_FUNCTION __forceinline__ + cudaError_t InvokePasses( + ReduceKernelT reduce_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceReduceKernel + SingleTileKernelT single_tile_kernel) ///< [in] Kernel function pointer to parameterization of cub::DeviceReduceSingleTileKernel + { +#ifndef CUB_RUNTIME_ENABLED + (void) reduce_kernel; + (void) single_tile_kernel; + + // Kernel launch not supported from this device + return CubDebug(cudaErrorNotSupported ); +#else + + cudaError error = cudaSuccess; + do + { + // Get device ordinal + int device_ordinal; + if (CubDebug(error = cudaGetDevice(&device_ordinal))) break; + + // Get SM count + int sm_count; + if (CubDebug(error = cudaDeviceGetAttribute (&sm_count, cudaDevAttrMultiProcessorCount, device_ordinal))) break; + + // Init regular kernel configuration + KernelConfig reduce_config; + if (CubDebug(error = reduce_config.Init(reduce_kernel))) break; + int reduce_device_occupancy = reduce_config.sm_occupancy * sm_count; + + // Even-share work distribution + int max_blocks = reduce_device_occupancy * CUB_SUBSCRIPTION_FACTOR(ptx_version); + GridEvenShare even_share; + even_share.DispatchInit(num_items, max_blocks, reduce_config.tile_size); + + // Temporary storage allocation requirements + void* allocations[1]; + size_t allocation_sizes[1] = + { + max_blocks * sizeof(OutputT) // bytes needed for privatized block reductions + }; + + // Alias the temporary allocations from the single storage blob (or compute the necessary size of the blob) + if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break; + if (d_temp_storage == NULL) + { + // Return if the caller is simply requesting the size of the storage allocation + return cudaSuccess; + } + + // Alias the allocation for the privatized per-block reductions + OutputT *d_block_reductions = (OutputT*) allocations[0]; + + // Get grid size for device_reduce_sweep_kernel + int reduce_grid_size = even_share.grid_size; + + // Log device_reduce_sweep_kernel configuration + if (debug_synchronous) _CubLog("Invoking DeviceReduceKernel<<<%d, %d, 0, %lld>>>(), %d items per thread, %d SM occupancy\n", + reduce_grid_size, + ActivePolicyT::ReducePolicy::BLOCK_THREADS, + (long long) stream, + ActivePolicyT::ReducePolicy::ITEMS_PER_THREAD, + reduce_config.sm_occupancy); + + // Invoke DeviceReduceKernel + reduce_kernel<<>>( + d_in, + d_block_reductions, + num_items, + even_share, + reduction_op); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + + // Log single_reduce_sweep_kernel configuration + if (debug_synchronous) _CubLog("Invoking DeviceReduceSingleTileKernel<<<1, %d, 0, %lld>>>(), %d items per thread\n", + ActivePolicyT::SingleTilePolicy::BLOCK_THREADS, + (long long) stream, + ActivePolicyT::SingleTilePolicy::ITEMS_PER_THREAD); + + // Invoke DeviceReduceSingleTileKernel + single_tile_kernel<<<1, ActivePolicyT::SingleTilePolicy::BLOCK_THREADS, 0, stream>>>( + d_block_reductions, + d_out, + reduce_grid_size, + reduction_op, + init); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + } + while (0); + + return error; + +#endif // CUB_RUNTIME_ENABLED + + } + + + //------------------------------------------------------------------------------ + // Chained policy invocation + //------------------------------------------------------------------------------ + + /// Invocation + template + CUB_RUNTIME_FUNCTION __forceinline__ + cudaError_t Invoke() + { + typedef typename ActivePolicyT::SingleTilePolicy SingleTilePolicyT; + typedef typename DispatchReduce::MaxPolicy MaxPolicyT; + + // Force kernel code-generation in all compiler passes + if (num_items <= (SingleTilePolicyT::BLOCK_THREADS * SingleTilePolicyT::ITEMS_PER_THREAD)) + { + // Small, single tile size + return InvokeSingleTile( + DeviceReduceSingleTileKernel); + } + else + { + // Regular size + return InvokePasses( + DeviceReduceKernel, + DeviceReduceSingleTileKernel); + } + } + + + //------------------------------------------------------------------------------ + // Dispatch entrypoints + //------------------------------------------------------------------------------ + + /** + * Internal dispatch routine for computing a device-wide reduction + */ + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Dispatch( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output aggregate + OffsetT num_items, ///< [in] Total number of input items (i.e., length of \p d_in) + ReductionOpT reduction_op, ///< [in] Binary reduction functor + OutputT init, ///< [in] The initial value of the reduction + cudaStream_t stream, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + typedef typename DispatchReduce::MaxPolicy MaxPolicyT; + + cudaError error = cudaSuccess; + do + { + // Get PTX version + int ptx_version; + if (CubDebug(error = PtxVersion(ptx_version))) break; + + // Create dispatch functor + DispatchReduce dispatch( + d_temp_storage, temp_storage_bytes, + d_in, d_out, num_items, reduction_op, init, + stream, debug_synchronous, ptx_version); + + // Dispatch to chained policy + if (CubDebug(error = MaxPolicyT::Invoke(ptx_version, dispatch))) break; + } + while (0); + + return error; + } +}; + + + +/****************************************************************************** + * Segmented dispatch + ******************************************************************************/ + +/** + * Utility class for dispatching the appropriately-tuned kernels for device-wide reduction + */ +template < + typename InputIteratorT, ///< Random-access input iterator type for reading input items \iterator + typename OutputIteratorT, ///< Output iterator type for recording the reduced aggregate \iterator + typename OffsetIteratorT, ///< Random-access input iterator type for reading segment offsets \iterator + typename OffsetT, ///< Signed integer type for global offsets + typename ReductionOpT> ///< Binary reduction functor type having member T operator()(const T &a, const T &b) +struct DispatchSegmentedReduce : + DeviceReducePolicy< + typename std::iterator_traits::value_type, + OffsetT, + ReductionOpT> +{ + //------------------------------------------------------------------------------ + // Constants + //------------------------------------------------------------------------------ + + /// The output value type + typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? + typename std::iterator_traits::value_type, // ... then the input iterator's value type, + typename std::iterator_traits::value_type>::Type OutputT; // ... else the output iterator's value type + + + //------------------------------------------------------------------------------ + // Problem state + //------------------------------------------------------------------------------ + + void *d_temp_storage; ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes; ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in; ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out; ///< [out] Pointer to the output aggregate + OffsetT num_segments; ///< [in] The number of segments that comprise the sorting data + OffsetIteratorT d_begin_offsets; ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets; ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + ReductionOpT reduction_op; ///< [in] Binary reduction functor + OutputT init; ///< [in] The initial value of the reduction + cudaStream_t stream; ///< [in] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous; ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + int ptx_version; ///< [in] PTX version + + //------------------------------------------------------------------------------ + // Constructor + //------------------------------------------------------------------------------ + + /// Constructor + CUB_RUNTIME_FUNCTION __forceinline__ + DispatchSegmentedReduce( + void* d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + OffsetT num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + ReductionOpT reduction_op, + OutputT init, + cudaStream_t stream, + bool debug_synchronous, + int ptx_version) + : + d_temp_storage(d_temp_storage), + temp_storage_bytes(temp_storage_bytes), + d_in(d_in), + d_out(d_out), + num_segments(num_segments), + d_begin_offsets(d_begin_offsets), + d_end_offsets(d_end_offsets), + reduction_op(reduction_op), + init(init), + stream(stream), + debug_synchronous(debug_synchronous), + ptx_version(ptx_version) + {} + + + + //------------------------------------------------------------------------------ + // Chained policy invocation + //------------------------------------------------------------------------------ + + /// Invocation + template < + typename ActivePolicyT, ///< Umbrella policy active for the target device + typename DeviceSegmentedReduceKernelT> ///< Function type of cub::DeviceSegmentedReduceKernel + CUB_RUNTIME_FUNCTION __forceinline__ + cudaError_t InvokePasses( + DeviceSegmentedReduceKernelT segmented_reduce_kernel) ///< [in] Kernel function pointer to parameterization of cub::DeviceSegmentedReduceKernel + { +#ifndef CUB_RUNTIME_ENABLED + (void)segmented_reduce_kernel; + // Kernel launch not supported from this device + return CubDebug(cudaErrorNotSupported ); +#else + cudaError error = cudaSuccess; + do + { + // Return if the caller is simply requesting the size of the storage allocation + if (d_temp_storage == NULL) + { + temp_storage_bytes = 1; + return cudaSuccess; + } + + // Init kernel configuration + KernelConfig segmented_reduce_config; + if (CubDebug(error = segmented_reduce_config.Init(segmented_reduce_kernel))) break; + + // Log device_reduce_sweep_kernel configuration + if (debug_synchronous) _CubLog("Invoking SegmentedDeviceReduceKernel<<<%d, %d, 0, %lld>>>(), %d items per thread, %d SM occupancy\n", + num_segments, + ActivePolicyT::SegmentedReducePolicy::BLOCK_THREADS, + (long long) stream, + ActivePolicyT::SegmentedReducePolicy::ITEMS_PER_THREAD, + segmented_reduce_config.sm_occupancy); + + // Invoke DeviceReduceKernel + segmented_reduce_kernel<<>>( + d_in, + d_out, + d_begin_offsets, + d_end_offsets, + num_segments, + reduction_op, + init); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + } + while (0); + + return error; + +#endif // CUB_RUNTIME_ENABLED + + } + + + /// Invocation + template + CUB_RUNTIME_FUNCTION __forceinline__ + cudaError_t Invoke() + { + typedef typename DispatchSegmentedReduce::MaxPolicy MaxPolicyT; + + // Force kernel code-generation in all compiler passes + return InvokePasses( + DeviceSegmentedReduceKernel); + } + + + //------------------------------------------------------------------------------ + // Dispatch entrypoints + //------------------------------------------------------------------------------ + + /** + * Internal dispatch routine for computing a device-wide reduction + */ + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Dispatch( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output aggregate + int num_segments, ///< [in] The number of segments that comprise the sorting data + OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. + ReductionOpT reduction_op, ///< [in] Binary reduction functor + OutputT init, ///< [in] The initial value of the reduction + cudaStream_t stream, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + typedef typename DispatchSegmentedReduce::MaxPolicy MaxPolicyT; + + if (num_segments <= 0) + return cudaSuccess; + + cudaError error = cudaSuccess; + do + { + // Get PTX version + int ptx_version; + if (CubDebug(error = PtxVersion(ptx_version))) break; + + // Create dispatch functor + DispatchSegmentedReduce dispatch( + d_temp_storage, temp_storage_bytes, + d_in, d_out, + num_segments, d_begin_offsets, d_end_offsets, + reduction_op, init, + stream, debug_synchronous, ptx_version); + + // Dispatch to chained policy + if (CubDebug(error = MaxPolicyT::Invoke(ptx_version, dispatch))) break; + } + while (0); + + return error; + } +}; + + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/device/dispatch/dispatch_reduce_by_key.cuh b/fastertransformer/cuda/cub/device/dispatch/dispatch_reduce_by_key.cuh new file mode 100644 index 000000000..6f4837b7f --- /dev/null +++ b/fastertransformer/cuda/cub/device/dispatch/dispatch_reduce_by_key.cuh @@ -0,0 +1,554 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DeviceReduceByKey provides device-wide, parallel operations for reducing segments of values residing within device-accessible memory. + */ + +#pragma once + +#include +#include + +#include "dispatch_scan.cuh" +#include "../../agent/agent_reduce_by_key.cuh" +#include "../../thread/thread_operators.cuh" +#include "../../grid/grid_queue.cuh" +#include "../../util_device.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/****************************************************************************** + * Kernel entry points + *****************************************************************************/ + +/** + * Multi-block reduce-by-key sweep kernel entry point + */ +template < + typename AgentReduceByKeyPolicyT, ///< Parameterized AgentReduceByKeyPolicyT tuning policy type + typename KeysInputIteratorT, ///< Random-access input iterator type for keys + typename UniqueOutputIteratorT, ///< Random-access output iterator type for keys + typename ValuesInputIteratorT, ///< Random-access input iterator type for values + typename AggregatesOutputIteratorT, ///< Random-access output iterator type for values + typename NumRunsOutputIteratorT, ///< Output iterator type for recording number of segments encountered + typename ScanTileStateT, ///< Tile status interface type + typename EqualityOpT, ///< KeyT equality operator type + typename ReductionOpT, ///< ValueT reduction operator type + typename OffsetT> ///< Signed integer type for global offsets +__launch_bounds__ (int(AgentReduceByKeyPolicyT::BLOCK_THREADS)) +__global__ void DeviceReduceByKeyKernel( + KeysInputIteratorT d_keys_in, ///< Pointer to the input sequence of keys + UniqueOutputIteratorT d_unique_out, ///< Pointer to the output sequence of unique keys (one key per run) + ValuesInputIteratorT d_values_in, ///< Pointer to the input sequence of corresponding values + AggregatesOutputIteratorT d_aggregates_out, ///< Pointer to the output sequence of value aggregates (one aggregate per run) + NumRunsOutputIteratorT d_num_runs_out, ///< Pointer to total number of runs encountered (i.e., the length of d_unique_out) + ScanTileStateT tile_state, ///< Tile status interface + int start_tile, ///< The starting tile for the current grid + EqualityOpT equality_op, ///< KeyT equality operator + ReductionOpT reduction_op, ///< ValueT reduction operator + OffsetT num_items) ///< Total number of items to select from +{ + // Thread block type for reducing tiles of value segments + typedef AgentReduceByKey< + AgentReduceByKeyPolicyT, + KeysInputIteratorT, + UniqueOutputIteratorT, + ValuesInputIteratorT, + AggregatesOutputIteratorT, + NumRunsOutputIteratorT, + EqualityOpT, + ReductionOpT, + OffsetT> + AgentReduceByKeyT; + + // Shared memory for AgentReduceByKey + __shared__ typename AgentReduceByKeyT::TempStorage temp_storage; + + // Process tiles + AgentReduceByKeyT(temp_storage, d_keys_in, d_unique_out, d_values_in, d_aggregates_out, d_num_runs_out, equality_op, reduction_op).ConsumeRange( + num_items, + tile_state, + start_tile); +} + + + + +/****************************************************************************** + * Dispatch + ******************************************************************************/ + +/** + * Utility class for dispatching the appropriately-tuned kernels for DeviceReduceByKey + */ +template < + typename KeysInputIteratorT, ///< Random-access input iterator type for keys + typename UniqueOutputIteratorT, ///< Random-access output iterator type for keys + typename ValuesInputIteratorT, ///< Random-access input iterator type for values + typename AggregatesOutputIteratorT, ///< Random-access output iterator type for values + typename NumRunsOutputIteratorT, ///< Output iterator type for recording number of segments encountered + typename EqualityOpT, ///< KeyT equality operator type + typename ReductionOpT, ///< ValueT reduction operator type + typename OffsetT> ///< Signed integer type for global offsets +struct DispatchReduceByKey +{ + //------------------------------------------------------------------------- + // Types and constants + //------------------------------------------------------------------------- + + // The input keys type + typedef typename std::iterator_traits::value_type KeyInputT; + + // The output keys type + typedef typename If<(Equals::value_type, void>::VALUE), // KeyOutputT = (if output iterator's value type is void) ? + typename std::iterator_traits::value_type, // ... then the input iterator's value type, + typename std::iterator_traits::value_type>::Type KeyOutputT; // ... else the output iterator's value type + + // The input values type + typedef typename std::iterator_traits::value_type ValueInputT; + + // The output values type + typedef typename If<(Equals::value_type, void>::VALUE), // ValueOutputT = (if output iterator's value type is void) ? + typename std::iterator_traits::value_type, // ... then the input iterator's value type, + typename std::iterator_traits::value_type>::Type ValueOutputT; // ... else the output iterator's value type + + enum + { + INIT_KERNEL_THREADS = 128, + MAX_INPUT_BYTES = CUB_MAX(sizeof(KeyOutputT), sizeof(ValueOutputT)), + COMBINED_INPUT_BYTES = sizeof(KeyOutputT) + sizeof(ValueOutputT), + }; + + // Tile status descriptor interface type + typedef ReduceByKeyScanTileState ScanTileStateT; + + + //------------------------------------------------------------------------- + // Tuning policies + //------------------------------------------------------------------------- + + /// SM35 + struct Policy350 + { + enum { + NOMINAL_4B_ITEMS_PER_THREAD = 6, + ITEMS_PER_THREAD = (MAX_INPUT_BYTES <= 8) ? 6 : CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, ((NOMINAL_4B_ITEMS_PER_THREAD * 8) + COMBINED_INPUT_BYTES - 1) / COMBINED_INPUT_BYTES)), + }; + + typedef AgentReduceByKeyPolicy< + 128, + ITEMS_PER_THREAD, + BLOCK_LOAD_DIRECT, + LOAD_LDG, + BLOCK_SCAN_WARP_SCANS> + ReduceByKeyPolicyT; + }; + + /// SM30 + struct Policy300 + { + enum { + NOMINAL_4B_ITEMS_PER_THREAD = 6, + ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, ((NOMINAL_4B_ITEMS_PER_THREAD * 8) + COMBINED_INPUT_BYTES - 1) / COMBINED_INPUT_BYTES)), + }; + + typedef AgentReduceByKeyPolicy< + 128, + ITEMS_PER_THREAD, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_SCAN_WARP_SCANS> + ReduceByKeyPolicyT; + }; + + /// SM20 + struct Policy200 + { + enum { + NOMINAL_4B_ITEMS_PER_THREAD = 11, + ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, ((NOMINAL_4B_ITEMS_PER_THREAD * 8) + COMBINED_INPUT_BYTES - 1) / COMBINED_INPUT_BYTES)), + }; + + typedef AgentReduceByKeyPolicy< + 128, + ITEMS_PER_THREAD, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_SCAN_WARP_SCANS> + ReduceByKeyPolicyT; + }; + + /// SM13 + struct Policy130 + { + enum { + NOMINAL_4B_ITEMS_PER_THREAD = 7, + ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, ((NOMINAL_4B_ITEMS_PER_THREAD * 8) + COMBINED_INPUT_BYTES - 1) / COMBINED_INPUT_BYTES)), + }; + + typedef AgentReduceByKeyPolicy< + 128, + ITEMS_PER_THREAD, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_SCAN_WARP_SCANS> + ReduceByKeyPolicyT; + }; + + /// SM11 + struct Policy110 + { + enum { + NOMINAL_4B_ITEMS_PER_THREAD = 5, + ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 8) / COMBINED_INPUT_BYTES)), + }; + + typedef AgentReduceByKeyPolicy< + 64, + ITEMS_PER_THREAD, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_SCAN_RAKING> + ReduceByKeyPolicyT; + }; + + + /****************************************************************************** + * Tuning policies of current PTX compiler pass + ******************************************************************************/ + +#if (CUB_PTX_ARCH >= 350) + typedef Policy350 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 300) + typedef Policy300 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 200) + typedef Policy200 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 130) + typedef Policy130 PtxPolicy; + +#else + typedef Policy110 PtxPolicy; + +#endif + + // "Opaque" policies (whose parameterizations aren't reflected in the type signature) + struct PtxReduceByKeyPolicy : PtxPolicy::ReduceByKeyPolicyT {}; + + + /****************************************************************************** + * Utilities + ******************************************************************************/ + + /** + * Initialize kernel dispatch configurations with the policies corresponding to the PTX assembly we will use + */ + template + CUB_RUNTIME_FUNCTION __forceinline__ + static void InitConfigs( + int ptx_version, + KernelConfig &reduce_by_key_config) + { + #if (CUB_PTX_ARCH > 0) + (void)ptx_version; + + // We're on the device, so initialize the kernel dispatch configurations with the current PTX policy + reduce_by_key_config.template Init(); + + #else + + // We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version + if (ptx_version >= 350) + { + reduce_by_key_config.template Init(); + } + else if (ptx_version >= 300) + { + reduce_by_key_config.template Init(); + } + else if (ptx_version >= 200) + { + reduce_by_key_config.template Init(); + } + else if (ptx_version >= 130) + { + reduce_by_key_config.template Init(); + } + else + { + reduce_by_key_config.template Init(); + } + + #endif + } + + + /** + * Kernel kernel dispatch configuration. + */ + struct KernelConfig + { + int block_threads; + int items_per_thread; + int tile_items; + + template + CUB_RUNTIME_FUNCTION __forceinline__ + void Init() + { + block_threads = PolicyT::BLOCK_THREADS; + items_per_thread = PolicyT::ITEMS_PER_THREAD; + tile_items = block_threads * items_per_thread; + } + }; + + + //--------------------------------------------------------------------- + // Dispatch entrypoints + //--------------------------------------------------------------------- + + /** + * Internal dispatch routine for computing a device-wide reduce-by-key using the + * specified kernel functions. + */ + template < + typename ScanInitKernelT, ///< Function type of cub::DeviceScanInitKernel + typename ReduceByKeyKernelT> ///< Function type of cub::DeviceReduceByKeyKernelT + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Dispatch( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + KeysInputIteratorT d_keys_in, ///< [in] Pointer to the input sequence of keys + UniqueOutputIteratorT d_unique_out, ///< [out] Pointer to the output sequence of unique keys (one key per run) + ValuesInputIteratorT d_values_in, ///< [in] Pointer to the input sequence of corresponding values + AggregatesOutputIteratorT d_aggregates_out, ///< [out] Pointer to the output sequence of value aggregates (one aggregate per run) + NumRunsOutputIteratorT d_num_runs_out, ///< [out] Pointer to total number of runs encountered (i.e., the length of d_unique_out) + EqualityOpT equality_op, ///< [in] KeyT equality operator + ReductionOpT reduction_op, ///< [in] ValueT reduction operator + OffsetT num_items, ///< [in] Total number of items to select from + cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous, ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + int /*ptx_version*/, ///< [in] PTX version of dispatch kernels + ScanInitKernelT init_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceScanInitKernel + ReduceByKeyKernelT reduce_by_key_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceReduceByKeyKernel + KernelConfig reduce_by_key_config) ///< [in] Dispatch parameters that match the policy that \p reduce_by_key_kernel was compiled for + { + +#ifndef CUB_RUNTIME_ENABLED + (void)d_temp_storage; + (void)temp_storage_bytes; + (void)d_keys_in; + (void)d_unique_out; + (void)d_values_in; + (void)d_aggregates_out; + (void)d_num_runs_out; + (void)equality_op; + (void)reduction_op; + (void)num_items; + (void)stream; + (void)debug_synchronous; + (void)init_kernel; + (void)reduce_by_key_kernel; + (void)reduce_by_key_config; + + // Kernel launch not supported from this device + return CubDebug(cudaErrorNotSupported); + +#else + + cudaError error = cudaSuccess; + do + { + // Get device ordinal + int device_ordinal; + if (CubDebug(error = cudaGetDevice(&device_ordinal))) break; + + // Get SM count + int sm_count; + if (CubDebug(error = cudaDeviceGetAttribute (&sm_count, cudaDevAttrMultiProcessorCount, device_ordinal))) break; + + // Number of input tiles + int tile_size = reduce_by_key_config.block_threads * reduce_by_key_config.items_per_thread; + int num_tiles = (num_items + tile_size - 1) / tile_size; + + // Specify temporary storage allocation requirements + size_t allocation_sizes[1]; + if (CubDebug(error = ScanTileStateT::AllocationSize(num_tiles, allocation_sizes[0]))) break; // bytes needed for tile status descriptors + + // Compute allocation pointers into the single storage blob (or compute the necessary size of the blob) + void* allocations[1]; + if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break; + if (d_temp_storage == NULL) + { + // Return if the caller is simply requesting the size of the storage allocation + break; + } + + // Construct the tile status interface + ScanTileStateT tile_state; + if (CubDebug(error = tile_state.Init(num_tiles, allocations[0], allocation_sizes[0]))) break; + + // Log init_kernel configuration + int init_grid_size = CUB_MAX(1, (num_tiles + INIT_KERNEL_THREADS - 1) / INIT_KERNEL_THREADS); + if (debug_synchronous) _CubLog("Invoking init_kernel<<<%d, %d, 0, %lld>>>()\n", init_grid_size, INIT_KERNEL_THREADS, (long long) stream); + + // Invoke init_kernel to initialize tile descriptors + init_kernel<<>>( + tile_state, + num_tiles, + d_num_runs_out); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + + // Return if empty problem + if (num_items == 0) + break; + + // Get SM occupancy for reduce_by_key_kernel + int reduce_by_key_sm_occupancy; + if (CubDebug(error = MaxSmOccupancy( + reduce_by_key_sm_occupancy, // out + reduce_by_key_kernel, + reduce_by_key_config.block_threads))) break; + + // Get max x-dimension of grid + int max_dim_x; + if (CubDebug(error = cudaDeviceGetAttribute(&max_dim_x, cudaDevAttrMaxGridDimX, device_ordinal))) break;; + + // Run grids in epochs (in case number of tiles exceeds max x-dimension + int scan_grid_size = CUB_MIN(num_tiles, max_dim_x); + for (int start_tile = 0; start_tile < num_tiles; start_tile += scan_grid_size) + { + // Log reduce_by_key_kernel configuration + if (debug_synchronous) _CubLog("Invoking %d reduce_by_key_kernel<<<%d, %d, 0, %lld>>>(), %d items per thread, %d SM occupancy\n", + start_tile, scan_grid_size, reduce_by_key_config.block_threads, (long long) stream, reduce_by_key_config.items_per_thread, reduce_by_key_sm_occupancy); + + // Invoke reduce_by_key_kernel + reduce_by_key_kernel<<>>( + d_keys_in, + d_unique_out, + d_values_in, + d_aggregates_out, + d_num_runs_out, + tile_state, + start_tile, + equality_op, + reduction_op, + num_items); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + } + } + while (0); + + return error; + +#endif // CUB_RUNTIME_ENABLED + } + + + /** + * Internal dispatch routine + */ + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Dispatch( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + KeysInputIteratorT d_keys_in, ///< [in] Pointer to the input sequence of keys + UniqueOutputIteratorT d_unique_out, ///< [out] Pointer to the output sequence of unique keys (one key per run) + ValuesInputIteratorT d_values_in, ///< [in] Pointer to the input sequence of corresponding values + AggregatesOutputIteratorT d_aggregates_out, ///< [out] Pointer to the output sequence of value aggregates (one aggregate per run) + NumRunsOutputIteratorT d_num_runs_out, ///< [out] Pointer to total number of runs encountered (i.e., the length of d_unique_out) + EqualityOpT equality_op, ///< [in] KeyT equality operator + ReductionOpT reduction_op, ///< [in] ValueT reduction operator + OffsetT num_items, ///< [in] Total number of items to select from + cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous) ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + cudaError error = cudaSuccess; + do + { + // Get PTX version + int ptx_version; + #if (CUB_PTX_ARCH == 0) + if (CubDebug(error = PtxVersion(ptx_version))) break; + #else + ptx_version = CUB_PTX_ARCH; + #endif + + // Get kernel kernel dispatch configurations + KernelConfig reduce_by_key_config; + InitConfigs(ptx_version, reduce_by_key_config); + + // Dispatch + if (CubDebug(error = Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys_in, + d_unique_out, + d_values_in, + d_aggregates_out, + d_num_runs_out, + equality_op, + reduction_op, + num_items, + stream, + debug_synchronous, + ptx_version, + DeviceCompactInitKernel, + DeviceReduceByKeyKernel, + reduce_by_key_config))) break; + } + while (0); + + return error; + } +}; + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/device/dispatch/dispatch_rle.cuh b/fastertransformer/cuda/cub/device/dispatch/dispatch_rle.cuh new file mode 100644 index 000000000..98c3681f0 --- /dev/null +++ b/fastertransformer/cuda/cub/device/dispatch/dispatch_rle.cuh @@ -0,0 +1,538 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DeviceRle provides device-wide, parallel operations for run-length-encoding sequences of data items residing within device-accessible memory. + */ + +#pragma once + +#include +#include + +#include "dispatch_scan.cuh" +#include "../../agent/agent_rle.cuh" +#include "../../thread/thread_operators.cuh" +#include "../../grid/grid_queue.cuh" +#include "../../util_device.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/****************************************************************************** + * Kernel entry points + *****************************************************************************/ + +/** + * Select kernel entry point (multi-block) + * + * Performs functor-based selection if SelectOp functor type != NullType + * Otherwise performs flag-based selection if FlagIterator's value type != NullType + * Otherwise performs discontinuity selection (keep unique) + */ +template < + typename AgentRlePolicyT, ///< Parameterized AgentRlePolicyT tuning policy type + typename InputIteratorT, ///< Random-access input iterator type for reading input items \iterator + typename OffsetsOutputIteratorT, ///< Random-access output iterator type for writing run-offset values \iterator + typename LengthsOutputIteratorT, ///< Random-access output iterator type for writing run-length values \iterator + typename NumRunsOutputIteratorT, ///< Output iterator type for recording the number of runs encountered \iterator + typename ScanTileStateT, ///< Tile status interface type + typename EqualityOpT, ///< T equality operator type + typename OffsetT> ///< Signed integer type for global offsets +__launch_bounds__ (int(AgentRlePolicyT::BLOCK_THREADS)) +__global__ void DeviceRleSweepKernel( + InputIteratorT d_in, ///< [in] Pointer to input sequence of data items + OffsetsOutputIteratorT d_offsets_out, ///< [out] Pointer to output sequence of run-offsets + LengthsOutputIteratorT d_lengths_out, ///< [out] Pointer to output sequence of run-lengths + NumRunsOutputIteratorT d_num_runs_out, ///< [out] Pointer to total number of runs (i.e., length of \p d_offsets_out) + ScanTileStateT tile_status, ///< [in] Tile status interface + EqualityOpT equality_op, ///< [in] Equality operator for input items + OffsetT num_items, ///< [in] Total number of input items (i.e., length of \p d_in) + int num_tiles) ///< [in] Total number of tiles for the entire problem +{ + // Thread block type for selecting data from input tiles + typedef AgentRle< + AgentRlePolicyT, + InputIteratorT, + OffsetsOutputIteratorT, + LengthsOutputIteratorT, + EqualityOpT, + OffsetT> AgentRleT; + + // Shared memory for AgentRle + __shared__ typename AgentRleT::TempStorage temp_storage; + + // Process tiles + AgentRleT(temp_storage, d_in, d_offsets_out, d_lengths_out, equality_op, num_items).ConsumeRange( + num_tiles, + tile_status, + d_num_runs_out); +} + + + + +/****************************************************************************** + * Dispatch + ******************************************************************************/ + +/** + * Utility class for dispatching the appropriately-tuned kernels for DeviceRle + */ +template < + typename InputIteratorT, ///< Random-access input iterator type for reading input items \iterator + typename OffsetsOutputIteratorT, ///< Random-access output iterator type for writing run-offset values \iterator + typename LengthsOutputIteratorT, ///< Random-access output iterator type for writing run-length values \iterator + typename NumRunsOutputIteratorT, ///< Output iterator type for recording the number of runs encountered \iterator + typename EqualityOpT, ///< T equality operator type + typename OffsetT> ///< Signed integer type for global offsets +struct DeviceRleDispatch +{ + /****************************************************************************** + * Types and constants + ******************************************************************************/ + + // The input value type + typedef typename std::iterator_traits::value_type T; + + // The lengths output value type + typedef typename If<(Equals::value_type, void>::VALUE), // LengthT = (if output iterator's value type is void) ? + OffsetT, // ... then the OffsetT type, + typename std::iterator_traits::value_type>::Type LengthT; // ... else the output iterator's value type + + enum + { + INIT_KERNEL_THREADS = 128, + }; + + // Tile status descriptor interface type + typedef ReduceByKeyScanTileState ScanTileStateT; + + + /****************************************************************************** + * Tuning policies + ******************************************************************************/ + + /// SM35 + struct Policy350 + { + enum { + NOMINAL_4B_ITEMS_PER_THREAD = 15, + ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(T)))), + }; + + typedef AgentRlePolicy< + 96, + ITEMS_PER_THREAD, + BLOCK_LOAD_DIRECT, + LOAD_LDG, + true, + BLOCK_SCAN_WARP_SCANS> + RleSweepPolicy; + }; + + /// SM30 + struct Policy300 + { + enum { + NOMINAL_4B_ITEMS_PER_THREAD = 5, + ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(T)))), + }; + + typedef AgentRlePolicy< + 256, + ITEMS_PER_THREAD, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + true, + BLOCK_SCAN_RAKING_MEMOIZE> + RleSweepPolicy; + }; + + /// SM20 + struct Policy200 + { + enum { + NOMINAL_4B_ITEMS_PER_THREAD = 15, + ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(T)))), + }; + + typedef AgentRlePolicy< + 128, + ITEMS_PER_THREAD, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + false, + BLOCK_SCAN_WARP_SCANS> + RleSweepPolicy; + }; + + /// SM13 + struct Policy130 + { + enum { + NOMINAL_4B_ITEMS_PER_THREAD = 9, + ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(T)))), + }; + + typedef AgentRlePolicy< + 64, + ITEMS_PER_THREAD, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + true, + BLOCK_SCAN_RAKING_MEMOIZE> + RleSweepPolicy; + }; + + /// SM10 + struct Policy100 + { + enum { + NOMINAL_4B_ITEMS_PER_THREAD = 9, + ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(T)))), + }; + + typedef AgentRlePolicy< + 256, + ITEMS_PER_THREAD, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + true, + BLOCK_SCAN_RAKING_MEMOIZE> + RleSweepPolicy; + }; + + + /****************************************************************************** + * Tuning policies of current PTX compiler pass + ******************************************************************************/ + +#if (CUB_PTX_ARCH >= 350) + typedef Policy350 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 300) + typedef Policy300 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 200) + typedef Policy200 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 130) + typedef Policy130 PtxPolicy; + +#else + typedef Policy100 PtxPolicy; + +#endif + + // "Opaque" policies (whose parameterizations aren't reflected in the type signature) + struct PtxRleSweepPolicy : PtxPolicy::RleSweepPolicy {}; + + + /****************************************************************************** + * Utilities + ******************************************************************************/ + + /** + * Initialize kernel dispatch configurations with the policies corresponding to the PTX assembly we will use + */ + template + CUB_RUNTIME_FUNCTION __forceinline__ + static void InitConfigs( + int ptx_version, + KernelConfig& device_rle_config) + { + #if (CUB_PTX_ARCH > 0) + + // We're on the device, so initialize the kernel dispatch configurations with the current PTX policy + device_rle_config.template Init(); + + #else + + // We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version + if (ptx_version >= 350) + { + device_rle_config.template Init(); + } + else if (ptx_version >= 300) + { + device_rle_config.template Init(); + } + else if (ptx_version >= 200) + { + device_rle_config.template Init(); + } + else if (ptx_version >= 130) + { + device_rle_config.template Init(); + } + else + { + device_rle_config.template Init(); + } + + #endif + } + + + /** + * Kernel kernel dispatch configuration. Mirrors the constants within AgentRlePolicyT. + */ + struct KernelConfig + { + int block_threads; + int items_per_thread; + BlockLoadAlgorithm load_policy; + bool store_warp_time_slicing; + BlockScanAlgorithm scan_algorithm; + + template + CUB_RUNTIME_FUNCTION __forceinline__ + void Init() + { + block_threads = AgentRlePolicyT::BLOCK_THREADS; + items_per_thread = AgentRlePolicyT::ITEMS_PER_THREAD; + load_policy = AgentRlePolicyT::LOAD_ALGORITHM; + store_warp_time_slicing = AgentRlePolicyT::STORE_WARP_TIME_SLICING; + scan_algorithm = AgentRlePolicyT::SCAN_ALGORITHM; + } + + CUB_RUNTIME_FUNCTION __forceinline__ + void Print() + { + printf("%d, %d, %d, %d, %d", + block_threads, + items_per_thread, + load_policy, + store_warp_time_slicing, + scan_algorithm); + } + }; + + + /****************************************************************************** + * Dispatch entrypoints + ******************************************************************************/ + + /** + * Internal dispatch routine for computing a device-wide run-length-encode using the + * specified kernel functions. + */ + template < + typename DeviceScanInitKernelPtr, ///< Function type of cub::DeviceScanInitKernel + typename DeviceRleSweepKernelPtr> ///< Function type of cub::DeviceRleSweepKernelPtr + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Dispatch( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OffsetsOutputIteratorT d_offsets_out, ///< [out] Pointer to the output sequence of run-offsets + LengthsOutputIteratorT d_lengths_out, ///< [out] Pointer to the output sequence of run-lengths + NumRunsOutputIteratorT d_num_runs_out, ///< [out] Pointer to the total number of runs encountered (i.e., length of \p d_offsets_out) + EqualityOpT equality_op, ///< [in] Equality operator for input items + OffsetT num_items, ///< [in] Total number of input items (i.e., length of \p d_in) + cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous, ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + int ptx_version, ///< [in] PTX version of dispatch kernels + DeviceScanInitKernelPtr device_scan_init_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceScanInitKernel + DeviceRleSweepKernelPtr device_rle_sweep_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceRleSweepKernel + KernelConfig device_rle_config) ///< [in] Dispatch parameters that match the policy that \p device_rle_sweep_kernel was compiled for + { + +#ifndef CUB_RUNTIME_ENABLED + + // Kernel launch not supported from this device + return CubDebug(cudaErrorNotSupported); + +#else + + cudaError error = cudaSuccess; + do + { + // Get device ordinal + int device_ordinal; + if (CubDebug(error = cudaGetDevice(&device_ordinal))) break; + + // Get SM count + int sm_count; + if (CubDebug(error = cudaDeviceGetAttribute (&sm_count, cudaDevAttrMultiProcessorCount, device_ordinal))) break; + + // Number of input tiles + int tile_size = device_rle_config.block_threads * device_rle_config.items_per_thread; + int num_tiles = (num_items + tile_size - 1) / tile_size; + + // Specify temporary storage allocation requirements + size_t allocation_sizes[1]; + if (CubDebug(error = ScanTileStateT::AllocationSize(num_tiles, allocation_sizes[0]))) break; // bytes needed for tile status descriptors + + // Compute allocation pointers into the single storage blob (or compute the necessary size of the blob) + void* allocations[1]; + if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break; + if (d_temp_storage == NULL) + { + // Return if the caller is simply requesting the size of the storage allocation + break; + } + + // Construct the tile status interface + ScanTileStateT tile_status; + if (CubDebug(error = tile_status.Init(num_tiles, allocations[0], allocation_sizes[0]))) break; + + // Log device_scan_init_kernel configuration + int init_grid_size = CUB_MAX(1, (num_tiles + INIT_KERNEL_THREADS - 1) / INIT_KERNEL_THREADS); + if (debug_synchronous) _CubLog("Invoking device_scan_init_kernel<<<%d, %d, 0, %lld>>>()\n", init_grid_size, INIT_KERNEL_THREADS, (long long) stream); + + // Invoke device_scan_init_kernel to initialize tile descriptors and queue descriptors + device_scan_init_kernel<<>>( + tile_status, + num_tiles, + d_num_runs_out); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + + // Return if empty problem + if (num_items == 0) + break; + + // Get SM occupancy for device_rle_sweep_kernel + int device_rle_kernel_sm_occupancy; + if (CubDebug(error = MaxSmOccupancy( + device_rle_kernel_sm_occupancy, // out + device_rle_sweep_kernel, + device_rle_config.block_threads))) break; + + // Get max x-dimension of grid + int max_dim_x; + if (CubDebug(error = cudaDeviceGetAttribute(&max_dim_x, cudaDevAttrMaxGridDimX, device_ordinal))) break;; + + // Get grid size for scanning tiles + dim3 scan_grid_size; + scan_grid_size.z = 1; + scan_grid_size.y = ((unsigned int) num_tiles + max_dim_x - 1) / max_dim_x; + scan_grid_size.x = CUB_MIN(num_tiles, max_dim_x); + + // Log device_rle_sweep_kernel configuration + if (debug_synchronous) _CubLog("Invoking device_rle_sweep_kernel<<<{%d,%d,%d}, %d, 0, %lld>>>(), %d items per thread, %d SM occupancy\n", + scan_grid_size.x, scan_grid_size.y, scan_grid_size.z, device_rle_config.block_threads, (long long) stream, device_rle_config.items_per_thread, device_rle_kernel_sm_occupancy); + + // Invoke device_rle_sweep_kernel + device_rle_sweep_kernel<<>>( + d_in, + d_offsets_out, + d_lengths_out, + d_num_runs_out, + tile_status, + equality_op, + num_items, + num_tiles); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + + } + while (0); + + return error; + +#endif // CUB_RUNTIME_ENABLED + } + + + /** + * Internal dispatch routine + */ + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Dispatch( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to input sequence of data items + OffsetsOutputIteratorT d_offsets_out, ///< [out] Pointer to output sequence of run-offsets + LengthsOutputIteratorT d_lengths_out, ///< [out] Pointer to output sequence of run-lengths + NumRunsOutputIteratorT d_num_runs_out, ///< [out] Pointer to total number of runs (i.e., length of \p d_offsets_out) + EqualityOpT equality_op, ///< [in] Equality operator for input items + OffsetT num_items, ///< [in] Total number of input items (i.e., length of \p d_in) + cudaStream_t stream, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + cudaError error = cudaSuccess; + do + { + // Get PTX version + int ptx_version; + #if (CUB_PTX_ARCH == 0) + if (CubDebug(error = PtxVersion(ptx_version))) break; + #else + ptx_version = CUB_PTX_ARCH; + #endif + + // Get kernel kernel dispatch configurations + KernelConfig device_rle_config; + InitConfigs(ptx_version, device_rle_config); + + // Dispatch + if (CubDebug(error = Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_offsets_out, + d_lengths_out, + d_num_runs_out, + equality_op, + num_items, + stream, + debug_synchronous, + ptx_version, + DeviceCompactInitKernel, + DeviceRleSweepKernel, + device_rle_config))) break; + } + while (0); + + return error; + } +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/device/dispatch/dispatch_scan.cuh b/fastertransformer/cuda/cub/device/dispatch/dispatch_scan.cuh new file mode 100644 index 000000000..3ef720a44 --- /dev/null +++ b/fastertransformer/cuda/cub/device/dispatch/dispatch_scan.cuh @@ -0,0 +1,563 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DeviceScan provides device-wide, parallel operations for computing a prefix scan across a sequence of data items residing within device-accessible memory. + */ + +#pragma once + +#include +#include + +#include "../../agent/agent_scan.cuh" +#include "../../thread/thread_operators.cuh" +#include "../../grid/grid_queue.cuh" +#include "../../util_arch.cuh" +#include "../../util_debug.cuh" +#include "../../util_device.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/****************************************************************************** + * Kernel entry points + *****************************************************************************/ + +/** + * Initialization kernel for tile status initialization (multi-block) + */ +template < + typename ScanTileStateT> ///< Tile status interface type +__global__ void DeviceScanInitKernel( + ScanTileStateT tile_state, ///< [in] Tile status interface + int num_tiles) ///< [in] Number of tiles +{ + // Initialize tile status + tile_state.InitializeStatus(num_tiles); +} + +/** + * Initialization kernel for tile status initialization (multi-block) + */ +template < + typename ScanTileStateT, ///< Tile status interface type + typename NumSelectedIteratorT> ///< Output iterator type for recording the number of items selected +__global__ void DeviceCompactInitKernel( + ScanTileStateT tile_state, ///< [in] Tile status interface + int num_tiles, ///< [in] Number of tiles + NumSelectedIteratorT d_num_selected_out) ///< [out] Pointer to the total number of items selected (i.e., length of \p d_selected_out) +{ + // Initialize tile status + tile_state.InitializeStatus(num_tiles); + + // Initialize d_num_selected_out + if ((blockIdx.x == 0) && (threadIdx.x == 0)) + *d_num_selected_out = 0; +} + + +/** + * Scan kernel entry point (multi-block) + */ +template < + typename ScanPolicyT, ///< Parameterized ScanPolicyT tuning policy type + typename InputIteratorT, ///< Random-access input iterator type for reading scan inputs \iterator + typename OutputIteratorT, ///< Random-access output iterator type for writing scan outputs \iterator + typename ScanTileStateT, ///< Tile status interface type + typename ScanOpT, ///< Binary scan functor type having member T operator()(const T &a, const T &b) + typename InitValueT, ///< Initial value to seed the exclusive scan (cub::NullType for inclusive scans) + typename OffsetT> ///< Signed integer type for global offsets +__launch_bounds__ (int(ScanPolicyT::BLOCK_THREADS)) +__global__ void DeviceScanKernel( + InputIteratorT d_in, ///< Input data + OutputIteratorT d_out, ///< Output data + ScanTileStateT tile_state, ///< Tile status interface + int start_tile, ///< The starting tile for the current grid + ScanOpT scan_op, ///< Binary scan functor + InitValueT init_value, ///< Initial value to seed the exclusive scan + OffsetT num_items) ///< Total number of scan items for the entire problem +{ + // Thread block type for scanning input tiles + typedef AgentScan< + ScanPolicyT, + InputIteratorT, + OutputIteratorT, + ScanOpT, + InitValueT, + OffsetT> AgentScanT; + + // Shared memory for AgentScan + __shared__ typename AgentScanT::TempStorage temp_storage; + + // Process tiles + AgentScanT(temp_storage, d_in, d_out, scan_op, init_value).ConsumeRange( + num_items, + tile_state, + start_tile); +} + + + + +/****************************************************************************** + * Dispatch + ******************************************************************************/ + + +/** + * Utility class for dispatching the appropriately-tuned kernels for DeviceScan + */ +template < + typename InputIteratorT, ///< Random-access input iterator type for reading scan inputs \iterator + typename OutputIteratorT, ///< Random-access output iterator type for writing scan outputs \iterator + typename ScanOpT, ///< Binary scan functor type having member T operator()(const T &a, const T &b) + typename InitValueT, ///< The init_value element type for ScanOpT (cub::NullType for inclusive scans) + typename OffsetT> ///< Signed integer type for global offsets +struct DispatchScan +{ + //--------------------------------------------------------------------- + // Constants and Types + //--------------------------------------------------------------------- + + enum + { + INIT_KERNEL_THREADS = 128 + }; + + // The output value type + typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? + typename std::iterator_traits::value_type, // ... then the input iterator's value type, + typename std::iterator_traits::value_type>::Type OutputT; // ... else the output iterator's value type + + // Tile status descriptor interface type + typedef ScanTileState ScanTileStateT; + + + //--------------------------------------------------------------------- + // Tuning policies + //--------------------------------------------------------------------- + + /// SM600 + struct Policy600 + { + typedef AgentScanPolicy< + CUB_SCALED_GRANULARITIES(128, 15, OutputT), ///< Threads per block, items per thread + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_STORE_TRANSPOSE, + BLOCK_SCAN_WARP_SCANS> + ScanPolicyT; + }; + + + /// SM520 + struct Policy520 + { + // Titan X: 32.47B items/s @ 48M 32-bit T + typedef AgentScanPolicy< + CUB_SCALED_GRANULARITIES(128, 12, OutputT), ///< Threads per block, items per thread + BLOCK_LOAD_DIRECT, + LOAD_LDG, + BLOCK_STORE_WARP_TRANSPOSE, + BLOCK_SCAN_WARP_SCANS> + ScanPolicyT; + }; + + + /// SM35 + struct Policy350 + { + // GTX Titan: 29.5B items/s (232.4 GB/s) @ 48M 32-bit T + typedef AgentScanPolicy< + CUB_SCALED_GRANULARITIES(128, 12, OutputT), ///< Threads per block, items per thread + BLOCK_LOAD_DIRECT, + LOAD_LDG, + BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED, + BLOCK_SCAN_RAKING> + ScanPolicyT; + }; + + /// SM30 + struct Policy300 + { + typedef AgentScanPolicy< + CUB_SCALED_GRANULARITIES(256, 9, OutputT), ///< Threads per block, items per thread + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_STORE_WARP_TRANSPOSE, + BLOCK_SCAN_WARP_SCANS> + ScanPolicyT; + }; + + /// SM20 + struct Policy200 + { + // GTX 580: 20.3B items/s (162.3 GB/s) @ 48M 32-bit T + typedef AgentScanPolicy< + CUB_SCALED_GRANULARITIES(128, 12, OutputT), ///< Threads per block, items per thread + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_STORE_WARP_TRANSPOSE, + BLOCK_SCAN_WARP_SCANS> + ScanPolicyT; + }; + + /// SM13 + struct Policy130 + { + typedef AgentScanPolicy< + CUB_SCALED_GRANULARITIES(96, 21, OutputT), ///< Threads per block, items per thread + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_STORE_WARP_TRANSPOSE, + BLOCK_SCAN_RAKING_MEMOIZE> + ScanPolicyT; + }; + + /// SM10 + struct Policy100 + { + typedef AgentScanPolicy< + CUB_SCALED_GRANULARITIES(64, 9, OutputT), ///< Threads per block, items per thread + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_STORE_WARP_TRANSPOSE, + BLOCK_SCAN_WARP_SCANS> + ScanPolicyT; + }; + + + //--------------------------------------------------------------------- + // Tuning policies of current PTX compiler pass + //--------------------------------------------------------------------- + +#if (CUB_PTX_ARCH >= 600) + typedef Policy600 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 520) + typedef Policy520 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 350) + typedef Policy350 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 300) + typedef Policy300 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 200) + typedef Policy200 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 130) + typedef Policy130 PtxPolicy; + +#else + typedef Policy100 PtxPolicy; + +#endif + + // "Opaque" policies (whose parameterizations aren't reflected in the type signature) + struct PtxAgentScanPolicy : PtxPolicy::ScanPolicyT {}; + + + //--------------------------------------------------------------------- + // Utilities + //--------------------------------------------------------------------- + + /** + * Initialize kernel dispatch configurations with the policies corresponding to the PTX assembly we will use + */ + template + CUB_RUNTIME_FUNCTION __forceinline__ + static void InitConfigs( + int ptx_version, + KernelConfig &scan_kernel_config) + { + #if (CUB_PTX_ARCH > 0) + (void)ptx_version; + + // We're on the device, so initialize the kernel dispatch configurations with the current PTX policy + scan_kernel_config.template Init(); + + #else + + // We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version + if (ptx_version >= 600) + { + scan_kernel_config.template Init(); + } + else if (ptx_version >= 520) + { + scan_kernel_config.template Init(); + } + else if (ptx_version >= 350) + { + scan_kernel_config.template Init(); + } + else if (ptx_version >= 300) + { + scan_kernel_config.template Init(); + } + else if (ptx_version >= 200) + { + scan_kernel_config.template Init(); + } + else if (ptx_version >= 130) + { + scan_kernel_config.template Init(); + } + else + { + scan_kernel_config.template Init(); + } + + #endif + } + + + /** + * Kernel kernel dispatch configuration. + */ + struct KernelConfig + { + int block_threads; + int items_per_thread; + int tile_items; + + template + CUB_RUNTIME_FUNCTION __forceinline__ + void Init() + { + block_threads = PolicyT::BLOCK_THREADS; + items_per_thread = PolicyT::ITEMS_PER_THREAD; + tile_items = block_threads * items_per_thread; + } + }; + + + //--------------------------------------------------------------------- + // Dispatch entrypoints + //--------------------------------------------------------------------- + + /** + * Internal dispatch routine for computing a device-wide prefix scan using the + * specified kernel functions. + */ + template < + typename ScanInitKernelPtrT, ///< Function type of cub::DeviceScanInitKernel + typename ScanSweepKernelPtrT> ///< Function type of cub::DeviceScanKernelPtrT + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Dispatch( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output sequence of data items + ScanOpT scan_op, ///< [in] Binary scan functor + InitValueT init_value, ///< [in] Initial value to seed the exclusive scan + OffsetT num_items, ///< [in] Total number of input items (i.e., the length of \p d_in) + cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous, ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + int /*ptx_version*/, ///< [in] PTX version of dispatch kernels + ScanInitKernelPtrT init_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceScanInitKernel + ScanSweepKernelPtrT scan_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceScanKernel + KernelConfig scan_kernel_config) ///< [in] Dispatch parameters that match the policy that \p scan_kernel was compiled for + { + +#ifndef CUB_RUNTIME_ENABLED + (void)d_temp_storage; + (void)temp_storage_bytes; + (void)d_in; + (void)d_out; + (void)scan_op; + (void)init_value; + (void)num_items; + (void)stream; + (void)debug_synchronous; + (void)init_kernel; + (void)scan_kernel; + (void)scan_kernel_config; + + // Kernel launch not supported from this device + return CubDebug(cudaErrorNotSupported); + +#else + cudaError error = cudaSuccess; + do + { + // Get device ordinal + int device_ordinal; + if (CubDebug(error = cudaGetDevice(&device_ordinal))) break; + + // Get SM count + int sm_count; + if (CubDebug(error = cudaDeviceGetAttribute (&sm_count, cudaDevAttrMultiProcessorCount, device_ordinal))) break; + + // Number of input tiles + int tile_size = scan_kernel_config.block_threads * scan_kernel_config.items_per_thread; + int num_tiles = (num_items + tile_size - 1) / tile_size; + + // Specify temporary storage allocation requirements + size_t allocation_sizes[1]; + if (CubDebug(error = ScanTileStateT::AllocationSize(num_tiles, allocation_sizes[0]))) break; // bytes needed for tile status descriptors + + // Compute allocation pointers into the single storage blob (or compute the necessary size of the blob) + void* allocations[1]; + if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break; + if (d_temp_storage == NULL) + { + // Return if the caller is simply requesting the size of the storage allocation + break; + } + + // Return if empty problem + if (num_items == 0) + break; + + // Construct the tile status interface + ScanTileStateT tile_state; + if (CubDebug(error = tile_state.Init(num_tiles, allocations[0], allocation_sizes[0]))) break; + + // Log init_kernel configuration + int init_grid_size = (num_tiles + INIT_KERNEL_THREADS - 1) / INIT_KERNEL_THREADS; + if (debug_synchronous) _CubLog("Invoking init_kernel<<<%d, %d, 0, %lld>>>()\n", init_grid_size, INIT_KERNEL_THREADS, (long long) stream); + + // Invoke init_kernel to initialize tile descriptors + init_kernel<<>>( + tile_state, + num_tiles); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + + // Get SM occupancy for scan_kernel + int scan_sm_occupancy; + if (CubDebug(error = MaxSmOccupancy( + scan_sm_occupancy, // out + scan_kernel, + scan_kernel_config.block_threads))) break; + + // Get max x-dimension of grid + int max_dim_x; + if (CubDebug(error = cudaDeviceGetAttribute(&max_dim_x, cudaDevAttrMaxGridDimX, device_ordinal))) break;; + + // Run grids in epochs (in case number of tiles exceeds max x-dimension + int scan_grid_size = CUB_MIN(num_tiles, max_dim_x); + for (int start_tile = 0; start_tile < num_tiles; start_tile += scan_grid_size) + { + // Log scan_kernel configuration + if (debug_synchronous) _CubLog("Invoking %d scan_kernel<<<%d, %d, 0, %lld>>>(), %d items per thread, %d SM occupancy\n", + start_tile, scan_grid_size, scan_kernel_config.block_threads, (long long) stream, scan_kernel_config.items_per_thread, scan_sm_occupancy); + + // Invoke scan_kernel + scan_kernel<<>>( + d_in, + d_out, + tile_state, + start_tile, + scan_op, + init_value, + num_items); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + } + } + while (0); + + return error; + +#endif // CUB_RUNTIME_ENABLED + } + + + /** + * Internal dispatch routine + */ + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Dispatch( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output sequence of data items + ScanOpT scan_op, ///< [in] Binary scan functor + InitValueT init_value, ///< [in] Initial value to seed the exclusive scan + OffsetT num_items, ///< [in] Total number of input items (i.e., the length of \p d_in) + cudaStream_t stream, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + cudaError error = cudaSuccess; + do + { + // Get PTX version + int ptx_version; + if (CubDebug(error = PtxVersion(ptx_version))) break; + + // Get kernel kernel dispatch configurations + KernelConfig scan_kernel_config; + InitConfigs(ptx_version, scan_kernel_config); + + // Dispatch + if (CubDebug(error = Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + scan_op, + init_value, + num_items, + stream, + debug_synchronous, + ptx_version, + DeviceScanInitKernel, + DeviceScanKernel, + scan_kernel_config))) break; + } + while (0); + + return error; + } +}; + + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/device/dispatch/dispatch_select_if.cuh b/fastertransformer/cuda/cub/device/dispatch/dispatch_select_if.cuh new file mode 100644 index 000000000..60b331338 --- /dev/null +++ b/fastertransformer/cuda/cub/device/dispatch/dispatch_select_if.cuh @@ -0,0 +1,542 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DeviceSelect provides device-wide, parallel operations for selecting items from sequences of data items residing within device-accessible memory. + */ + +#pragma once + +#include +#include + +#include "dispatch_scan.cuh" +#include "../../agent/agent_select_if.cuh" +#include "../../thread/thread_operators.cuh" +#include "../../grid/grid_queue.cuh" +#include "../../util_device.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/****************************************************************************** + * Kernel entry points + *****************************************************************************/ + +/** + * Select kernel entry point (multi-block) + * + * Performs functor-based selection if SelectOpT functor type != NullType + * Otherwise performs flag-based selection if FlagsInputIterator's value type != NullType + * Otherwise performs discontinuity selection (keep unique) + */ +template < + typename AgentSelectIfPolicyT, ///< Parameterized AgentSelectIfPolicyT tuning policy type + typename InputIteratorT, ///< Random-access input iterator type for reading input items + typename FlagsInputIteratorT, ///< Random-access input iterator type for reading selection flags (NullType* if a selection functor or discontinuity flagging is to be used for selection) + typename SelectedOutputIteratorT, ///< Random-access output iterator type for writing selected items + typename NumSelectedIteratorT, ///< Output iterator type for recording the number of items selected + typename ScanTileStateT, ///< Tile status interface type + typename SelectOpT, ///< Selection operator type (NullType if selection flags or discontinuity flagging is to be used for selection) + typename EqualityOpT, ///< Equality operator type (NullType if selection functor or selection flags is to be used for selection) + typename OffsetT, ///< Signed integer type for global offsets + bool KEEP_REJECTS> ///< Whether or not we push rejected items to the back of the output +__launch_bounds__ (int(AgentSelectIfPolicyT::BLOCK_THREADS)) +__global__ void DeviceSelectSweepKernel( + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + FlagsInputIteratorT d_flags, ///< [in] Pointer to the input sequence of selection flags (if applicable) + SelectedOutputIteratorT d_selected_out, ///< [out] Pointer to the output sequence of selected data items + NumSelectedIteratorT d_num_selected_out, ///< [out] Pointer to the total number of items selected (i.e., length of \p d_selected_out) + ScanTileStateT tile_status, ///< [in] Tile status interface + SelectOpT select_op, ///< [in] Selection operator + EqualityOpT equality_op, ///< [in] Equality operator + OffsetT num_items, ///< [in] Total number of input items (i.e., length of \p d_in) + int num_tiles) ///< [in] Total number of tiles for the entire problem +{ + // Thread block type for selecting data from input tiles + typedef AgentSelectIf< + AgentSelectIfPolicyT, + InputIteratorT, + FlagsInputIteratorT, + SelectedOutputIteratorT, + SelectOpT, + EqualityOpT, + OffsetT, + KEEP_REJECTS> AgentSelectIfT; + + // Shared memory for AgentSelectIf + __shared__ typename AgentSelectIfT::TempStorage temp_storage; + + // Process tiles + AgentSelectIfT(temp_storage, d_in, d_flags, d_selected_out, select_op, equality_op, num_items).ConsumeRange( + num_tiles, + tile_status, + d_num_selected_out); +} + + + + +/****************************************************************************** + * Dispatch + ******************************************************************************/ + +/** + * Utility class for dispatching the appropriately-tuned kernels for DeviceSelect + */ +template < + typename InputIteratorT, ///< Random-access input iterator type for reading input items + typename FlagsInputIteratorT, ///< Random-access input iterator type for reading selection flags (NullType* if a selection functor or discontinuity flagging is to be used for selection) + typename SelectedOutputIteratorT, ///< Random-access output iterator type for writing selected items + typename NumSelectedIteratorT, ///< Output iterator type for recording the number of items selected + typename SelectOpT, ///< Selection operator type (NullType if selection flags or discontinuity flagging is to be used for selection) + typename EqualityOpT, ///< Equality operator type (NullType if selection functor or selection flags is to be used for selection) + typename OffsetT, ///< Signed integer type for global offsets + bool KEEP_REJECTS> ///< Whether or not we push rejected items to the back of the output +struct DispatchSelectIf +{ + /****************************************************************************** + * Types and constants + ******************************************************************************/ + + // The output value type + typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? + typename std::iterator_traits::value_type, // ... then the input iterator's value type, + typename std::iterator_traits::value_type>::Type OutputT; // ... else the output iterator's value type + + // The flag value type + typedef typename std::iterator_traits::value_type FlagT; + + enum + { + INIT_KERNEL_THREADS = 128, + }; + + // Tile status descriptor interface type + typedef ScanTileState ScanTileStateT; + + + /****************************************************************************** + * Tuning policies + ******************************************************************************/ + + /// SM35 + struct Policy350 + { + enum { + NOMINAL_4B_ITEMS_PER_THREAD = 10, + ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(OutputT)))), + }; + + typedef AgentSelectIfPolicy< + 128, + ITEMS_PER_THREAD, + BLOCK_LOAD_DIRECT, + LOAD_LDG, + BLOCK_SCAN_WARP_SCANS> + SelectIfPolicyT; + }; + + /// SM30 + struct Policy300 + { + enum { + NOMINAL_4B_ITEMS_PER_THREAD = 7, + ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(3, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(OutputT)))), + }; + + typedef AgentSelectIfPolicy< + 128, + ITEMS_PER_THREAD, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_SCAN_WARP_SCANS> + SelectIfPolicyT; + }; + + /// SM20 + struct Policy200 + { + enum { + NOMINAL_4B_ITEMS_PER_THREAD = (KEEP_REJECTS) ? 7 : 15, + ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(OutputT)))), + }; + + typedef AgentSelectIfPolicy< + 128, + ITEMS_PER_THREAD, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_SCAN_WARP_SCANS> + SelectIfPolicyT; + }; + + /// SM13 + struct Policy130 + { + enum { + NOMINAL_4B_ITEMS_PER_THREAD = 9, + ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(OutputT)))), + }; + + typedef AgentSelectIfPolicy< + 64, + ITEMS_PER_THREAD, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_SCAN_RAKING_MEMOIZE> + SelectIfPolicyT; + }; + + /// SM10 + struct Policy100 + { + enum { + NOMINAL_4B_ITEMS_PER_THREAD = 9, + ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(OutputT)))), + }; + + typedef AgentSelectIfPolicy< + 64, + ITEMS_PER_THREAD, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_SCAN_RAKING> + SelectIfPolicyT; + }; + + + /****************************************************************************** + * Tuning policies of current PTX compiler pass + ******************************************************************************/ + +#if (CUB_PTX_ARCH >= 350) + typedef Policy350 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 300) + typedef Policy300 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 200) + typedef Policy200 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 130) + typedef Policy130 PtxPolicy; + +#else + typedef Policy100 PtxPolicy; + +#endif + + // "Opaque" policies (whose parameterizations aren't reflected in the type signature) + struct PtxSelectIfPolicyT : PtxPolicy::SelectIfPolicyT {}; + + + /****************************************************************************** + * Utilities + ******************************************************************************/ + + /** + * Initialize kernel dispatch configurations with the policies corresponding to the PTX assembly we will use + */ + template + CUB_RUNTIME_FUNCTION __forceinline__ + static void InitConfigs( + int ptx_version, + KernelConfig &select_if_config) + { + #if (CUB_PTX_ARCH > 0) + (void)ptx_version; + + // We're on the device, so initialize the kernel dispatch configurations with the current PTX policy + select_if_config.template Init(); + + #else + + // We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version + if (ptx_version >= 350) + { + select_if_config.template Init(); + } + else if (ptx_version >= 300) + { + select_if_config.template Init(); + } + else if (ptx_version >= 200) + { + select_if_config.template Init(); + } + else if (ptx_version >= 130) + { + select_if_config.template Init(); + } + else + { + select_if_config.template Init(); + } + + #endif + } + + + /** + * Kernel kernel dispatch configuration. + */ + struct KernelConfig + { + int block_threads; + int items_per_thread; + int tile_items; + + template + CUB_RUNTIME_FUNCTION __forceinline__ + void Init() + { + block_threads = PolicyT::BLOCK_THREADS; + items_per_thread = PolicyT::ITEMS_PER_THREAD; + tile_items = block_threads * items_per_thread; + } + }; + + + /****************************************************************************** + * Dispatch entrypoints + ******************************************************************************/ + + /** + * Internal dispatch routine for computing a device-wide selection using the + * specified kernel functions. + */ + template < + typename ScanInitKernelPtrT, ///< Function type of cub::DeviceScanInitKernel + typename SelectIfKernelPtrT> ///< Function type of cub::SelectIfKernelPtrT + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Dispatch( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + FlagsInputIteratorT d_flags, ///< [in] Pointer to the input sequence of selection flags (if applicable) + SelectedOutputIteratorT d_selected_out, ///< [in] Pointer to the output sequence of selected data items + NumSelectedIteratorT d_num_selected_out, ///< [in] Pointer to the total number of items selected (i.e., length of \p d_selected_out) + SelectOpT select_op, ///< [in] Selection operator + EqualityOpT equality_op, ///< [in] Equality operator + OffsetT num_items, ///< [in] Total number of input items (i.e., length of \p d_in) + cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous, ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + int /*ptx_version*/, ///< [in] PTX version of dispatch kernels + ScanInitKernelPtrT scan_init_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceScanInitKernel + SelectIfKernelPtrT select_if_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceSelectSweepKernel + KernelConfig select_if_config) ///< [in] Dispatch parameters that match the policy that \p select_if_kernel was compiled for + { + +#ifndef CUB_RUNTIME_ENABLED + (void)d_temp_storage; + (void)temp_storage_bytes; + (void)d_in; + (void)d_flags; + (void)d_selected_out; + (void)d_num_selected_out; + (void)select_op; + (void)equality_op; + (void)num_items; + (void)stream; + (void)debug_synchronous; + (void)scan_init_kernel; + (void)select_if_kernel; + (void)select_if_config; + + // Kernel launch not supported from this device + return CubDebug(cudaErrorNotSupported); + +#else + + cudaError error = cudaSuccess; + do + { + // Get device ordinal + int device_ordinal; + if (CubDebug(error = cudaGetDevice(&device_ordinal))) break; + + // Get SM count + int sm_count; + if (CubDebug(error = cudaDeviceGetAttribute (&sm_count, cudaDevAttrMultiProcessorCount, device_ordinal))) break; + + // Number of input tiles + int tile_size = select_if_config.block_threads * select_if_config.items_per_thread; + int num_tiles = (num_items + tile_size - 1) / tile_size; + + // Specify temporary storage allocation requirements + size_t allocation_sizes[1]; + if (CubDebug(error = ScanTileStateT::AllocationSize(num_tiles, allocation_sizes[0]))) break; // bytes needed for tile status descriptors + + // Compute allocation pointers into the single storage blob (or compute the necessary size of the blob) + void* allocations[1]; + if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break; + if (d_temp_storage == NULL) + { + // Return if the caller is simply requesting the size of the storage allocation + break; + } + + // Construct the tile status interface + ScanTileStateT tile_status; + if (CubDebug(error = tile_status.Init(num_tiles, allocations[0], allocation_sizes[0]))) break; + + // Log scan_init_kernel configuration + int init_grid_size = CUB_MAX(1, (num_tiles + INIT_KERNEL_THREADS - 1) / INIT_KERNEL_THREADS); + if (debug_synchronous) _CubLog("Invoking scan_init_kernel<<<%d, %d, 0, %lld>>>()\n", init_grid_size, INIT_KERNEL_THREADS, (long long) stream); + + // Invoke scan_init_kernel to initialize tile descriptors + scan_init_kernel<<>>( + tile_status, + num_tiles, + d_num_selected_out); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + + // Return if empty problem + if (num_items == 0) + break; + + // Get SM occupancy for select_if_kernel + int range_select_sm_occupancy; + if (CubDebug(error = MaxSmOccupancy( + range_select_sm_occupancy, // out + select_if_kernel, + select_if_config.block_threads))) break; + + // Get max x-dimension of grid + int max_dim_x; + if (CubDebug(error = cudaDeviceGetAttribute(&max_dim_x, cudaDevAttrMaxGridDimX, device_ordinal))) break;; + + // Get grid size for scanning tiles + dim3 scan_grid_size; + scan_grid_size.z = 1; + scan_grid_size.y = ((unsigned int) num_tiles + max_dim_x - 1) / max_dim_x; + scan_grid_size.x = CUB_MIN(num_tiles, max_dim_x); + + // Log select_if_kernel configuration + if (debug_synchronous) _CubLog("Invoking select_if_kernel<<<{%d,%d,%d}, %d, 0, %lld>>>(), %d items per thread, %d SM occupancy\n", + scan_grid_size.x, scan_grid_size.y, scan_grid_size.z, select_if_config.block_threads, (long long) stream, select_if_config.items_per_thread, range_select_sm_occupancy); + + // Invoke select_if_kernel + select_if_kernel<<>>( + d_in, + d_flags, + d_selected_out, + d_num_selected_out, + tile_status, + select_op, + equality_op, + num_items, + num_tiles); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + } + while (0); + + return error; + +#endif // CUB_RUNTIME_ENABLED + } + + + /** + * Internal dispatch routine + */ + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Dispatch( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + FlagsInputIteratorT d_flags, ///< [in] Pointer to the input sequence of selection flags (if applicable) + SelectedOutputIteratorT d_selected_out, ///< [in] Pointer to the output sequence of selected data items + NumSelectedIteratorT d_num_selected_out, ///< [in] Pointer to the total number of items selected (i.e., length of \p d_selected_out) + SelectOpT select_op, ///< [in] Selection operator + EqualityOpT equality_op, ///< [in] Equality operator + OffsetT num_items, ///< [in] Total number of input items (i.e., length of \p d_in) + cudaStream_t stream, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + cudaError error = cudaSuccess; + do + { + // Get PTX version + int ptx_version; + #if (CUB_PTX_ARCH == 0) + if (CubDebug(error = PtxVersion(ptx_version))) break; + #else + ptx_version = CUB_PTX_ARCH; + #endif + + // Get kernel kernel dispatch configurations + KernelConfig select_if_config; + InitConfigs(ptx_version, select_if_config); + + // Dispatch + if (CubDebug(error = Dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_flags, + d_selected_out, + d_num_selected_out, + select_op, + equality_op, + num_items, + stream, + debug_synchronous, + ptx_version, + DeviceCompactInitKernel, + DeviceSelectSweepKernel, + select_if_config))) break; + } + while (0); + + return error; + } +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/device/dispatch/dispatch_spmv_orig.cuh b/fastertransformer/cuda/cub/device/dispatch/dispatch_spmv_orig.cuh new file mode 100644 index 000000000..ab9c5346d --- /dev/null +++ b/fastertransformer/cuda/cub/device/dispatch/dispatch_spmv_orig.cuh @@ -0,0 +1,834 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DeviceSpmv provides device-wide parallel operations for performing sparse-matrix * vector multiplication (SpMV). + */ + +#pragma once + +#include +#include + +#include "../../agent/single_pass_scan_operators.cuh" +#include "../../agent/agent_segment_fixup.cuh" +#include "../../agent/agent_spmv_orig.cuh" +#include "../../util_type.cuh" +#include "../../util_debug.cuh" +#include "../../util_device.cuh" +#include "../../thread/thread_search.cuh" +#include "../../grid/grid_queue.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/****************************************************************************** + * SpMV kernel entry points + *****************************************************************************/ + +/** + * Spmv search kernel. Identifies merge path starting coordinates for each tile. + */ +template < + typename AgentSpmvPolicyT, ///< Parameterized SpmvPolicy tuning policy type + typename ValueT, ///< Matrix and vector value type + typename OffsetT> ///< Signed integer type for sequence offsets +__global__ void DeviceSpmv1ColKernel( + SpmvParams spmv_params) ///< [in] SpMV input parameter bundle +{ + typedef CacheModifiedInputIterator< + AgentSpmvPolicyT::VECTOR_VALUES_LOAD_MODIFIER, + ValueT, + OffsetT> + VectorValueIteratorT; + + VectorValueIteratorT wrapped_vector_x(spmv_params.d_vector_x); + + int row_idx = (blockIdx.x * blockDim.x) + threadIdx.x; + if (row_idx < spmv_params.num_rows) + { + OffsetT end_nonzero_idx = spmv_params.d_row_end_offsets[row_idx]; + OffsetT nonzero_idx = spmv_params.d_row_end_offsets[row_idx - 1]; + + ValueT value = 0.0; + if (end_nonzero_idx != nonzero_idx) + { + value = spmv_params.d_values[nonzero_idx] * wrapped_vector_x[spmv_params.d_column_indices[nonzero_idx]]; + } + + spmv_params.d_vector_y[row_idx] = value; + } +} + + +/** + * Spmv search kernel. Identifies merge path starting coordinates for each tile. + */ +template < + typename SpmvPolicyT, ///< Parameterized SpmvPolicy tuning policy type + typename OffsetT, ///< Signed integer type for sequence offsets + typename CoordinateT, ///< Merge path coordinate type + typename SpmvParamsT> ///< SpmvParams type +__global__ void DeviceSpmvSearchKernel( + int num_merge_tiles, ///< [in] Number of SpMV merge tiles (spmv grid size) + CoordinateT* d_tile_coordinates, ///< [out] Pointer to the temporary array of tile starting coordinates + SpmvParamsT spmv_params) ///< [in] SpMV input parameter bundle +{ + /// Constants + enum + { + BLOCK_THREADS = SpmvPolicyT::BLOCK_THREADS, + ITEMS_PER_THREAD = SpmvPolicyT::ITEMS_PER_THREAD, + TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, + }; + + typedef CacheModifiedInputIterator< + SpmvPolicyT::ROW_OFFSETS_SEARCH_LOAD_MODIFIER, + OffsetT, + OffsetT> + RowOffsetsSearchIteratorT; + + // Find the starting coordinate for all tiles (plus the end coordinate of the last one) + int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; + if (tile_idx < num_merge_tiles + 1) + { + OffsetT diagonal = (tile_idx * TILE_ITEMS); + CoordinateT tile_coordinate; + CountingInputIterator nonzero_indices(0); + + // Search the merge path + MergePathSearch( + diagonal, + RowOffsetsSearchIteratorT(spmv_params.d_row_end_offsets), + nonzero_indices, + spmv_params.num_rows, + spmv_params.num_nonzeros, + tile_coordinate); + + // Output starting offset + d_tile_coordinates[tile_idx] = tile_coordinate; + } +} + + +/** + * Spmv agent entry point + */ +template < + typename SpmvPolicyT, ///< Parameterized SpmvPolicy tuning policy type + typename ScanTileStateT, ///< Tile status interface type + typename ValueT, ///< Matrix and vector value type + typename OffsetT, ///< Signed integer type for sequence offsets + typename CoordinateT, ///< Merge path coordinate type + bool HAS_ALPHA, ///< Whether the input parameter Alpha is 1 + bool HAS_BETA> ///< Whether the input parameter Beta is 0 +__launch_bounds__ (int(SpmvPolicyT::BLOCK_THREADS)) +__global__ void DeviceSpmvKernel( + SpmvParams spmv_params, ///< [in] SpMV input parameter bundle + CoordinateT* d_tile_coordinates, ///< [in] Pointer to the temporary array of tile starting coordinates + KeyValuePair* d_tile_carry_pairs, ///< [out] Pointer to the temporary array carry-out dot product row-ids, one per block + int num_tiles, ///< [in] Number of merge tiles + ScanTileStateT tile_state, ///< [in] Tile status interface for fixup reduce-by-key kernel + int num_segment_fixup_tiles) ///< [in] Number of reduce-by-key tiles (fixup grid size) +{ + // Spmv agent type specialization + typedef AgentSpmv< + SpmvPolicyT, + ValueT, + OffsetT, + HAS_ALPHA, + HAS_BETA> + AgentSpmvT; + + // Shared memory for AgentSpmv + __shared__ typename AgentSpmvT::TempStorage temp_storage; + + AgentSpmvT(temp_storage, spmv_params).ConsumeTile( + d_tile_coordinates, + d_tile_carry_pairs, + num_tiles); + + // Initialize fixup tile status + tile_state.InitializeStatus(num_segment_fixup_tiles); + +} + + +/** + * Multi-block reduce-by-key sweep kernel entry point + */ +template < + typename AgentSegmentFixupPolicyT, ///< Parameterized AgentSegmentFixupPolicy tuning policy type + typename PairsInputIteratorT, ///< Random-access input iterator type for keys + typename AggregatesOutputIteratorT, ///< Random-access output iterator type for values + typename OffsetT, ///< Signed integer type for global offsets + typename ScanTileStateT> ///< Tile status interface type +__launch_bounds__ (int(AgentSegmentFixupPolicyT::BLOCK_THREADS)) +__global__ void DeviceSegmentFixupKernel( + PairsInputIteratorT d_pairs_in, ///< [in] Pointer to the array carry-out dot product row-ids, one per spmv block + AggregatesOutputIteratorT d_aggregates_out, ///< [in,out] Output value aggregates + OffsetT num_items, ///< [in] Total number of items to select from + int num_tiles, ///< [in] Total number of tiles for the entire problem + ScanTileStateT tile_state) ///< [in] Tile status interface +{ + // Thread block type for reducing tiles of value segments + typedef AgentSegmentFixup< + AgentSegmentFixupPolicyT, + PairsInputIteratorT, + AggregatesOutputIteratorT, + cub::Equality, + cub::Sum, + OffsetT> + AgentSegmentFixupT; + + // Shared memory for AgentSegmentFixup + __shared__ typename AgentSegmentFixupT::TempStorage temp_storage; + + // Process tiles + AgentSegmentFixupT(temp_storage, d_pairs_in, d_aggregates_out, cub::Equality(), cub::Sum()).ConsumeRange( + num_items, + num_tiles, + tile_state); +} + + +/****************************************************************************** + * Dispatch + ******************************************************************************/ + +/** + * Utility class for dispatching the appropriately-tuned kernels for DeviceSpmv + */ +template < + typename ValueT, ///< Matrix and vector value type + typename OffsetT> ///< Signed integer type for global offsets +struct DispatchSpmv +{ + //--------------------------------------------------------------------- + // Constants and Types + //--------------------------------------------------------------------- + + enum + { + INIT_KERNEL_THREADS = 128 + }; + + // SpmvParams bundle type + typedef SpmvParams SpmvParamsT; + + // 2D merge path coordinate type + typedef typename CubVector::Type CoordinateT; + + // Tile status descriptor interface type + typedef ReduceByKeyScanTileState ScanTileStateT; + + // Tuple type for scanning (pairs accumulated segment-value with segment-index) + typedef KeyValuePair KeyValuePairT; + + + //--------------------------------------------------------------------- + // Tuning policies + //--------------------------------------------------------------------- + + /// SM11 + struct Policy110 + { + typedef AgentSpmvPolicy< + 128, + 1, + LOAD_DEFAULT, + LOAD_DEFAULT, + LOAD_DEFAULT, + LOAD_DEFAULT, + LOAD_DEFAULT, + false, + BLOCK_SCAN_WARP_SCANS> + SpmvPolicyT; + + typedef AgentSegmentFixupPolicy< + 128, + 4, + BLOCK_LOAD_VECTORIZE, + LOAD_DEFAULT, + BLOCK_SCAN_WARP_SCANS> + SegmentFixupPolicyT; + }; + + /// SM20 + struct Policy200 + { + typedef AgentSpmvPolicy< + 96, + 18, + LOAD_DEFAULT, + LOAD_DEFAULT, + LOAD_DEFAULT, + LOAD_DEFAULT, + LOAD_DEFAULT, + false, + BLOCK_SCAN_RAKING> + SpmvPolicyT; + + typedef AgentSegmentFixupPolicy< + 128, + 4, + BLOCK_LOAD_VECTORIZE, + LOAD_DEFAULT, + BLOCK_SCAN_WARP_SCANS> + SegmentFixupPolicyT; + + }; + + + + /// SM30 + struct Policy300 + { + typedef AgentSpmvPolicy< + 96, + 6, + LOAD_DEFAULT, + LOAD_DEFAULT, + LOAD_DEFAULT, + LOAD_DEFAULT, + LOAD_DEFAULT, + false, + BLOCK_SCAN_WARP_SCANS> + SpmvPolicyT; + + typedef AgentSegmentFixupPolicy< + 128, + 4, + BLOCK_LOAD_VECTORIZE, + LOAD_DEFAULT, + BLOCK_SCAN_WARP_SCANS> + SegmentFixupPolicyT; + + }; + + + /// SM35 + struct Policy350 + { + typedef AgentSpmvPolicy< + (sizeof(ValueT) > 4) ? 96 : 128, + (sizeof(ValueT) > 4) ? 4 : 7, + LOAD_LDG, + LOAD_CA, + LOAD_LDG, + LOAD_LDG, + LOAD_LDG, + (sizeof(ValueT) > 4) ? true : false, + BLOCK_SCAN_WARP_SCANS> + SpmvPolicyT; + + typedef AgentSegmentFixupPolicy< + 128, + 3, + BLOCK_LOAD_VECTORIZE, + LOAD_LDG, + BLOCK_SCAN_WARP_SCANS> + SegmentFixupPolicyT; + }; + + + /// SM37 + struct Policy370 + { + + typedef AgentSpmvPolicy< + (sizeof(ValueT) > 4) ? 128 : 128, + (sizeof(ValueT) > 4) ? 9 : 14, + LOAD_LDG, + LOAD_CA, + LOAD_LDG, + LOAD_LDG, + LOAD_LDG, + false, + BLOCK_SCAN_WARP_SCANS> + SpmvPolicyT; + + typedef AgentSegmentFixupPolicy< + 128, + 3, + BLOCK_LOAD_VECTORIZE, + LOAD_LDG, + BLOCK_SCAN_WARP_SCANS> + SegmentFixupPolicyT; + }; + + /// SM50 + struct Policy500 + { + typedef AgentSpmvPolicy< + (sizeof(ValueT) > 4) ? 64 : 128, + (sizeof(ValueT) > 4) ? 6 : 7, + LOAD_LDG, + LOAD_DEFAULT, + (sizeof(ValueT) > 4) ? LOAD_LDG : LOAD_DEFAULT, + (sizeof(ValueT) > 4) ? LOAD_LDG : LOAD_DEFAULT, + LOAD_LDG, + (sizeof(ValueT) > 4) ? true : false, + (sizeof(ValueT) > 4) ? BLOCK_SCAN_WARP_SCANS : BLOCK_SCAN_RAKING_MEMOIZE> + SpmvPolicyT; + + + typedef AgentSegmentFixupPolicy< + 128, + 3, + BLOCK_LOAD_VECTORIZE, + LOAD_LDG, + BLOCK_SCAN_RAKING_MEMOIZE> + SegmentFixupPolicyT; + }; + + + /// SM60 + struct Policy600 + { + typedef AgentSpmvPolicy< + (sizeof(ValueT) > 4) ? 64 : 128, + (sizeof(ValueT) > 4) ? 5 : 7, + LOAD_DEFAULT, + LOAD_DEFAULT, + LOAD_DEFAULT, + LOAD_DEFAULT, + LOAD_DEFAULT, + false, + BLOCK_SCAN_WARP_SCANS> + SpmvPolicyT; + + + typedef AgentSegmentFixupPolicy< + 128, + 3, + BLOCK_LOAD_DIRECT, + LOAD_LDG, + BLOCK_SCAN_WARP_SCANS> + SegmentFixupPolicyT; + }; + + + + //--------------------------------------------------------------------- + // Tuning policies of current PTX compiler pass + //--------------------------------------------------------------------- + +#if (CUB_PTX_ARCH >= 600) + typedef Policy600 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 500) + typedef Policy500 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 370) + typedef Policy370 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 350) + typedef Policy350 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 300) + typedef Policy300 PtxPolicy; + +#elif (CUB_PTX_ARCH >= 200) + typedef Policy200 PtxPolicy; + +#else + typedef Policy110 PtxPolicy; + +#endif + + // "Opaque" policies (whose parameterizations aren't reflected in the type signature) + struct PtxSpmvPolicyT : PtxPolicy::SpmvPolicyT {}; + struct PtxSegmentFixupPolicy : PtxPolicy::SegmentFixupPolicyT {}; + + + //--------------------------------------------------------------------- + // Utilities + //--------------------------------------------------------------------- + + /** + * Initialize kernel dispatch configurations with the policies corresponding to the PTX assembly we will use + */ + template + CUB_RUNTIME_FUNCTION __forceinline__ + static void InitConfigs( + int ptx_version, + KernelConfig &spmv_config, + KernelConfig &segment_fixup_config) + { + #if (CUB_PTX_ARCH > 0) + + // We're on the device, so initialize the kernel dispatch configurations with the current PTX policy + spmv_config.template Init(); + segment_fixup_config.template Init(); + + #else + + // We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version + if (ptx_version >= 600) + { + spmv_config.template Init(); + segment_fixup_config.template Init(); + } + else if (ptx_version >= 500) + { + spmv_config.template Init(); + segment_fixup_config.template Init(); + } + else if (ptx_version >= 370) + { + spmv_config.template Init(); + segment_fixup_config.template Init(); + } + else if (ptx_version >= 350) + { + spmv_config.template Init(); + segment_fixup_config.template Init(); + } + else if (ptx_version >= 300) + { + spmv_config.template Init(); + segment_fixup_config.template Init(); + + } + else if (ptx_version >= 200) + { + spmv_config.template Init(); + segment_fixup_config.template Init(); + } + else + { + spmv_config.template Init(); + segment_fixup_config.template Init(); + } + + #endif + } + + + /** + * Kernel kernel dispatch configuration. + */ + struct KernelConfig + { + int block_threads; + int items_per_thread; + int tile_items; + + template + CUB_RUNTIME_FUNCTION __forceinline__ + void Init() + { + block_threads = PolicyT::BLOCK_THREADS; + items_per_thread = PolicyT::ITEMS_PER_THREAD; + tile_items = block_threads * items_per_thread; + } + }; + + + //--------------------------------------------------------------------- + // Dispatch entrypoints + //--------------------------------------------------------------------- + + /** + * Internal dispatch routine for computing a device-wide reduction using the + * specified kernel functions. + * + * If the input is larger than a single tile, this method uses two-passes of + * kernel invocations. + */ + template < + typename Spmv1ColKernelT, ///< Function type of cub::DeviceSpmv1ColKernel + typename SpmvSearchKernelT, ///< Function type of cub::AgentSpmvSearchKernel + typename SpmvKernelT, ///< Function type of cub::AgentSpmvKernel + typename SegmentFixupKernelT> ///< Function type of cub::DeviceSegmentFixupKernelT + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Dispatch( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + SpmvParamsT& spmv_params, ///< SpMV input parameter bundle + cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous, ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + Spmv1ColKernelT spmv_1col_kernel, ///< [in] Kernel function pointer to parameterization of DeviceSpmv1ColKernel + SpmvSearchKernelT spmv_search_kernel, ///< [in] Kernel function pointer to parameterization of AgentSpmvSearchKernel + SpmvKernelT spmv_kernel, ///< [in] Kernel function pointer to parameterization of AgentSpmvKernel + SegmentFixupKernelT segment_fixup_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceSegmentFixupKernel + KernelConfig spmv_config, ///< [in] Dispatch parameters that match the policy that \p spmv_kernel was compiled for + KernelConfig segment_fixup_config) ///< [in] Dispatch parameters that match the policy that \p segment_fixup_kernel was compiled for + { +#ifndef CUB_RUNTIME_ENABLED + + // Kernel launch not supported from this device + return CubDebug(cudaErrorNotSupported ); + +#else + cudaError error = cudaSuccess; + do + { + if (spmv_params.num_cols == 1) + { + if (d_temp_storage == NULL) + { + // Return if the caller is simply requesting the size of the storage allocation + temp_storage_bytes = 1; + break; + } + + // Get search/init grid dims + int degen_col_kernel_block_size = INIT_KERNEL_THREADS; + int degen_col_kernel_grid_size = (spmv_params.num_rows + degen_col_kernel_block_size - 1) / degen_col_kernel_block_size; + + if (debug_synchronous) _CubLog("Invoking spmv_1col_kernel<<<%d, %d, 0, %lld>>>()\n", + degen_col_kernel_grid_size, degen_col_kernel_block_size, (long long) stream); + + // Invoke spmv_search_kernel + spmv_1col_kernel<<>>( + spmv_params); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + + break; + } + + // Get device ordinal + int device_ordinal; + if (CubDebug(error = cudaGetDevice(&device_ordinal))) break; + + // Get SM count + int sm_count; + if (CubDebug(error = cudaDeviceGetAttribute (&sm_count, cudaDevAttrMultiProcessorCount, device_ordinal))) break; + + // Get max x-dimension of grid + int max_dim_x; + if (CubDebug(error = cudaDeviceGetAttribute(&max_dim_x, cudaDevAttrMaxGridDimX, device_ordinal))) break;; + + // Total number of spmv work items + int num_merge_items = spmv_params.num_rows + spmv_params.num_nonzeros; + + // Tile sizes of kernels + int merge_tile_size = spmv_config.block_threads * spmv_config.items_per_thread; + int segment_fixup_tile_size = segment_fixup_config.block_threads * segment_fixup_config.items_per_thread; + + // Number of tiles for kernels + unsigned int num_merge_tiles = (num_merge_items + merge_tile_size - 1) / merge_tile_size; + unsigned int num_segment_fixup_tiles = (num_merge_tiles + segment_fixup_tile_size - 1) / segment_fixup_tile_size; + + // Get SM occupancy for kernels + int spmv_sm_occupancy; + if (CubDebug(error = MaxSmOccupancy( + spmv_sm_occupancy, + spmv_kernel, + spmv_config.block_threads))) break; + + int segment_fixup_sm_occupancy; + if (CubDebug(error = MaxSmOccupancy( + segment_fixup_sm_occupancy, + segment_fixup_kernel, + segment_fixup_config.block_threads))) break; + + // Get grid dimensions + dim3 spmv_grid_size( + CUB_MIN(num_merge_tiles, max_dim_x), + (num_merge_tiles + max_dim_x - 1) / max_dim_x, + 1); + + dim3 segment_fixup_grid_size( + CUB_MIN(num_segment_fixup_tiles, max_dim_x), + (num_segment_fixup_tiles + max_dim_x - 1) / max_dim_x, + 1); + + // Get the temporary storage allocation requirements + size_t allocation_sizes[3]; + if (CubDebug(error = ScanTileStateT::AllocationSize(num_segment_fixup_tiles, allocation_sizes[0]))) break; // bytes needed for reduce-by-key tile status descriptors + allocation_sizes[1] = num_merge_tiles * sizeof(KeyValuePairT); // bytes needed for block carry-out pairs + allocation_sizes[2] = (num_merge_tiles + 1) * sizeof(CoordinateT); // bytes needed for tile starting coordinates + + // Alias the temporary allocations from the single storage blob (or compute the necessary size of the blob) + void* allocations[3]; + if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break; + if (d_temp_storage == NULL) + { + // Return if the caller is simply requesting the size of the storage allocation + break; + } + + // Construct the tile status interface + ScanTileStateT tile_state; + if (CubDebug(error = tile_state.Init(num_segment_fixup_tiles, allocations[0], allocation_sizes[0]))) break; + + // Alias the other allocations + KeyValuePairT* d_tile_carry_pairs = (KeyValuePairT*) allocations[1]; // Agent carry-out pairs + CoordinateT* d_tile_coordinates = (CoordinateT*) allocations[2]; // Agent starting coordinates + + // Get search/init grid dims + int search_block_size = INIT_KERNEL_THREADS; + int search_grid_size = (num_merge_tiles + 1 + search_block_size - 1) / search_block_size; + +#if (CUB_PTX_ARCH == 0) + // Init textures + if (CubDebug(error = spmv_params.t_vector_x.BindTexture(spmv_params.d_vector_x))) break; +#endif + + if (search_grid_size < sm_count) +// if (num_merge_tiles < spmv_sm_occupancy * sm_count) + { + // Not enough spmv tiles to saturate the device: have spmv blocks search their own staring coords + d_tile_coordinates = NULL; + } + else + { + // Use separate search kernel if we have enough spmv tiles to saturate the device + + // Log spmv_search_kernel configuration + if (debug_synchronous) _CubLog("Invoking spmv_search_kernel<<<%d, %d, 0, %lld>>>()\n", + search_grid_size, search_block_size, (long long) stream); + + // Invoke spmv_search_kernel + spmv_search_kernel<<>>( + num_merge_tiles, + d_tile_coordinates, + spmv_params); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + } + + // Log spmv_kernel configuration + if (debug_synchronous) _CubLog("Invoking spmv_kernel<<<{%d,%d,%d}, %d, 0, %lld>>>(), %d items per thread, %d SM occupancy\n", + spmv_grid_size.x, spmv_grid_size.y, spmv_grid_size.z, spmv_config.block_threads, (long long) stream, spmv_config.items_per_thread, spmv_sm_occupancy); + + // Invoke spmv_kernel + spmv_kernel<<>>( + spmv_params, + d_tile_coordinates, + d_tile_carry_pairs, + num_merge_tiles, + tile_state, + num_segment_fixup_tiles); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + + // Run reduce-by-key fixup if necessary + if (num_merge_tiles > 1) + { + // Log segment_fixup_kernel configuration + if (debug_synchronous) _CubLog("Invoking segment_fixup_kernel<<<{%d,%d,%d}, %d, 0, %lld>>>(), %d items per thread, %d SM occupancy\n", + segment_fixup_grid_size.x, segment_fixup_grid_size.y, segment_fixup_grid_size.z, segment_fixup_config.block_threads, (long long) stream, segment_fixup_config.items_per_thread, segment_fixup_sm_occupancy); + + // Invoke segment_fixup_kernel + segment_fixup_kernel<<>>( + d_tile_carry_pairs, + spmv_params.d_vector_y, + num_merge_tiles, + num_segment_fixup_tiles, + tile_state); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + // Sync the stream if specified to flush runtime errors + if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; + } + +#if (CUB_PTX_ARCH == 0) + // Free textures + if (CubDebug(error = spmv_params.t_vector_x.UnbindTexture())) break; +#endif + } + while (0); + + return error; + +#endif // CUB_RUNTIME_ENABLED + } + + + /** + * Internal dispatch routine for computing a device-wide reduction + */ + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Dispatch( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + SpmvParamsT& spmv_params, ///< SpMV input parameter bundle + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + cudaError error = cudaSuccess; + do + { + // Get PTX version + int ptx_version; + #if (CUB_PTX_ARCH == 0) + if (CubDebug(error = PtxVersion(ptx_version))) break; + #else + ptx_version = CUB_PTX_ARCH; + #endif + + // Get kernel kernel dispatch configurations + KernelConfig spmv_config, segment_fixup_config; + InitConfigs(ptx_version, spmv_config, segment_fixup_config); + + if (CubDebug(error = Dispatch( + d_temp_storage, temp_storage_bytes, spmv_params, stream, debug_synchronous, + DeviceSpmv1ColKernel, + DeviceSpmvSearchKernel, + DeviceSpmvKernel, + DeviceSegmentFixupKernel, + spmv_config, segment_fixup_config))) break; + + } + while (0); + + return error; + } +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/grid/grid_barrier.cuh b/fastertransformer/cuda/cub/grid/grid_barrier.cuh new file mode 100644 index 000000000..461fb4421 --- /dev/null +++ b/fastertransformer/cuda/cub/grid/grid_barrier.cuh @@ -0,0 +1,211 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::GridBarrier implements a software global barrier among thread blocks within a CUDA grid + */ + +#pragma once + +#include "../util_debug.cuh" +#include "../util_namespace.cuh" +#include "../thread/thread_load.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \addtogroup GridModule + * @{ + */ + + +/** + * \brief GridBarrier implements a software global barrier among thread blocks within a CUDA grid + */ +class GridBarrier +{ +protected : + + typedef unsigned int SyncFlag; + + // Counters in global device memory + SyncFlag* d_sync; + +public: + + /** + * Constructor + */ + GridBarrier() : d_sync(NULL) {} + + + /** + * Synchronize + */ + __device__ __forceinline__ void Sync() const + { + volatile SyncFlag *d_vol_sync = d_sync; + + // Threadfence and syncthreads to make sure global writes are visible before + // thread-0 reports in with its sync counter + __threadfence(); + CTA_SYNC(); + + if (blockIdx.x == 0) + { + // Report in ourselves + if (threadIdx.x == 0) + { + d_vol_sync[blockIdx.x] = 1; + } + + CTA_SYNC(); + + // Wait for everyone else to report in + for (int peer_block = threadIdx.x; peer_block < gridDim.x; peer_block += blockDim.x) + { + while (ThreadLoad(d_sync + peer_block) == 0) + { + __threadfence_block(); + } + } + + CTA_SYNC(); + + // Let everyone know it's safe to proceed + for (int peer_block = threadIdx.x; peer_block < gridDim.x; peer_block += blockDim.x) + { + d_vol_sync[peer_block] = 0; + } + } + else + { + if (threadIdx.x == 0) + { + // Report in + d_vol_sync[blockIdx.x] = 1; + + // Wait for acknowledgment + while (ThreadLoad(d_sync + blockIdx.x) == 1) + { + __threadfence_block(); + } + } + + CTA_SYNC(); + } + } +}; + + +/** + * \brief GridBarrierLifetime extends GridBarrier to provide lifetime management of the temporary device storage needed for cooperation. + * + * Uses RAII for lifetime, i.e., device resources are reclaimed when + * the destructor is called. + */ +class GridBarrierLifetime : public GridBarrier +{ +protected: + + // Number of bytes backed by d_sync + size_t sync_bytes; + +public: + + /** + * Constructor + */ + GridBarrierLifetime() : GridBarrier(), sync_bytes(0) {} + + + /** + * DeviceFrees and resets the progress counters + */ + cudaError_t HostReset() + { + cudaError_t retval = cudaSuccess; + if (d_sync) + { + CubDebug(retval = cudaFree(d_sync)); + d_sync = NULL; + } + sync_bytes = 0; + return retval; + } + + + /** + * Destructor + */ + virtual ~GridBarrierLifetime() + { + HostReset(); + } + + + /** + * Sets up the progress counters for the next kernel launch (lazily + * allocating and initializing them if necessary) + */ + cudaError_t Setup(int sweep_grid_size) + { + cudaError_t retval = cudaSuccess; + do { + size_t new_sync_bytes = sweep_grid_size * sizeof(SyncFlag); + if (new_sync_bytes > sync_bytes) + { + if (d_sync) + { + if (CubDebug(retval = cudaFree(d_sync))) break; + } + + sync_bytes = new_sync_bytes; + + // Allocate and initialize to zero + if (CubDebug(retval = cudaMalloc((void**) &d_sync, sync_bytes))) break; + if (CubDebug(retval = cudaMemset(d_sync, 0, new_sync_bytes))) break; + } + } while (0); + + return retval; + } +}; + + +/** @} */ // end group GridModule + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/grid/grid_even_share.cuh b/fastertransformer/cuda/cub/grid/grid_even_share.cuh new file mode 100644 index 000000000..f0b3a69ae --- /dev/null +++ b/fastertransformer/cuda/cub/grid/grid_even_share.cuh @@ -0,0 +1,222 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::GridEvenShare is a descriptor utility for distributing input among CUDA thread blocks in an "even-share" fashion. Each thread block gets roughly the same number of fixed-size work units (grains). + */ + + +#pragma once + +#include "../util_namespace.cuh" +#include "../util_macro.cuh" +#include "grid_mapping.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \addtogroup GridModule + * @{ + */ + + +/** + * \brief GridEvenShare is a descriptor utility for distributing input among + * CUDA thread blocks in an "even-share" fashion. Each thread block gets roughly + * the same number of input tiles. + * + * \par Overview + * Each thread block is assigned a consecutive sequence of input tiles. To help + * preserve alignment and eliminate the overhead of guarded loads for all but the + * last thread block, to GridEvenShare assigns one of three different amounts of + * work to a given thread block: "big", "normal", or "last". The "big" workloads + * are one scheduling grain larger than "normal". The "last" work unit for the + * last thread block may be partially-full if the input is not an even multiple of + * the scheduling grain size. + * + * \par + * Before invoking a child grid, a parent thread will typically construct an + * instance of GridEvenShare. The instance can be passed to child thread blocks + * which can initialize their per-thread block offsets using \p BlockInit(). + */ +template +struct GridEvenShare +{ +private: + + OffsetT total_tiles; + int big_shares; + OffsetT big_share_items; + OffsetT normal_share_items; + OffsetT normal_base_offset; + +public: + + /// Total number of input items + OffsetT num_items; + + /// Grid size in thread blocks + int grid_size; + + /// OffsetT into input marking the beginning of the owning thread block's segment of input tiles + OffsetT block_offset; + + /// OffsetT into input of marking the end (one-past) of the owning thread block's segment of input tiles + OffsetT block_end; + + /// Stride between input tiles + OffsetT block_stride; + + + /** + * \brief Constructor. + */ + __host__ __device__ __forceinline__ GridEvenShare() : + total_tiles(0), + big_shares(0), + big_share_items(0), + normal_share_items(0), + normal_base_offset(0), + num_items(0), + grid_size(0), + block_offset(0), + block_end(0), + block_stride(0) + {} + + + /** + * \brief Dispatch initializer. To be called prior prior to kernel launch. + */ + __host__ __device__ __forceinline__ void DispatchInit( + OffsetT num_items, ///< Total number of input items + int max_grid_size, ///< Maximum grid size allowable (actual grid size may be less if not warranted by the the number of input items) + int tile_items) ///< Number of data items per input tile + { + this->block_offset = num_items; // Initialize past-the-end + this->block_end = num_items; // Initialize past-the-end + this->num_items = num_items; + this->total_tiles = (num_items + tile_items - 1) / tile_items; + this->grid_size = CUB_MIN(total_tiles, max_grid_size); + OffsetT avg_tiles_per_block = total_tiles / grid_size; + this->big_shares = total_tiles - (avg_tiles_per_block * grid_size); // leftover grains go to big blocks + this->normal_share_items = avg_tiles_per_block * tile_items; + this->normal_base_offset = big_shares * tile_items; + this->big_share_items = normal_share_items + tile_items; + } + + + /** + * \brief Initializes ranges for the specified thread block index. Specialized + * for a "raking" access pattern in which each thread block is assigned a + * consecutive sequence of input tiles. + */ + template + __device__ __forceinline__ void BlockInit( + int block_id, + Int2Type /*strategy_tag*/) + { + block_stride = TILE_ITEMS; + if (block_id < big_shares) + { + // This thread block gets a big share of grains (avg_tiles_per_block + 1) + block_offset = (block_id * big_share_items); + block_end = block_offset + big_share_items; + } + else if (block_id < total_tiles) + { + // This thread block gets a normal share of grains (avg_tiles_per_block) + block_offset = normal_base_offset + (block_id * normal_share_items); + block_end = CUB_MIN(num_items, block_offset + normal_share_items); + } + // Else default past-the-end + } + + + /** + * \brief Block-initialization, specialized for a "raking" access + * pattern in which each thread block is assigned a consecutive sequence + * of input tiles. + */ + template + __device__ __forceinline__ void BlockInit( + int block_id, + Int2Type /*strategy_tag*/) + { + block_stride = grid_size * TILE_ITEMS; + block_offset = (block_id * TILE_ITEMS); + block_end = num_items; + } + + + /** + * \brief Block-initialization, specialized for "strip mining" access + * pattern in which the input tiles assigned to each thread block are + * separated by a stride equal to the the extent of the grid. + */ + template < + int TILE_ITEMS, + GridMappingStrategy STRATEGY> + __device__ __forceinline__ void BlockInit() + { + BlockInit(blockIdx.x, Int2Type()); + } + + + /** + * \brief Block-initialization, specialized for a "raking" access + * pattern in which each thread block is assigned a consecutive sequence + * of input tiles. + */ + template + __device__ __forceinline__ void BlockInit( + OffsetT block_offset, ///< [in] Threadblock begin offset (inclusive) + OffsetT block_end) ///< [in] Threadblock end offset (exclusive) + { + this->block_offset = block_offset; + this->block_end = block_end; + this->block_stride = TILE_ITEMS; + } + + +}; + + + + + +/** @} */ // end group GridModule + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/grid/grid_mapping.cuh b/fastertransformer/cuda/cub/grid/grid_mapping.cuh new file mode 100644 index 000000000..f0e9fded2 --- /dev/null +++ b/fastertransformer/cuda/cub/grid/grid_mapping.cuh @@ -0,0 +1,113 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::GridMappingStrategy enumerates alternative strategies for mapping constant-sized tiles of device-wide data onto a grid of CUDA thread blocks. + */ + +#pragma once + +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \addtogroup GridModule + * @{ + */ + + +/****************************************************************************** + * Mapping policies + *****************************************************************************/ + + +/** + * \brief cub::GridMappingStrategy enumerates alternative strategies for mapping constant-sized tiles of device-wide data onto a grid of CUDA thread blocks. + */ +enum GridMappingStrategy +{ + /** + * \brief An a "raking" access pattern in which each thread block is + * assigned a consecutive sequence of input tiles + * + * \par Overview + * The input is evenly partitioned into \p p segments, where \p p is + * constant and corresponds loosely to the number of thread blocks that may + * actively reside on the target device. Each segment is comprised of + * consecutive tiles, where a tile is a small, constant-sized unit of input + * to be processed to completion before the thread block terminates or + * obtains more work. The kernel invokes \p p thread blocks, each + * of which iteratively consumes a segment of n/p elements + * in tile-size increments. + */ + GRID_MAPPING_RAKE, + + /** + * \brief An a "strip mining" access pattern in which the input tiles assigned + * to each thread block are separated by a stride equal to the the extent of + * the grid. + * + * \par Overview + * The input is evenly partitioned into \p p sets, where \p p is + * constant and corresponds loosely to the number of thread blocks that may + * actively reside on the target device. Each set is comprised of + * data tiles separated by stride \p tiles, where a tile is a small, + * constant-sized unit of input to be processed to completion before the + * thread block terminates or obtains more work. The kernel invokes \p p + * thread blocks, each of which iteratively consumes a segment of + * n/p elements in tile-size increments. + */ + GRID_MAPPING_STRIP_MINE, + + /** + * \brief A dynamic "queue-based" strategy for assigning input tiles to thread blocks. + * + * \par Overview + * The input is treated as a queue to be dynamically consumed by a grid of + * thread blocks. Work is atomically dequeued in tiles, where a tile is a + * unit of input to be processed to completion before the thread block + * terminates or obtains more work. The grid size \p p is constant, + * loosely corresponding to the number of thread blocks that may actively + * reside on the target device. + */ + GRID_MAPPING_DYNAMIC, +}; + + +/** @} */ // end group GridModule + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/grid/grid_queue.cuh b/fastertransformer/cuda/cub/grid/grid_queue.cuh new file mode 100644 index 000000000..9615b14db --- /dev/null +++ b/fastertransformer/cuda/cub/grid/grid_queue.cuh @@ -0,0 +1,220 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::GridQueue is a descriptor utility for dynamic queue management. + */ + +#pragma once + +#include "../util_namespace.cuh" +#include "../util_debug.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \addtogroup GridModule + * @{ + */ + + +/** + * \brief GridQueue is a descriptor utility for dynamic queue management. + * + * \par Overview + * GridQueue descriptors provides abstractions for "filling" or + * "draining" globally-shared vectors. + * + * \par + * A "filling" GridQueue works by atomically-adding to a zero-initialized counter, + * returning a unique offset for the calling thread to write its items. + * The GridQueue maintains the total "fill-size". The fill counter must be reset + * using GridQueue::ResetFill by the host or kernel instance prior to the kernel instance that + * will be filling. + * + * \par + * Similarly, a "draining" GridQueue works by works by atomically-incrementing a + * zero-initialized counter, returning a unique offset for the calling thread to + * read its items. Threads can safely drain until the array's logical fill-size is + * exceeded. The drain counter must be reset using GridQueue::ResetDrain or + * GridQueue::FillAndResetDrain by the host or kernel instance prior to the kernel instance that + * will be filling. (For dynamic work distribution of existing data, the corresponding fill-size + * is simply the number of elements in the array.) + * + * \par + * Iterative work management can be implemented simply with a pair of flip-flopping + * work buffers, each with an associated set of fill and drain GridQueue descriptors. + * + * \tparam OffsetT Signed integer type for global offsets + */ +template +class GridQueue +{ +private: + + /// Counter indices + enum + { + FILL = 0, + DRAIN = 1, + }; + + /// Pair of counters + OffsetT *d_counters; + +public: + + /// Returns the device allocation size in bytes needed to construct a GridQueue instance + __host__ __device__ __forceinline__ + static size_t AllocationSize() + { + return sizeof(OffsetT) * 2; + } + + + /// Constructs an invalid GridQueue descriptor + __host__ __device__ __forceinline__ GridQueue() + : + d_counters(NULL) + {} + + + /// Constructs a GridQueue descriptor around the device storage allocation + __host__ __device__ __forceinline__ GridQueue( + void *d_storage) ///< Device allocation to back the GridQueue. Must be at least as big as AllocationSize(). + : + d_counters((OffsetT*) d_storage) + {} + + + /// This operation sets the fill-size and resets the drain counter, preparing the GridQueue for draining in the next kernel instance. To be called by the host or by a kernel prior to that which will be draining. + __host__ __device__ __forceinline__ cudaError_t FillAndResetDrain( + OffsetT fill_size, + cudaStream_t stream = 0) + { +#if (CUB_PTX_ARCH > 0) + (void)stream; + d_counters[FILL] = fill_size; + d_counters[DRAIN] = 0; + return cudaSuccess; +#else + OffsetT counters[2]; + counters[FILL] = fill_size; + counters[DRAIN] = 0; + return CubDebug(cudaMemcpyAsync(d_counters, counters, sizeof(OffsetT) * 2, cudaMemcpyHostToDevice, stream)); +#endif + } + + + /// This operation resets the drain so that it may advance to meet the existing fill-size. To be called by the host or by a kernel prior to that which will be draining. + __host__ __device__ __forceinline__ cudaError_t ResetDrain(cudaStream_t stream = 0) + { +#if (CUB_PTX_ARCH > 0) + (void)stream; + d_counters[DRAIN] = 0; + return cudaSuccess; +#else + return CubDebug(cudaMemsetAsync(d_counters + DRAIN, 0, sizeof(OffsetT), stream)); +#endif + } + + + /// This operation resets the fill counter. To be called by the host or by a kernel prior to that which will be filling. + __host__ __device__ __forceinline__ cudaError_t ResetFill(cudaStream_t stream = 0) + { +#if (CUB_PTX_ARCH > 0) + (void)stream; + d_counters[FILL] = 0; + return cudaSuccess; +#else + return CubDebug(cudaMemsetAsync(d_counters + FILL, 0, sizeof(OffsetT), stream)); +#endif + } + + + /// Returns the fill-size established by the parent or by the previous kernel. + __host__ __device__ __forceinline__ cudaError_t FillSize( + OffsetT &fill_size, + cudaStream_t stream = 0) + { +#if (CUB_PTX_ARCH > 0) + (void)stream; + fill_size = d_counters[FILL]; + return cudaSuccess; +#else + return CubDebug(cudaMemcpyAsync(&fill_size, d_counters + FILL, sizeof(OffsetT), cudaMemcpyDeviceToHost, stream)); +#endif + } + + + /// Drain \p num_items from the queue. Returns offset from which to read items. To be called from CUDA kernel. + __device__ __forceinline__ OffsetT Drain(OffsetT num_items) + { + return atomicAdd(d_counters + DRAIN, num_items); + } + + + /// Fill \p num_items into the queue. Returns offset from which to write items. To be called from CUDA kernel. + __device__ __forceinline__ OffsetT Fill(OffsetT num_items) + { + return atomicAdd(d_counters + FILL, num_items); + } +}; + + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + +/** + * Reset grid queue (call with 1 block of 1 thread) + */ +template +__global__ void FillAndResetDrainKernel( + GridQueue grid_queue, + OffsetT num_items) +{ + grid_queue.FillAndResetDrain(num_items); +} + + + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + +/** @} */ // end group GridModule + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + + diff --git a/fastertransformer/cuda/cub/host/mutex.cuh b/fastertransformer/cuda/cub/host/mutex.cuh new file mode 100644 index 000000000..ff7ec90dd --- /dev/null +++ b/fastertransformer/cuda/cub/host/mutex.cuh @@ -0,0 +1,171 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Simple portable mutex + */ + + +#pragma once + +#if (__cplusplus > 199711L) || (defined(_MSC_VER) && _MSC_VER >= 1800) + #include +#else + #if defined(_WIN32) || defined(_WIN64) + #include + + #define WIN32_LEAN_AND_MEAN + #define NOMINMAX + #include + #undef WIN32_LEAN_AND_MEAN + #undef NOMINMAX + + /** + * Compiler read/write barrier + */ + #pragma intrinsic(_ReadWriteBarrier) + + #endif +#endif + +#include "../util_namespace.cuh" + + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * Simple portable mutex + * - Wraps std::mutex when compiled with C++11 or newer (supported on all platforms) + * - Uses GNU/Windows spinlock mechanisms for pre C++11 (supported on x86/x64 when compiled with cl.exe or g++) + */ +struct Mutex +{ +#if (__cplusplus > 199711L) || (defined(_MSC_VER) && _MSC_VER >= 1800) + + std::mutex mtx; + + void Lock() + { + mtx.lock(); + } + + void Unlock() + { + mtx.unlock(); + } + + void TryLock() + { + mtx.try_lock(); + } + +#else //__cplusplus > 199711L + + #if defined(_MSC_VER) + + // Microsoft VC++ + typedef long Spinlock; + + #else + + // GNU g++ + typedef int Spinlock; + + /** + * Compiler read/write barrier + */ + __forceinline__ void _ReadWriteBarrier() + { + __sync_synchronize(); + } + + /** + * Atomic exchange + */ + __forceinline__ long _InterlockedExchange(volatile int * const Target, const int Value) + { + // NOTE: __sync_lock_test_and_set would be an acquire barrier, so we force a full barrier + _ReadWriteBarrier(); + return __sync_lock_test_and_set(Target, Value); + } + + /** + * Pause instruction to prevent excess processor bus usage + */ + __forceinline__ void YieldProcessor() + { + } + + #endif // defined(_MSC_VER) + + /// Lock member + volatile Spinlock lock; + + /** + * Constructor + */ + Mutex() : lock(0) {} + + /** + * Return when the specified spinlock has been acquired + */ + __forceinline__ void Lock() + { + while (1) + { + if (!_InterlockedExchange(&lock, 1)) return; + while (lock) YieldProcessor(); + } + } + + + /** + * Release the specified spinlock + */ + __forceinline__ void Unlock() + { + _ReadWriteBarrier(); + lock = 0; + } + +#endif // __cplusplus > 199711L + +}; + + + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/fastertransformer/cuda/cub/iterator/arg_index_input_iterator.cuh b/fastertransformer/cuda/cub/iterator/arg_index_input_iterator.cuh new file mode 100644 index 000000000..95a84a579 --- /dev/null +++ b/fastertransformer/cuda/cub/iterator/arg_index_input_iterator.cuh @@ -0,0 +1,259 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Random-access iterator types + */ + +#pragma once + +#include +#include + +#include "../thread/thread_load.cuh" +#include "../thread/thread_store.cuh" +#include "../util_device.cuh" +#include "../util_namespace.cuh" + +#include + +#if (THRUST_VERSION >= 100700) + // This iterator is compatible with Thrust API 1.7 and newer + #include + #include +#endif // THRUST_VERSION + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \addtogroup UtilIterator + * @{ + */ + + +/** + * \brief A random-access input wrapper for pairing dereferenced values with their corresponding indices (forming \p KeyValuePair tuples). + * + * \par Overview + * - ArgIndexInputIteratorTwraps a random access input iterator \p itr of type \p InputIteratorT. + * Dereferencing an ArgIndexInputIteratorTat offset \p i produces a \p KeyValuePair value whose + * \p key field is \p i and whose \p value field is itr[i]. + * - Can be used with any data type. + * - Can be constructed, manipulated, and exchanged within and between host and device + * functions. Wrapped host memory can only be dereferenced on the host, and wrapped + * device memory can only be dereferenced on the device. + * - Compatible with Thrust API v1.7 or newer. + * + * \par Snippet + * The code snippet below illustrates the use of \p ArgIndexInputIteratorTto + * dereference an array of doubles + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize a device array + * double *d_in; // e.g., [8.0, 6.0, 7.0, 5.0, 3.0, 0.0, 9.0] + * + * // Create an iterator wrapper + * cub::ArgIndexInputIterator itr(d_in); + * + * // Within device code: + * typedef typename cub::ArgIndexInputIterator::value_type Tuple; + * Tuple item_offset_pair.key = *itr; + * printf("%f @ %d\n", + * item_offset_pair.value, + * item_offset_pair.key); // 8.0 @ 0 + * + * itr = itr + 6; + * item_offset_pair.key = *itr; + * printf("%f @ %d\n", + * item_offset_pair.value, + * item_offset_pair.key); // 9.0 @ 6 + * + * \endcode + * + * \tparam InputIteratorT The value type of the wrapped input iterator + * \tparam OffsetT The difference type of this iterator (Default: \p ptrdiff_t) + * \tparam OutputValueT The paired value type of the tuple (Default: value type of input iterator) + */ +template < + typename InputIteratorT, + typename OffsetT = ptrdiff_t, + typename OutputValueT = typename std::iterator_traits::value_type> +class ArgIndexInputIterator +{ +public: + + // Required iterator traits + typedef ArgIndexInputIterator self_type; ///< My own type + typedef OffsetT difference_type; ///< Type to express the result of subtracting one iterator from another + typedef KeyValuePair value_type; ///< The type of the element the iterator can point to + typedef value_type* pointer; ///< The type of a pointer to an element the iterator can point to + typedef value_type reference; ///< The type of a reference to an element the iterator can point to + +#if (THRUST_VERSION >= 100700) + // Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods + typedef typename thrust::detail::iterator_facade_category< + thrust::any_system_tag, + thrust::random_access_traversal_tag, + value_type, + reference + >::type iterator_category; ///< The iterator category +#else + typedef std::random_access_iterator_tag iterator_category; ///< The iterator category +#endif // THRUST_VERSION + +private: + + InputIteratorT itr; + difference_type offset; + +public: + + /// Constructor + __host__ __device__ __forceinline__ ArgIndexInputIterator( + InputIteratorT itr, ///< Input iterator to wrap + difference_type offset = 0) ///< OffsetT (in items) from \p itr denoting the position of the iterator + : + itr(itr), + offset(offset) + {} + + /// Postfix increment + __host__ __device__ __forceinline__ self_type operator++(int) + { + self_type retval = *this; + offset++; + return retval; + } + + /// Prefix increment + __host__ __device__ __forceinline__ self_type operator++() + { + offset++; + return *this; + } + + /// Indirection + __host__ __device__ __forceinline__ reference operator*() const + { + value_type retval; + retval.value = itr[offset]; + retval.key = offset; + return retval; + } + + /// Addition + template + __host__ __device__ __forceinline__ self_type operator+(Distance n) const + { + self_type retval(itr, offset + n); + return retval; + } + + /// Addition assignment + template + __host__ __device__ __forceinline__ self_type& operator+=(Distance n) + { + offset += n; + return *this; + } + + /// Subtraction + template + __host__ __device__ __forceinline__ self_type operator-(Distance n) const + { + self_type retval(itr, offset - n); + return retval; + } + + /// Subtraction assignment + template + __host__ __device__ __forceinline__ self_type& operator-=(Distance n) + { + offset -= n; + return *this; + } + + /// Distance + __host__ __device__ __forceinline__ difference_type operator-(self_type other) const + { + return offset - other.offset; + } + + /// Array subscript + template + __host__ __device__ __forceinline__ reference operator[](Distance n) const + { + self_type offset = (*this) + n; + return *offset; + } + + /// Structure dereference + __host__ __device__ __forceinline__ pointer operator->() + { + return &(*(*this)); + } + + /// Equal to + __host__ __device__ __forceinline__ bool operator==(const self_type& rhs) + { + return ((itr == rhs.itr) && (offset == rhs.offset)); + } + + /// Not equal to + __host__ __device__ __forceinline__ bool operator!=(const self_type& rhs) + { + return ((itr != rhs.itr) || (offset != rhs.offset)); + } + + /// Normalize + __host__ __device__ __forceinline__ void normalize() + { + itr += offset; + offset = 0; + } + + /// ostream operator + friend std::ostream& operator<<(std::ostream& os, const self_type& /*itr*/) + { + return os; + } +}; + + + +/** @} */ // end group UtilIterator + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/iterator/cache_modified_input_iterator.cuh b/fastertransformer/cuda/cub/iterator/cache_modified_input_iterator.cuh new file mode 100644 index 000000000..b4ad91e2f --- /dev/null +++ b/fastertransformer/cuda/cub/iterator/cache_modified_input_iterator.cuh @@ -0,0 +1,240 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Random-access iterator types + */ + +#pragma once + +#include +#include + +#include "../thread/thread_load.cuh" +#include "../thread/thread_store.cuh" +#include "../util_device.cuh" +#include "../util_namespace.cuh" + +#if (THRUST_VERSION >= 100700) + // This iterator is compatible with Thrust API 1.7 and newer + #include + #include +#endif // THRUST_VERSION + + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + + +/** + * \addtogroup UtilIterator + * @{ + */ + + +/** + * \brief A random-access input wrapper for dereferencing array values using a PTX cache load modifier. + * + * \par Overview + * - CacheModifiedInputIteratorTis a random-access input iterator that wraps a native + * device pointer of type ValueType*. \p ValueType references are + * made by reading \p ValueType values through loads modified by \p MODIFIER. + * - Can be used to load any data type from memory using PTX cache load modifiers (e.g., "LOAD_LDG", + * "LOAD_CG", "LOAD_CA", "LOAD_CS", "LOAD_CV", etc.). + * - Can be constructed, manipulated, and exchanged within and between host and device + * functions, but can only be dereferenced within device functions. + * - Compatible with Thrust API v1.7 or newer. + * + * \par Snippet + * The code snippet below illustrates the use of \p CacheModifiedInputIteratorTto + * dereference a device array of double using the "ldg" PTX load modifier + * (i.e., load values through texture cache). + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize a device array + * double *d_in; // e.g., [8.0, 6.0, 7.0, 5.0, 3.0, 0.0, 9.0] + * + * // Create an iterator wrapper + * cub::CacheModifiedInputIterator itr(d_in); + * + * // Within device code: + * printf("%f\n", itr[0]); // 8.0 + * printf("%f\n", itr[1]); // 6.0 + * printf("%f\n", itr[6]); // 9.0 + * + * \endcode + * + * \tparam CacheLoadModifier The cub::CacheLoadModifier to use when accessing data + * \tparam ValueType The value type of this iterator + * \tparam OffsetT The difference type of this iterator (Default: \p ptrdiff_t) + */ +template < + CacheLoadModifier MODIFIER, + typename ValueType, + typename OffsetT = ptrdiff_t> +class CacheModifiedInputIterator +{ +public: + + // Required iterator traits + typedef CacheModifiedInputIterator self_type; ///< My own type + typedef OffsetT difference_type; ///< Type to express the result of subtracting one iterator from another + typedef ValueType value_type; ///< The type of the element the iterator can point to + typedef ValueType* pointer; ///< The type of a pointer to an element the iterator can point to + typedef ValueType reference; ///< The type of a reference to an element the iterator can point to + +#if (THRUST_VERSION >= 100700) + // Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods + typedef typename thrust::detail::iterator_facade_category< + thrust::device_system_tag, + thrust::random_access_traversal_tag, + value_type, + reference + >::type iterator_category; ///< The iterator category +#else + typedef std::random_access_iterator_tag iterator_category; ///< The iterator category +#endif // THRUST_VERSION + + +public: + + /// Wrapped native pointer + ValueType* ptr; + + /// Constructor + template + __host__ __device__ __forceinline__ CacheModifiedInputIterator( + QualifiedValueType* ptr) ///< Native pointer to wrap + : + ptr(const_cast::Type *>(ptr)) + {} + + /// Postfix increment + __host__ __device__ __forceinline__ self_type operator++(int) + { + self_type retval = *this; + ptr++; + return retval; + } + + /// Prefix increment + __host__ __device__ __forceinline__ self_type operator++() + { + ptr++; + return *this; + } + + /// Indirection + __device__ __forceinline__ reference operator*() const + { + return ThreadLoad(ptr); + } + + /// Addition + template + __host__ __device__ __forceinline__ self_type operator+(Distance n) const + { + self_type retval(ptr + n); + return retval; + } + + /// Addition assignment + template + __host__ __device__ __forceinline__ self_type& operator+=(Distance n) + { + ptr += n; + return *this; + } + + /// Subtraction + template + __host__ __device__ __forceinline__ self_type operator-(Distance n) const + { + self_type retval(ptr - n); + return retval; + } + + /// Subtraction assignment + template + __host__ __device__ __forceinline__ self_type& operator-=(Distance n) + { + ptr -= n; + return *this; + } + + /// Distance + __host__ __device__ __forceinline__ difference_type operator-(self_type other) const + { + return ptr - other.ptr; + } + + /// Array subscript + template + __device__ __forceinline__ reference operator[](Distance n) const + { + return ThreadLoad(ptr + n); + } + + /// Structure dereference + __device__ __forceinline__ pointer operator->() + { + return &ThreadLoad(ptr); + } + + /// Equal to + __host__ __device__ __forceinline__ bool operator==(const self_type& rhs) + { + return (ptr == rhs.ptr); + } + + /// Not equal to + __host__ __device__ __forceinline__ bool operator!=(const self_type& rhs) + { + return (ptr != rhs.ptr); + } + + /// ostream operator + friend std::ostream& operator<<(std::ostream& os, const self_type& /*itr*/) + { + return os; + } +}; + + + +/** @} */ // end group UtilIterator + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/iterator/cache_modified_output_iterator.cuh b/fastertransformer/cuda/cub/iterator/cache_modified_output_iterator.cuh new file mode 100644 index 000000000..c3e3321d3 --- /dev/null +++ b/fastertransformer/cuda/cub/iterator/cache_modified_output_iterator.cuh @@ -0,0 +1,254 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Random-access iterator types + */ + +#pragma once + +#include +#include + +#include "../thread/thread_load.cuh" +#include "../thread/thread_store.cuh" +#include "../util_device.cuh" +#include "../util_namespace.cuh" + +#if (THRUST_VERSION >= 100700) + // This iterator is compatible with Thrust API 1.7 and newer + #include + #include +#endif // THRUST_VERSION + + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \addtogroup UtilIterator + * @{ + */ + + +/** + * \brief A random-access output wrapper for storing array values using a PTX cache-modifier. + * + * \par Overview + * - CacheModifiedOutputIterator is a random-access output iterator that wraps a native + * device pointer of type ValueType*. \p ValueType references are + * made by writing \p ValueType values through stores modified by \p MODIFIER. + * - Can be used to store any data type to memory using PTX cache store modifiers (e.g., "STORE_WB", + * "STORE_CG", "STORE_CS", "STORE_WT", etc.). + * - Can be constructed, manipulated, and exchanged within and between host and device + * functions, but can only be dereferenced within device functions. + * - Compatible with Thrust API v1.7 or newer. + * + * \par Snippet + * The code snippet below illustrates the use of \p CacheModifiedOutputIterator to + * dereference a device array of doubles using the "wt" PTX load modifier + * (i.e., write-through to system memory). + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize a device array + * double *d_out; // e.g., [, , , , , , ] + * + * // Create an iterator wrapper + * cub::CacheModifiedOutputIterator itr(d_out); + * + * // Within device code: + * itr[0] = 8.0; + * itr[1] = 66.0; + * itr[55] = 24.0; + * + * \endcode + * + * \par Usage Considerations + * - Can only be dereferenced within device code + * + * \tparam CacheStoreModifier The cub::CacheStoreModifier to use when accessing data + * \tparam ValueType The value type of this iterator + * \tparam OffsetT The difference type of this iterator (Default: \p ptrdiff_t) + */ +template < + CacheStoreModifier MODIFIER, + typename ValueType, + typename OffsetT = ptrdiff_t> +class CacheModifiedOutputIterator +{ +private: + + // Proxy object + struct Reference + { + ValueType* ptr; + + /// Constructor + __host__ __device__ __forceinline__ Reference(ValueType* ptr) : ptr(ptr) {} + + /// Assignment + __device__ __forceinline__ ValueType operator =(ValueType val) + { + ThreadStore(ptr, val); + return val; + } + }; + +public: + + // Required iterator traits + typedef CacheModifiedOutputIterator self_type; ///< My own type + typedef OffsetT difference_type; ///< Type to express the result of subtracting one iterator from another + typedef void value_type; ///< The type of the element the iterator can point to + typedef void pointer; ///< The type of a pointer to an element the iterator can point to + typedef Reference reference; ///< The type of a reference to an element the iterator can point to + +#if (THRUST_VERSION >= 100700) + // Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods + typedef typename thrust::detail::iterator_facade_category< + thrust::device_system_tag, + thrust::random_access_traversal_tag, + value_type, + reference + >::type iterator_category; ///< The iterator category +#else + typedef std::random_access_iterator_tag iterator_category; ///< The iterator category +#endif // THRUST_VERSION + +private: + + ValueType* ptr; + +public: + + /// Constructor + template + __host__ __device__ __forceinline__ CacheModifiedOutputIterator( + QualifiedValueType* ptr) ///< Native pointer to wrap + : + ptr(const_cast::Type *>(ptr)) + {} + + /// Postfix increment + __host__ __device__ __forceinline__ self_type operator++(int) + { + self_type retval = *this; + ptr++; + return retval; + } + + + /// Prefix increment + __host__ __device__ __forceinline__ self_type operator++() + { + ptr++; + return *this; + } + + /// Indirection + __host__ __device__ __forceinline__ reference operator*() const + { + return Reference(ptr); + } + + /// Addition + template + __host__ __device__ __forceinline__ self_type operator+(Distance n) const + { + self_type retval(ptr + n); + return retval; + } + + /// Addition assignment + template + __host__ __device__ __forceinline__ self_type& operator+=(Distance n) + { + ptr += n; + return *this; + } + + /// Subtraction + template + __host__ __device__ __forceinline__ self_type operator-(Distance n) const + { + self_type retval(ptr - n); + return retval; + } + + /// Subtraction assignment + template + __host__ __device__ __forceinline__ self_type& operator-=(Distance n) + { + ptr -= n; + return *this; + } + + /// Distance + __host__ __device__ __forceinline__ difference_type operator-(self_type other) const + { + return ptr - other.ptr; + } + + /// Array subscript + template + __host__ __device__ __forceinline__ reference operator[](Distance n) const + { + return Reference(ptr + n); + } + + /// Equal to + __host__ __device__ __forceinline__ bool operator==(const self_type& rhs) + { + return (ptr == rhs.ptr); + } + + /// Not equal to + __host__ __device__ __forceinline__ bool operator!=(const self_type& rhs) + { + return (ptr != rhs.ptr); + } + + /// ostream operator + friend std::ostream& operator<<(std::ostream& os, const self_type& itr) + { + return os; + } +}; + + +/** @} */ // end group UtilIterator + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/iterator/constant_input_iterator.cuh b/fastertransformer/cuda/cub/iterator/constant_input_iterator.cuh new file mode 100644 index 000000000..1e0a91044 --- /dev/null +++ b/fastertransformer/cuda/cub/iterator/constant_input_iterator.cuh @@ -0,0 +1,235 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Random-access iterator types + */ + +#pragma once + +#include +#include + +#include "../thread/thread_load.cuh" +#include "../thread/thread_store.cuh" +#include "../util_namespace.cuh" + +#if (THRUST_VERSION >= 100700) + // This iterator is compatible with Thrust API 1.7 and newer + #include + #include +#endif // THRUST_VERSION + + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \addtogroup UtilIterator + * @{ + */ + + +/** + * \brief A random-access input generator for dereferencing a sequence of homogeneous values + * + * \par Overview + * - Read references to a ConstantInputIteratorTiterator always return the supplied constant + * of type \p ValueType. + * - Can be used with any data type. + * - Can be constructed, manipulated, dereferenced, and exchanged within and between host and device + * functions. + * - Compatible with Thrust API v1.7 or newer. + * + * \par Snippet + * The code snippet below illustrates the use of \p ConstantInputIteratorTto + * dereference a sequence of homogeneous doubles. + * \par + * \code + * #include // or equivalently + * + * cub::ConstantInputIterator itr(5.0); + * + * printf("%f\n", itr[0]); // 5.0 + * printf("%f\n", itr[1]); // 5.0 + * printf("%f\n", itr[2]); // 5.0 + * printf("%f\n", itr[50]); // 5.0 + * + * \endcode + * + * \tparam ValueType The value type of this iterator + * \tparam OffsetT The difference type of this iterator (Default: \p ptrdiff_t) + */ +template < + typename ValueType, + typename OffsetT = ptrdiff_t> +class ConstantInputIterator +{ +public: + + // Required iterator traits + typedef ConstantInputIterator self_type; ///< My own type + typedef OffsetT difference_type; ///< Type to express the result of subtracting one iterator from another + typedef ValueType value_type; ///< The type of the element the iterator can point to + typedef ValueType* pointer; ///< The type of a pointer to an element the iterator can point to + typedef ValueType reference; ///< The type of a reference to an element the iterator can point to + +#if (THRUST_VERSION >= 100700) + // Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods + typedef typename thrust::detail::iterator_facade_category< + thrust::any_system_tag, + thrust::random_access_traversal_tag, + value_type, + reference + >::type iterator_category; ///< The iterator category +#else + typedef std::random_access_iterator_tag iterator_category; ///< The iterator category +#endif // THRUST_VERSION + +private: + + ValueType val; + OffsetT offset; +#ifdef _WIN32 + OffsetT pad[CUB_MAX(1, (16 / sizeof(OffsetT) - 1))]; // Workaround for win32 parameter-passing bug (ulonglong2 argmin DeviceReduce) +#endif + +public: + + /// Constructor + __host__ __device__ __forceinline__ ConstantInputIterator( + ValueType val, ///< Starting value for the iterator instance to report + OffsetT offset = 0) ///< Base offset + : + val(val), + offset(offset) + {} + + /// Postfix increment + __host__ __device__ __forceinline__ self_type operator++(int) + { + self_type retval = *this; + offset++; + return retval; + } + + /// Prefix increment + __host__ __device__ __forceinline__ self_type operator++() + { + offset++; + return *this; + } + + /// Indirection + __host__ __device__ __forceinline__ reference operator*() const + { + return val; + } + + /// Addition + template + __host__ __device__ __forceinline__ self_type operator+(Distance n) const + { + self_type retval(val, offset + n); + return retval; + } + + /// Addition assignment + template + __host__ __device__ __forceinline__ self_type& operator+=(Distance n) + { + offset += n; + return *this; + } + + /// Subtraction + template + __host__ __device__ __forceinline__ self_type operator-(Distance n) const + { + self_type retval(val, offset - n); + return retval; + } + + /// Subtraction assignment + template + __host__ __device__ __forceinline__ self_type& operator-=(Distance n) + { + offset -= n; + return *this; + } + + /// Distance + __host__ __device__ __forceinline__ difference_type operator-(self_type other) const + { + return offset - other.offset; + } + + /// Array subscript + template + __host__ __device__ __forceinline__ reference operator[](Distance /*n*/) const + { + return val; + } + + /// Structure dereference + __host__ __device__ __forceinline__ pointer operator->() + { + return &val; + } + + /// Equal to + __host__ __device__ __forceinline__ bool operator==(const self_type& rhs) + { + return (offset == rhs.offset) && ((val == rhs.val)); + } + + /// Not equal to + __host__ __device__ __forceinline__ bool operator!=(const self_type& rhs) + { + return (offset != rhs.offset) || (val!= rhs.val); + } + + /// ostream operator + friend std::ostream& operator<<(std::ostream& os, const self_type& itr) + { + os << "[" << itr.val << "," << itr.offset << "]"; + return os; + } + +}; + + +/** @} */ // end group UtilIterator + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/iterator/counting_input_iterator.cuh b/fastertransformer/cuda/cub/iterator/counting_input_iterator.cuh new file mode 100644 index 000000000..7f49348d6 --- /dev/null +++ b/fastertransformer/cuda/cub/iterator/counting_input_iterator.cuh @@ -0,0 +1,228 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Random-access iterator types + */ + +#pragma once + +#include +#include + +#include "../thread/thread_load.cuh" +#include "../thread/thread_store.cuh" +#include "../util_device.cuh" +#include "../util_namespace.cuh" + +#if (THRUST_VERSION >= 100700) + // This iterator is compatible with Thrust API 1.7 and newer + #include + #include +#endif // THRUST_VERSION + + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \addtogroup UtilIterator + * @{ + */ + +/** + * \brief A random-access input generator for dereferencing a sequence of incrementing integer values. + * + * \par Overview + * - After initializing a CountingInputIteratorTto a certain integer \p base, read references + * at \p offset will return the value \p base + \p offset. + * - Can be constructed, manipulated, dereferenced, and exchanged within and between host and device + * functions. + * - Compatible with Thrust API v1.7 or newer. + * + * \par Snippet + * The code snippet below illustrates the use of \p CountingInputIteratorTto + * dereference a sequence of incrementing integers. + * \par + * \code + * #include // or equivalently + * + * cub::CountingInputIterator itr(5); + * + * printf("%d\n", itr[0]); // 5 + * printf("%d\n", itr[1]); // 6 + * printf("%d\n", itr[2]); // 7 + * printf("%d\n", itr[50]); // 55 + * + * \endcode + * + * \tparam ValueType The value type of this iterator + * \tparam OffsetT The difference type of this iterator (Default: \p ptrdiff_t) + */ +template < + typename ValueType, + typename OffsetT = ptrdiff_t> +class CountingInputIterator +{ +public: + + // Required iterator traits + typedef CountingInputIterator self_type; ///< My own type + typedef OffsetT difference_type; ///< Type to express the result of subtracting one iterator from another + typedef ValueType value_type; ///< The type of the element the iterator can point to + typedef ValueType* pointer; ///< The type of a pointer to an element the iterator can point to + typedef ValueType reference; ///< The type of a reference to an element the iterator can point to + +#if (THRUST_VERSION >= 100700) + // Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods + typedef typename thrust::detail::iterator_facade_category< + thrust::any_system_tag, + thrust::random_access_traversal_tag, + value_type, + reference + >::type iterator_category; ///< The iterator category +#else + typedef std::random_access_iterator_tag iterator_category; ///< The iterator category +#endif // THRUST_VERSION + +private: + + ValueType val; + +public: + + /// Constructor + __host__ __device__ __forceinline__ CountingInputIterator( + const ValueType &val) ///< Starting value for the iterator instance to report + : + val(val) + {} + + /// Postfix increment + __host__ __device__ __forceinline__ self_type operator++(int) + { + self_type retval = *this; + val++; + return retval; + } + + /// Prefix increment + __host__ __device__ __forceinline__ self_type operator++() + { + val++; + return *this; + } + + /// Indirection + __host__ __device__ __forceinline__ reference operator*() const + { + return val; + } + + /// Addition + template + __host__ __device__ __forceinline__ self_type operator+(Distance n) const + { + self_type retval(val + (ValueType) n); + return retval; + } + + /// Addition assignment + template + __host__ __device__ __forceinline__ self_type& operator+=(Distance n) + { + val += (ValueType) n; + return *this; + } + + /// Subtraction + template + __host__ __device__ __forceinline__ self_type operator-(Distance n) const + { + self_type retval(val - (ValueType) n); + return retval; + } + + /// Subtraction assignment + template + __host__ __device__ __forceinline__ self_type& operator-=(Distance n) + { + val -= n; + return *this; + } + + /// Distance + __host__ __device__ __forceinline__ difference_type operator-(self_type other) const + { + return (difference_type) (val - other.val); + } + + /// Array subscript + template + __host__ __device__ __forceinline__ reference operator[](Distance n) const + { + return val + (ValueType) n; + } + + /// Structure dereference + __host__ __device__ __forceinline__ pointer operator->() + { + return &val; + } + + /// Equal to + __host__ __device__ __forceinline__ bool operator==(const self_type& rhs) + { + return (val == rhs.val); + } + + /// Not equal to + __host__ __device__ __forceinline__ bool operator!=(const self_type& rhs) + { + return (val != rhs.val); + } + + /// ostream operator + friend std::ostream& operator<<(std::ostream& os, const self_type& itr) + { + os << "[" << itr.val << "]"; + return os; + } + +}; + + + +/** @} */ // end group UtilIterator + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/iterator/discard_output_iterator.cuh b/fastertransformer/cuda/cub/iterator/discard_output_iterator.cuh new file mode 100644 index 000000000..28473e5f2 --- /dev/null +++ b/fastertransformer/cuda/cub/iterator/discard_output_iterator.cuh @@ -0,0 +1,220 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Random-access iterator types + */ + +#pragma once + +#include +#include + +#include "../util_namespace.cuh" +#include "../util_macro.cuh" + +#if (THRUST_VERSION >= 100700) + // This iterator is compatible with Thrust API 1.7 and newer + #include + #include +#endif // THRUST_VERSION + + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \addtogroup UtilIterator + * @{ + */ + + +/** + * \brief A discard iterator + */ +template +class DiscardOutputIterator +{ +public: + + // Required iterator traits + typedef DiscardOutputIterator self_type; ///< My own type + typedef OffsetT difference_type; ///< Type to express the result of subtracting one iterator from another + typedef void value_type; ///< The type of the element the iterator can point to + typedef void pointer; ///< The type of a pointer to an element the iterator can point to + typedef void reference; ///< The type of a reference to an element the iterator can point to + +#if (THRUST_VERSION >= 100700) + // Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods + typedef typename thrust::detail::iterator_facade_category< + thrust::any_system_tag, + thrust::random_access_traversal_tag, + value_type, + reference + >::type iterator_category; ///< The iterator category +#else + typedef std::random_access_iterator_tag iterator_category; ///< The iterator category +#endif // THRUST_VERSION + +private: + + OffsetT offset; + +#if defined(_WIN32) || !defined(_WIN64) + // Workaround for win32 parameter-passing bug (ulonglong2 argmin DeviceReduce) + OffsetT pad[CUB_MAX(1, (16 / sizeof(OffsetT) - 1))]; +#endif + +public: + + /// Constructor + __host__ __device__ __forceinline__ DiscardOutputIterator( + OffsetT offset = 0) ///< Base offset + : + offset(offset) + {} + + /// Postfix increment + __host__ __device__ __forceinline__ self_type operator++(int) + { + self_type retval = *this; + offset++; + return retval; + } + + /// Prefix increment + __host__ __device__ __forceinline__ self_type operator++() + { + offset++; + return *this; + } + + /// Indirection + __host__ __device__ __forceinline__ self_type& operator*() + { + // return self reference, which can be assigned to anything + return *this; + } + + /// Addition + template + __host__ __device__ __forceinline__ self_type operator+(Distance n) const + { + self_type retval(offset + n); + return retval; + } + + /// Addition assignment + template + __host__ __device__ __forceinline__ self_type& operator+=(Distance n) + { + offset += n; + return *this; + } + + /// Subtraction + template + __host__ __device__ __forceinline__ self_type operator-(Distance n) const + { + self_type retval(offset - n); + return retval; + } + + /// Subtraction assignment + template + __host__ __device__ __forceinline__ self_type& operator-=(Distance n) + { + offset -= n; + return *this; + } + + /// Distance + __host__ __device__ __forceinline__ difference_type operator-(self_type other) const + { + return offset - other.offset; + } + + /// Array subscript + template + __host__ __device__ __forceinline__ self_type& operator[](Distance n) + { + // return self reference, which can be assigned to anything + return *this; + } + + /// Structure dereference + __host__ __device__ __forceinline__ pointer operator->() + { + return; + } + + /// Assignment to self (no-op) + __host__ __device__ __forceinline__ void operator=(self_type const& other) + { + offset = other.offset; + } + + /// Assignment to anything else (no-op) + template + __host__ __device__ __forceinline__ void operator=(T const&) + {} + + /// Cast to void* operator + __host__ __device__ __forceinline__ operator void*() const { return NULL; } + + /// Equal to + __host__ __device__ __forceinline__ bool operator==(const self_type& rhs) + { + return (offset == rhs.offset); + } + + /// Not equal to + __host__ __device__ __forceinline__ bool operator!=(const self_type& rhs) + { + return (offset != rhs.offset); + } + + /// ostream operator + friend std::ostream& operator<<(std::ostream& os, const self_type& itr) + { + os << "[" << itr.offset << "]"; + return os; + } + +}; + + +/** @} */ // end group UtilIterator + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/iterator/tex_obj_input_iterator.cuh b/fastertransformer/cuda/cub/iterator/tex_obj_input_iterator.cuh new file mode 100644 index 000000000..b99103ec5 --- /dev/null +++ b/fastertransformer/cuda/cub/iterator/tex_obj_input_iterator.cuh @@ -0,0 +1,310 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Random-access iterator types + */ + +#pragma once + +#include +#include + +#include "../thread/thread_load.cuh" +#include "../thread/thread_store.cuh" +#include "../util_device.cuh" +#include "../util_debug.cuh" +#include "../util_namespace.cuh" + +#if (THRUST_VERSION >= 100700) + // This iterator is compatible with Thrust API 1.7 and newer + #include + #include +#endif // THRUST_VERSION + + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \addtogroup UtilIterator + * @{ + */ + + + +/** + * \brief A random-access input wrapper for dereferencing array values through texture cache. Uses newer Kepler-style texture objects. + * + * \par Overview + * - TexObjInputIteratorTwraps a native device pointer of type ValueType*. References + * to elements are to be loaded through texture cache. + * - Can be used to load any data type from memory through texture cache. + * - Can be manipulated and exchanged within and between host and device + * functions, can only be constructed within host functions, and can only be + * dereferenced within device functions. + * - With regard to nested/dynamic parallelism, TexObjInputIteratorTiterators may only be + * created by the host thread, but can be used by any descendant kernel. + * - Compatible with Thrust API v1.7 or newer. + * + * \par Snippet + * The code snippet below illustrates the use of \p TexRefInputIteratorTto + * dereference a device array of doubles through texture cache. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize a device array + * int num_items; // e.g., 7 + * double *d_in; // e.g., [8.0, 6.0, 7.0, 5.0, 3.0, 0.0, 9.0] + * + * // Create an iterator wrapper + * cub::TexObjInputIterator itr; + * itr.BindTexture(d_in, sizeof(double) * num_items); + * ... + * + * // Within device code: + * printf("%f\n", itr[0]); // 8.0 + * printf("%f\n", itr[1]); // 6.0 + * printf("%f\n", itr[6]); // 9.0 + * + * ... + * itr.UnbindTexture(); + * + * \endcode + * + * \tparam T The value type of this iterator + * \tparam OffsetT The difference type of this iterator (Default: \p ptrdiff_t) + */ +template < + typename T, + typename OffsetT = ptrdiff_t> +class TexObjInputIterator +{ +public: + + // Required iterator traits + typedef TexObjInputIterator self_type; ///< My own type + typedef OffsetT difference_type; ///< Type to express the result of subtracting one iterator from another + typedef T value_type; ///< The type of the element the iterator can point to + typedef T* pointer; ///< The type of a pointer to an element the iterator can point to + typedef T reference; ///< The type of a reference to an element the iterator can point to + +#if (THRUST_VERSION >= 100700) + // Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods + typedef typename thrust::detail::iterator_facade_category< + thrust::device_system_tag, + thrust::random_access_traversal_tag, + value_type, + reference + >::type iterator_category; ///< The iterator category +#else + typedef std::random_access_iterator_tag iterator_category; ///< The iterator category +#endif // THRUST_VERSION + +private: + + // Largest texture word we can use in device + typedef typename UnitWord::TextureWord TextureWord; + + // Number of texture words per T + enum { + TEXTURE_MULTIPLE = sizeof(T) / sizeof(TextureWord) + }; + +private: + + T* ptr; + difference_type tex_offset; + cudaTextureObject_t tex_obj; + +public: + + /// Constructor + __host__ __device__ __forceinline__ TexObjInputIterator() + : + ptr(NULL), + tex_offset(0), + tex_obj(0) + {} + + /// Use this iterator to bind \p ptr with a texture reference + template + cudaError_t BindTexture( + QualifiedT *ptr, ///< Native pointer to wrap that is aligned to cudaDeviceProp::textureAlignment + size_t bytes = size_t(-1), ///< Number of bytes in the range + size_t tex_offset = 0) ///< OffsetT (in items) from \p ptr denoting the position of the iterator + { + this->ptr = const_cast::Type *>(ptr); + this->tex_offset = tex_offset; + + cudaChannelFormatDesc channel_desc = cudaCreateChannelDesc(); + cudaResourceDesc res_desc; + cudaTextureDesc tex_desc; + memset(&res_desc, 0, sizeof(cudaResourceDesc)); + memset(&tex_desc, 0, sizeof(cudaTextureDesc)); + res_desc.resType = cudaResourceTypeLinear; + res_desc.res.linear.devPtr = this->ptr; + res_desc.res.linear.desc = channel_desc; + res_desc.res.linear.sizeInBytes = bytes; + tex_desc.readMode = cudaReadModeElementType; + return cudaCreateTextureObject(&tex_obj, &res_desc, &tex_desc, NULL); + } + + /// Unbind this iterator from its texture reference + cudaError_t UnbindTexture() + { + return cudaDestroyTextureObject(tex_obj); + } + + /// Postfix increment + __host__ __device__ __forceinline__ self_type operator++(int) + { + self_type retval = *this; + tex_offset++; + return retval; + } + + /// Prefix increment + __host__ __device__ __forceinline__ self_type operator++() + { + tex_offset++; + return *this; + } + + /// Indirection + __host__ __device__ __forceinline__ reference operator*() const + { +#if (CUB_PTX_ARCH == 0) + // Simply dereference the pointer on the host + return ptr[tex_offset]; +#else + // Move array of uninitialized words, then alias and assign to return value + TextureWord words[TEXTURE_MULTIPLE]; + + #pragma unroll + for (int i = 0; i < TEXTURE_MULTIPLE; ++i) + { + words[i] = tex1Dfetch( + tex_obj, + (tex_offset * TEXTURE_MULTIPLE) + i); + } + + // Load from words + return *reinterpret_cast(words); +#endif + } + + /// Addition + template + __host__ __device__ __forceinline__ self_type operator+(Distance n) const + { + self_type retval; + retval.ptr = ptr; + retval.tex_obj = tex_obj; + retval.tex_offset = tex_offset + n; + return retval; + } + + /// Addition assignment + template + __host__ __device__ __forceinline__ self_type& operator+=(Distance n) + { + tex_offset += n; + return *this; + } + + /// Subtraction + template + __host__ __device__ __forceinline__ self_type operator-(Distance n) const + { + self_type retval; + retval.ptr = ptr; + retval.tex_obj = tex_obj; + retval.tex_offset = tex_offset - n; + return retval; + } + + /// Subtraction assignment + template + __host__ __device__ __forceinline__ self_type& operator-=(Distance n) + { + tex_offset -= n; + return *this; + } + + /// Distance + __host__ __device__ __forceinline__ difference_type operator-(self_type other) const + { + return tex_offset - other.tex_offset; + } + + /// Array subscript + template + __host__ __device__ __forceinline__ reference operator[](Distance n) const + { + self_type offset = (*this) + n; + return *offset; + } + + /// Structure dereference + __host__ __device__ __forceinline__ pointer operator->() + { + return &(*(*this)); + } + + /// Equal to + __host__ __device__ __forceinline__ bool operator==(const self_type& rhs) + { + return ((ptr == rhs.ptr) && (tex_offset == rhs.tex_offset) && (tex_obj == rhs.tex_obj)); + } + + /// Not equal to + __host__ __device__ __forceinline__ bool operator!=(const self_type& rhs) + { + return ((ptr != rhs.ptr) || (tex_offset != rhs.tex_offset) || (tex_obj != rhs.tex_obj)); + } + + /// ostream operator + friend std::ostream& operator<<(std::ostream& os, const self_type& itr) + { + return os; + } + +}; + + + +/** @} */ // end group UtilIterator + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/iterator/tex_ref_input_iterator.cuh b/fastertransformer/cuda/cub/iterator/tex_ref_input_iterator.cuh new file mode 100644 index 000000000..95d0ffbc9 --- /dev/null +++ b/fastertransformer/cuda/cub/iterator/tex_ref_input_iterator.cuh @@ -0,0 +1,374 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Random-access iterator types + */ + +#pragma once + +#include +#include + +#include "../thread/thread_load.cuh" +#include "../thread/thread_store.cuh" +#include "../util_device.cuh" +#include "../util_debug.cuh" +#include "../util_namespace.cuh" + +#if (CUDA_VERSION >= 5050) || defined(DOXYGEN_ACTIVE) // This iterator is compatible with CUDA 5.5 and newer + +#if (THRUST_VERSION >= 100700) // This iterator is compatible with Thrust API 1.7 and newer + #include + #include +#endif // THRUST_VERSION + + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/****************************************************************************** + * Static file-scope Tesla/Fermi-style texture references + *****************************************************************************/ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +// Anonymous namespace +namespace { + +/// Global texture reference specialized by type +template +struct IteratorTexRef +{ + /// And by unique ID + template + struct TexId + { + // Largest texture word we can use in device + typedef typename UnitWord::DeviceWord DeviceWord; + typedef typename UnitWord::TextureWord TextureWord; + + // Number of texture words per T + enum { + DEVICE_MULTIPLE = sizeof(T) / sizeof(DeviceWord), + TEXTURE_MULTIPLE = sizeof(T) / sizeof(TextureWord) + }; + + // Texture reference type + typedef texture TexRef; + + // Texture reference + static TexRef ref; + + /// Bind texture + static cudaError_t BindTexture(void *d_in, size_t &offset) + { + if (d_in) + { + cudaChannelFormatDesc tex_desc = cudaCreateChannelDesc(); + ref.channelDesc = tex_desc; + return (CubDebug(cudaBindTexture(&offset, ref, d_in))); + } + + return cudaSuccess; + } + + /// Unbind texture + static cudaError_t UnbindTexture() + { + return CubDebug(cudaUnbindTexture(ref)); + } + + /// Fetch element + template + static __device__ __forceinline__ T Fetch(Distance tex_offset) + { + DeviceWord temp[DEVICE_MULTIPLE]; + TextureWord *words = reinterpret_cast(temp); + + #pragma unroll + for (int i = 0; i < TEXTURE_MULTIPLE; ++i) + { + words[i] = tex1Dfetch(ref, (tex_offset * TEXTURE_MULTIPLE) + i); + } + + return reinterpret_cast(temp); + } + }; +}; + +// Texture reference definitions +template +template +typename IteratorTexRef::template TexId::TexRef IteratorTexRef::template TexId::ref = 0; + + +} // Anonymous namespace + + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + + +/** + * \addtogroup UtilIterator + * @{ + */ + + + +/** + * \brief A random-access input wrapper for dereferencing array values through texture cache. Uses older Tesla/Fermi-style texture references. + * + * \par Overview + * - TexRefInputIteratorTwraps a native device pointer of type ValueType*. References + * to elements are to be loaded through texture cache. + * - Can be used to load any data type from memory through texture cache. + * - Can be manipulated and exchanged within and between host and device + * functions, can only be constructed within host functions, and can only be + * dereferenced within device functions. + * - The \p UNIQUE_ID template parameter is used to statically name the underlying texture + * reference. Only one TexRefInputIteratorTinstance can be bound at any given time for a + * specific combination of (1) data type \p T, (2) \p UNIQUE_ID, (3) host + * thread, and (4) compilation .o unit. + * - With regard to nested/dynamic parallelism, TexRefInputIteratorTiterators may only be + * created by the host thread and used by a top-level kernel (i.e. the one which is launched + * from the host). + * - Compatible with Thrust API v1.7 or newer. + * - Compatible with CUDA toolkit v5.5 or newer. + * + * \par Snippet + * The code snippet below illustrates the use of \p TexRefInputIteratorTto + * dereference a device array of doubles through texture cache. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize a device array + * int num_items; // e.g., 7 + * double *d_in; // e.g., [8.0, 6.0, 7.0, 5.0, 3.0, 0.0, 9.0] + * + * // Create an iterator wrapper + * cub::TexRefInputIterator itr; + * itr.BindTexture(d_in, sizeof(double) * num_items); + * ... + * + * // Within device code: + * printf("%f\n", itr[0]); // 8.0 + * printf("%f\n", itr[1]); // 6.0 + * printf("%f\n", itr[6]); // 9.0 + * + * ... + * itr.UnbindTexture(); + * + * \endcode + * + * \tparam T The value type of this iterator + * \tparam UNIQUE_ID A globally-unique identifier (within the compilation unit) to name the underlying texture reference + * \tparam OffsetT The difference type of this iterator (Default: \p ptrdiff_t) + */ +template < + typename T, + int UNIQUE_ID, + typename OffsetT = ptrdiff_t> +class TexRefInputIterator +{ +public: + + // Required iterator traits + typedef TexRefInputIterator self_type; ///< My own type + typedef OffsetT difference_type; ///< Type to express the result of subtracting one iterator from another + typedef T value_type; ///< The type of the element the iterator can point to + typedef T* pointer; ///< The type of a pointer to an element the iterator can point to + typedef T reference; ///< The type of a reference to an element the iterator can point to + +#if (THRUST_VERSION >= 100700) + // Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods + typedef typename thrust::detail::iterator_facade_category< + thrust::device_system_tag, + thrust::random_access_traversal_tag, + value_type, + reference + >::type iterator_category; ///< The iterator category +#else + typedef std::random_access_iterator_tag iterator_category; ///< The iterator category +#endif // THRUST_VERSION + +private: + + T* ptr; + difference_type tex_offset; + + // Texture reference wrapper (old Tesla/Fermi-style textures) + typedef typename IteratorTexRef::template TexId TexId; + +public: +/* + /// Constructor + __host__ __device__ __forceinline__ TexRefInputIterator() + : + ptr(NULL), + tex_offset(0) + {} +*/ + /// Use this iterator to bind \p ptr with a texture reference + template + cudaError_t BindTexture( + QualifiedT *ptr, ///< Native pointer to wrap that is aligned to cudaDeviceProp::textureAlignment + size_t bytes = size_t(-1), ///< Number of bytes in the range + size_t tex_offset = 0) ///< OffsetT (in items) from \p ptr denoting the position of the iterator + { + this->ptr = const_cast::Type *>(ptr); + size_t offset; + cudaError_t retval = TexId::BindTexture(this->ptr + tex_offset, offset); + this->tex_offset = (difference_type) (offset / sizeof(QualifiedT)); + return retval; + } + + /// Unbind this iterator from its texture reference + cudaError_t UnbindTexture() + { + return TexId::UnbindTexture(); + } + + /// Postfix increment + __host__ __device__ __forceinline__ self_type operator++(int) + { + self_type retval = *this; + tex_offset++; + return retval; + } + + /// Prefix increment + __host__ __device__ __forceinline__ self_type operator++() + { + tex_offset++; + return *this; + } + + /// Indirection + __host__ __device__ __forceinline__ reference operator*() const + { +#if (CUB_PTX_ARCH == 0) + // Simply dereference the pointer on the host + return ptr[tex_offset]; +#else + // Use the texture reference + return TexId::Fetch(tex_offset); +#endif + } + + /// Addition + template + __host__ __device__ __forceinline__ self_type operator+(Distance n) const + { + self_type retval; + retval.ptr = ptr; + retval.tex_offset = tex_offset + n; + return retval; + } + + /// Addition assignment + template + __host__ __device__ __forceinline__ self_type& operator+=(Distance n) + { + tex_offset += n; + return *this; + } + + /// Subtraction + template + __host__ __device__ __forceinline__ self_type operator-(Distance n) const + { + self_type retval; + retval.ptr = ptr; + retval.tex_offset = tex_offset - n; + return retval; + } + + /// Subtraction assignment + template + __host__ __device__ __forceinline__ self_type& operator-=(Distance n) + { + tex_offset -= n; + return *this; + } + + /// Distance + __host__ __device__ __forceinline__ difference_type operator-(self_type other) const + { + return tex_offset - other.tex_offset; + } + + /// Array subscript + template + __host__ __device__ __forceinline__ reference operator[](Distance n) const + { + self_type offset = (*this) + n; + return *offset; + } + + /// Structure dereference + __host__ __device__ __forceinline__ pointer operator->() + { + return &(*(*this)); + } + + /// Equal to + __host__ __device__ __forceinline__ bool operator==(const self_type& rhs) + { + return ((ptr == rhs.ptr) && (tex_offset == rhs.tex_offset)); + } + + /// Not equal to + __host__ __device__ __forceinline__ bool operator!=(const self_type& rhs) + { + return ((ptr != rhs.ptr) || (tex_offset != rhs.tex_offset)); + } + + /// ostream operator + friend std::ostream& operator<<(std::ostream& os, const self_type& itr) + { + return os; + } + +}; + + + +/** @} */ // end group UtilIterator + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + +#endif // CUDA_VERSION diff --git a/fastertransformer/cuda/cub/iterator/transform_input_iterator.cuh b/fastertransformer/cuda/cub/iterator/transform_input_iterator.cuh new file mode 100644 index 000000000..dad1f5004 --- /dev/null +++ b/fastertransformer/cuda/cub/iterator/transform_input_iterator.cuh @@ -0,0 +1,252 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Random-access iterator types + */ + +#pragma once + +#include +#include + +#include "../thread/thread_load.cuh" +#include "../thread/thread_store.cuh" +#include "../util_device.cuh" +#include "../util_namespace.cuh" + +#if (THRUST_VERSION >= 100700) + // This iterator is compatible with Thrust API 1.7 and newer + #include + #include +#endif // THRUST_VERSION + + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \addtogroup UtilIterator + * @{ + */ + + +/** + * \brief A random-access input wrapper for transforming dereferenced values. + * + * \par Overview + * - TransformInputIteratorTwraps a unary conversion functor of type \p + * ConversionOp and a random-access input iterator of type InputIteratorT, + * using the former to produce references of type \p ValueType from the latter. + * - Can be used with any data type. + * - Can be constructed, manipulated, and exchanged within and between host and device + * functions. Wrapped host memory can only be dereferenced on the host, and wrapped + * device memory can only be dereferenced on the device. + * - Compatible with Thrust API v1.7 or newer. + * + * \par Snippet + * The code snippet below illustrates the use of \p TransformInputIteratorTto + * dereference an array of integers, tripling the values and converting them to doubles. + * \par + * \code + * #include // or equivalently + * + * // Functor for tripling integer values and converting to doubles + * struct TripleDoubler + * { + * __host__ __device__ __forceinline__ + * double operator()(const int &a) const { + * return double(a * 3); + * } + * }; + * + * // Declare, allocate, and initialize a device array + * int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] + * TripleDoubler conversion_op; + * + * // Create an iterator wrapper + * cub::TransformInputIterator itr(d_in, conversion_op); + * + * // Within device code: + * printf("%f\n", itr[0]); // 24.0 + * printf("%f\n", itr[1]); // 18.0 + * printf("%f\n", itr[6]); // 27.0 + * + * \endcode + * + * \tparam ValueType The value type of this iterator + * \tparam ConversionOp Unary functor type for mapping objects of type \p InputType to type \p ValueType. Must have member ValueType operator()(const InputType &datum). + * \tparam InputIteratorT The type of the wrapped input iterator + * \tparam OffsetT The difference type of this iterator (Default: \p ptrdiff_t) + * + */ +template < + typename ValueType, + typename ConversionOp, + typename InputIteratorT, + typename OffsetT = ptrdiff_t> +class TransformInputIterator +{ +public: + + // Required iterator traits + typedef TransformInputIterator self_type; ///< My own type + typedef OffsetT difference_type; ///< Type to express the result of subtracting one iterator from another + typedef ValueType value_type; ///< The type of the element the iterator can point to + typedef ValueType* pointer; ///< The type of a pointer to an element the iterator can point to + typedef ValueType reference; ///< The type of a reference to an element the iterator can point to + +#if (THRUST_VERSION >= 100700) + // Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods + typedef typename thrust::detail::iterator_facade_category< + thrust::any_system_tag, + thrust::random_access_traversal_tag, + value_type, + reference + >::type iterator_category; ///< The iterator category +#else + typedef std::random_access_iterator_tag iterator_category; ///< The iterator category +#endif // THRUST_VERSION + +private: + + ConversionOp conversion_op; + InputIteratorT input_itr; + +public: + + /// Constructor + __host__ __device__ __forceinline__ TransformInputIterator( + InputIteratorT input_itr, ///< Input iterator to wrap + ConversionOp conversion_op) ///< Conversion functor to wrap + : + conversion_op(conversion_op), + input_itr(input_itr) + {} + + /// Postfix increment + __host__ __device__ __forceinline__ self_type operator++(int) + { + self_type retval = *this; + input_itr++; + return retval; + } + + /// Prefix increment + __host__ __device__ __forceinline__ self_type operator++() + { + input_itr++; + return *this; + } + + /// Indirection + __host__ __device__ __forceinline__ reference operator*() const + { + return conversion_op(*input_itr); + } + + /// Addition + template + __host__ __device__ __forceinline__ self_type operator+(Distance n) const + { + self_type retval(input_itr + n, conversion_op); + return retval; + } + + /// Addition assignment + template + __host__ __device__ __forceinline__ self_type& operator+=(Distance n) + { + input_itr += n; + return *this; + } + + /// Subtraction + template + __host__ __device__ __forceinline__ self_type operator-(Distance n) const + { + self_type retval(input_itr - n, conversion_op); + return retval; + } + + /// Subtraction assignment + template + __host__ __device__ __forceinline__ self_type& operator-=(Distance n) + { + input_itr -= n; + return *this; + } + + /// Distance + __host__ __device__ __forceinline__ difference_type operator-(self_type other) const + { + return input_itr - other.input_itr; + } + + /// Array subscript + template + __host__ __device__ __forceinline__ reference operator[](Distance n) const + { + return conversion_op(input_itr[n]); + } + + /// Structure dereference + __host__ __device__ __forceinline__ pointer operator->() + { + return &conversion_op(*input_itr); + } + + /// Equal to + __host__ __device__ __forceinline__ bool operator==(const self_type& rhs) + { + return (input_itr == rhs.input_itr); + } + + /// Not equal to + __host__ __device__ __forceinline__ bool operator!=(const self_type& rhs) + { + return (input_itr != rhs.input_itr); + } + + /// ostream operator + friend std::ostream& operator<<(std::ostream& os, const self_type& itr) + { + return os; + } +}; + + + +/** @} */ // end group UtilIterator + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/thread/thread_load.cuh b/fastertransformer/cuda/cub/thread/thread_load.cuh new file mode 100644 index 000000000..b1ca412fa --- /dev/null +++ b/fastertransformer/cuda/cub/thread/thread_load.cuh @@ -0,0 +1,438 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Thread utilities for reading memory using PTX cache modifiers. + */ + +#pragma once + +#include + +#include + +#include "../util_ptx.cuh" +#include "../util_type.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \addtogroup UtilIo + * @{ + */ + +//----------------------------------------------------------------------------- +// Tags and constants +//----------------------------------------------------------------------------- + +/** + * \brief Enumeration of cache modifiers for memory load operations. + */ +enum CacheLoadModifier +{ + LOAD_DEFAULT, ///< Default (no modifier) + LOAD_CA, ///< Cache at all levels + LOAD_CG, ///< Cache at global level + LOAD_CS, ///< Cache streaming (likely to be accessed once) + LOAD_CV, ///< Cache as volatile (including cached system lines) + LOAD_LDG, ///< Cache as texture + LOAD_VOLATILE, ///< Volatile (any memory space) +}; + + +/** + * \name Thread I/O (cache modified) + * @{ + */ + +/** + * \brief Thread utility for reading memory using cub::CacheLoadModifier cache modifiers. Can be used to load any data type. + * + * \par Example + * \code + * #include // or equivalently + * + * // 32-bit load using cache-global modifier: + * int *d_in; + * int val = cub::ThreadLoad(d_in + threadIdx.x); + * + * // 16-bit load using default modifier + * short *d_in; + * short val = cub::ThreadLoad(d_in + threadIdx.x); + * + * // 256-bit load using cache-volatile modifier + * double4 *d_in; + * double4 val = cub::ThreadLoad(d_in + threadIdx.x); + * + * // 96-bit load using cache-streaming modifier + * struct TestFoo { bool a; short b; }; + * TestFoo *d_struct; + * TestFoo val = cub::ThreadLoad(d_in + threadIdx.x); + * \endcode + * + * \tparam MODIFIER [inferred] CacheLoadModifier enumeration + * \tparam InputIteratorT [inferred] Input iterator type \iterator + */ +template < + CacheLoadModifier MODIFIER, + typename InputIteratorT> +__device__ __forceinline__ typename std::iterator_traits::value_type ThreadLoad(InputIteratorT itr); + + +//@} end member group + + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + +/// Helper structure for templated load iteration (inductive case) +template +struct IterateThreadLoad +{ + template + static __device__ __forceinline__ void Load(T const *ptr, T *vals) + { + vals[COUNT] = ThreadLoad(ptr + COUNT); + IterateThreadLoad::template Load(ptr, vals); + } + + template + static __device__ __forceinline__ void Dereference(InputIteratorT itr, T *vals) + { + vals[COUNT] = itr[COUNT]; + IterateThreadLoad::Dereference(itr, vals); + } +}; + + +/// Helper structure for templated load iteration (termination case) +template +struct IterateThreadLoad +{ + template + static __device__ __forceinline__ void Load(T const * /*ptr*/, T * /*vals*/) {} + + template + static __device__ __forceinline__ void Dereference(InputIteratorT /*itr*/, T * /*vals*/) {} +}; + + +/** + * Define a uint4 (16B) ThreadLoad specialization for the given Cache load modifier + */ +#define _CUB_LOAD_16(cub_modifier, ptx_modifier) \ + template<> \ + __device__ __forceinline__ uint4 ThreadLoad(uint4 const *ptr) \ + { \ + uint4 retval; \ + asm volatile ("ld."#ptx_modifier".v4.u32 {%0, %1, %2, %3}, [%4];" : \ + "=r"(retval.x), \ + "=r"(retval.y), \ + "=r"(retval.z), \ + "=r"(retval.w) : \ + _CUB_ASM_PTR_(ptr)); \ + return retval; \ + } \ + template<> \ + __device__ __forceinline__ ulonglong2 ThreadLoad(ulonglong2 const *ptr) \ + { \ + ulonglong2 retval; \ + asm volatile ("ld."#ptx_modifier".v2.u64 {%0, %1}, [%2];" : \ + "=l"(retval.x), \ + "=l"(retval.y) : \ + _CUB_ASM_PTR_(ptr)); \ + return retval; \ + } + +/** + * Define a uint2 (8B) ThreadLoad specialization for the given Cache load modifier + */ +#define _CUB_LOAD_8(cub_modifier, ptx_modifier) \ + template<> \ + __device__ __forceinline__ ushort4 ThreadLoad(ushort4 const *ptr) \ + { \ + ushort4 retval; \ + asm volatile ("ld."#ptx_modifier".v4.u16 {%0, %1, %2, %3}, [%4];" : \ + "=h"(retval.x), \ + "=h"(retval.y), \ + "=h"(retval.z), \ + "=h"(retval.w) : \ + _CUB_ASM_PTR_(ptr)); \ + return retval; \ + } \ + template<> \ + __device__ __forceinline__ uint2 ThreadLoad(uint2 const *ptr) \ + { \ + uint2 retval; \ + asm volatile ("ld."#ptx_modifier".v2.u32 {%0, %1}, [%2];" : \ + "=r"(retval.x), \ + "=r"(retval.y) : \ + _CUB_ASM_PTR_(ptr)); \ + return retval; \ + } \ + template<> \ + __device__ __forceinline__ unsigned long long ThreadLoad(unsigned long long const *ptr) \ + { \ + unsigned long long retval; \ + asm volatile ("ld."#ptx_modifier".u64 %0, [%1];" : \ + "=l"(retval) : \ + _CUB_ASM_PTR_(ptr)); \ + return retval; \ + } + +/** + * Define a uint (4B) ThreadLoad specialization for the given Cache load modifier + */ +#define _CUB_LOAD_4(cub_modifier, ptx_modifier) \ + template<> \ + __device__ __forceinline__ unsigned int ThreadLoad(unsigned int const *ptr) \ + { \ + unsigned int retval; \ + asm volatile ("ld."#ptx_modifier".u32 %0, [%1];" : \ + "=r"(retval) : \ + _CUB_ASM_PTR_(ptr)); \ + return retval; \ + } + + +/** + * Define a unsigned short (2B) ThreadLoad specialization for the given Cache load modifier + */ +#define _CUB_LOAD_2(cub_modifier, ptx_modifier) \ + template<> \ + __device__ __forceinline__ unsigned short ThreadLoad(unsigned short const *ptr) \ + { \ + unsigned short retval; \ + asm volatile ("ld."#ptx_modifier".u16 %0, [%1];" : \ + "=h"(retval) : \ + _CUB_ASM_PTR_(ptr)); \ + return retval; \ + } + + +/** + * Define an unsigned char (1B) ThreadLoad specialization for the given Cache load modifier + */ +#define _CUB_LOAD_1(cub_modifier, ptx_modifier) \ + template<> \ + __device__ __forceinline__ unsigned char ThreadLoad(unsigned char const *ptr) \ + { \ + unsigned short retval; \ + asm volatile ( \ + "{" \ + " .reg .u8 datum;" \ + " ld."#ptx_modifier".u8 datum, [%1];" \ + " cvt.u16.u8 %0, datum;" \ + "}" : \ + "=h"(retval) : \ + _CUB_ASM_PTR_(ptr)); \ + return (unsigned char) retval; \ + } + + +/** + * Define powers-of-two ThreadLoad specializations for the given Cache load modifier + */ +#define _CUB_LOAD_ALL(cub_modifier, ptx_modifier) \ + _CUB_LOAD_16(cub_modifier, ptx_modifier) \ + _CUB_LOAD_8(cub_modifier, ptx_modifier) \ + _CUB_LOAD_4(cub_modifier, ptx_modifier) \ + _CUB_LOAD_2(cub_modifier, ptx_modifier) \ + _CUB_LOAD_1(cub_modifier, ptx_modifier) \ + + +/** + * Define powers-of-two ThreadLoad specializations for the various Cache load modifiers + */ +#if CUB_PTX_ARCH >= 200 + _CUB_LOAD_ALL(LOAD_CA, ca) + _CUB_LOAD_ALL(LOAD_CG, cg) + _CUB_LOAD_ALL(LOAD_CS, cs) + _CUB_LOAD_ALL(LOAD_CV, cv) +#else + _CUB_LOAD_ALL(LOAD_CA, global) + // Use volatile to ensure coherent reads when this PTX is JIT'd to run on newer architectures with L1 + _CUB_LOAD_ALL(LOAD_CG, volatile.global) + _CUB_LOAD_ALL(LOAD_CS, global) + _CUB_LOAD_ALL(LOAD_CV, volatile.global) +#endif + +#if CUB_PTX_ARCH >= 350 + _CUB_LOAD_ALL(LOAD_LDG, global.nc) +#else + _CUB_LOAD_ALL(LOAD_LDG, global) +#endif + + +// Macro cleanup +#undef _CUB_LOAD_ALL +#undef _CUB_LOAD_1 +#undef _CUB_LOAD_2 +#undef _CUB_LOAD_4 +#undef _CUB_LOAD_8 +#undef _CUB_LOAD_16 + + + +/** + * ThreadLoad definition for LOAD_DEFAULT modifier on iterator types + */ +template +__device__ __forceinline__ typename std::iterator_traits::value_type ThreadLoad( + InputIteratorT itr, + Int2Type /*modifier*/, + Int2Type /*is_pointer*/) +{ + return *itr; +} + + +/** + * ThreadLoad definition for LOAD_DEFAULT modifier on pointer types + */ +template +__device__ __forceinline__ T ThreadLoad( + T *ptr, + Int2Type /*modifier*/, + Int2Type /*is_pointer*/) +{ + return *ptr; +} + + +/** + * ThreadLoad definition for LOAD_VOLATILE modifier on primitive pointer types + */ +template +__device__ __forceinline__ T ThreadLoadVolatilePointer( + T *ptr, + Int2Type /*is_primitive*/) +{ + T retval = *reinterpret_cast(ptr); + return retval; +} + + +/** + * ThreadLoad definition for LOAD_VOLATILE modifier on non-primitive pointer types + */ +template +__device__ __forceinline__ T ThreadLoadVolatilePointer( + T *ptr, + Int2Type /*is_primitive*/) +{ + typedef typename UnitWord::VolatileWord VolatileWord; // Word type for memcopying + + const int VOLATILE_MULTIPLE = sizeof(T) / sizeof(VolatileWord); +/* + VolatileWord words[VOLATILE_MULTIPLE]; + + IterateThreadLoad<0, VOLATILE_MULTIPLE>::Dereference( + reinterpret_cast(ptr), + words); + + return *reinterpret_cast(words); +*/ + + T retval; + VolatileWord *words = reinterpret_cast(&retval); + IterateThreadLoad<0, VOLATILE_MULTIPLE>::Dereference( + reinterpret_cast(ptr), + words); + return retval; +} + + +/** + * ThreadLoad definition for LOAD_VOLATILE modifier on pointer types + */ +template +__device__ __forceinline__ T ThreadLoad( + T *ptr, + Int2Type /*modifier*/, + Int2Type /*is_pointer*/) +{ + // Apply tags for partial-specialization + return ThreadLoadVolatilePointer(ptr, Int2Type::PRIMITIVE>()); +} + + +/** + * ThreadLoad definition for generic modifiers on pointer types + */ +template +__device__ __forceinline__ T ThreadLoad( + T const *ptr, + Int2Type /*modifier*/, + Int2Type /*is_pointer*/) +{ + typedef typename UnitWord::DeviceWord DeviceWord; + + const int DEVICE_MULTIPLE = sizeof(T) / sizeof(DeviceWord); + + DeviceWord words[DEVICE_MULTIPLE]; + + IterateThreadLoad<0, DEVICE_MULTIPLE>::template Load( + reinterpret_cast(const_cast(ptr)), + words); + + return *reinterpret_cast(words); +} + + +/** + * ThreadLoad definition for generic modifiers + */ +template < + CacheLoadModifier MODIFIER, + typename InputIteratorT> +__device__ __forceinline__ typename std::iterator_traits::value_type ThreadLoad(InputIteratorT itr) +{ + // Apply tags for partial-specialization + return ThreadLoad( + itr, + Int2Type(), + Int2Type::VALUE>()); +} + + + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + +/** @} */ // end group UtilIo + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/thread/thread_operators.cuh b/fastertransformer/cuda/cub/thread/thread_operators.cuh new file mode 100644 index 000000000..76cd800f5 --- /dev/null +++ b/fastertransformer/cuda/cub/thread/thread_operators.cuh @@ -0,0 +1,317 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Simple binary operator functor types + */ + +/****************************************************************************** + * Simple functor operators + ******************************************************************************/ + +#pragma once + +#include "../util_macro.cuh" +#include "../util_type.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \addtogroup UtilModule + * @{ + */ + +/** + * \brief Default equality functor + */ +struct Equality +{ + /// Boolean equality operator, returns (a == b) + template + __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b) const + { + return a == b; + } +}; + + +/** + * \brief Default inequality functor + */ +struct Inequality +{ + /// Boolean inequality operator, returns (a != b) + template + __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b) const + { + return a != b; + } +}; + + +/** + * \brief Inequality functor (wraps equality functor) + */ +template +struct InequalityWrapper +{ + /// Wrapped equality operator + EqualityOp op; + + /// Constructor + __host__ __device__ __forceinline__ + InequalityWrapper(EqualityOp op) : op(op) {} + + /// Boolean inequality operator, returns (a != b) + template + __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b) + { + return !op(a, b); + } +}; + + +/** + * \brief Default sum functor + */ +struct Sum +{ + /// Boolean sum operator, returns a + b + template + __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const + { + return a + b; + } +}; + + +/** + * \brief Default max functor + */ +struct Max +{ + /// Boolean max operator, returns (a > b) ? a : b + template + __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const + { + return CUB_MAX(a, b); + } +}; + + +/** + * \brief Arg max functor (keeps the value and offset of the first occurrence of the larger item) + */ +struct ArgMax +{ + /// Boolean max operator, preferring the item having the smaller offset in case of ties + template + __host__ __device__ __forceinline__ KeyValuePair operator()( + const KeyValuePair &a, + const KeyValuePair &b) const + { +// Mooch BUG (device reduce argmax gk110 3.2 million random fp32) +// return ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a; + + if ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) + return b; + return a; + } +}; + + +/** + * \brief Default min functor + */ +struct Min +{ + /// Boolean min operator, returns (a < b) ? a : b + template + __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const + { + return CUB_MIN(a, b); + } +}; + + +/** + * \brief Arg min functor (keeps the value and offset of the first occurrence of the smallest item) + */ +struct ArgMin +{ + /// Boolean min operator, preferring the item having the smaller offset in case of ties + template + __host__ __device__ __forceinline__ KeyValuePair operator()( + const KeyValuePair &a, + const KeyValuePair &b) const + { +// Mooch BUG (device reduce argmax gk110 3.2 million random fp32) +// return ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a; + + if ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) + return b; + return a; + } +}; + + +/** + * \brief Default cast functor + */ +template +struct CastOp +{ + /// Cast operator, returns (B) a + template + __host__ __device__ __forceinline__ B operator()(const A &a) const + { + return (B) a; + } +}; + + +/** + * \brief Binary operator wrapper for switching non-commutative scan arguments + */ +template +class SwizzleScanOp +{ +private: + + /// Wrapped scan operator + ScanOp scan_op; + +public: + + /// Constructor + __host__ __device__ __forceinline__ + SwizzleScanOp(ScanOp scan_op) : scan_op(scan_op) {} + + /// Switch the scan arguments + template + __host__ __device__ __forceinline__ + T operator()(const T &a, const T &b) + { + T _a(a); + T _b(b); + + return scan_op(_b, _a); + } +}; + + +/** + * \brief Reduce-by-segment functor. + * + * Given two cub::KeyValuePair inputs \p a and \p b and a + * binary associative combining operator \p f(const T &x, const T &y), + * an instance of this functor returns a cub::KeyValuePair whose \p key + * field is a.key + b.key, and whose \p value field + * is either b.value if b.key is non-zero, or f(a.value, b.value) otherwise. + * + * ReduceBySegmentOp is an associative, non-commutative binary combining operator + * for input sequences of cub::KeyValuePair pairings. Such + * sequences are typically used to represent a segmented set of values to be reduced + * and a corresponding set of {0,1}-valued integer "head flags" demarcating the + * first value of each segment. + * + */ +template ///< Binary reduction operator to apply to values +struct ReduceBySegmentOp +{ + /// Wrapped reduction operator + ReductionOpT op; + + /// Constructor + __host__ __device__ __forceinline__ ReduceBySegmentOp() {} + + /// Constructor + __host__ __device__ __forceinline__ ReduceBySegmentOp(ReductionOpT op) : op(op) {} + + /// Scan operator + template ///< KeyValuePair pairing of T (value) and OffsetT (head flag) + __host__ __device__ __forceinline__ KeyValuePairT operator()( + const KeyValuePairT &first, ///< First partial reduction + const KeyValuePairT &second) ///< Second partial reduction + { + KeyValuePairT retval; + retval.key = first.key + second.key; + retval.value = (second.key) ? + second.value : // The second partial reduction spans a segment reset, so it's value aggregate becomes the running aggregate + op(first.value, second.value); // The second partial reduction does not span a reset, so accumulate both into the running aggregate + return retval; + } +}; + + + +template ///< Binary reduction operator to apply to values +struct ReduceByKeyOp +{ + /// Wrapped reduction operator + ReductionOpT op; + + /// Constructor + __host__ __device__ __forceinline__ ReduceByKeyOp() {} + + /// Constructor + __host__ __device__ __forceinline__ ReduceByKeyOp(ReductionOpT op) : op(op) {} + + /// Scan operator + template + __host__ __device__ __forceinline__ KeyValuePairT operator()( + const KeyValuePairT &first, ///< First partial reduction + const KeyValuePairT &second) ///< Second partial reduction + { + KeyValuePairT retval = second; + + if (first.key == second.key) + retval.value = op(first.value, retval.value); + + return retval; + } +}; + + + + + + + +/** @} */ // end group UtilModule + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/thread/thread_reduce.cuh b/fastertransformer/cuda/cub/thread/thread_reduce.cuh new file mode 100644 index 000000000..4c13688f3 --- /dev/null +++ b/fastertransformer/cuda/cub/thread/thread_reduce.cuh @@ -0,0 +1,152 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Thread utilities for sequential reduction over statically-sized array types + */ + +#pragma once + +#include "../thread/thread_operators.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/// Internal namespace (to prevent ADL mishaps between static functions when mixing different CUB installations) +namespace internal { + +/** + * Sequential reduction over statically-sized array types + */ +template < + int LENGTH, + typename T, + typename ReductionOp> +__device__ __forceinline__ T ThreadReduce( + T* input, ///< [in] Input array + ReductionOp reduction_op, ///< [in] Binary reduction operator + T prefix, ///< [in] Prefix to seed reduction with + Int2Type /*length*/) +{ + T retval = prefix; + + #pragma unroll + for (int i = 0; i < LENGTH; ++i) + retval = reduction_op(retval, input[i]); + + return retval; +} + + +/** + * \brief Perform a sequential reduction over \p LENGTH elements of the \p input array, seeded with the specified \p prefix. The aggregate is returned. + * + * \tparam LENGTH LengthT of input array + * \tparam T [inferred] The data type to be reduced. + * \tparam ScanOp [inferred] Binary reduction operator type having member T operator()(const T &a, const T &b) + */ +template < + int LENGTH, + typename T, + typename ReductionOp> +__device__ __forceinline__ T ThreadReduce( + T* input, ///< [in] Input array + ReductionOp reduction_op, ///< [in] Binary reduction operator + T prefix) ///< [in] Prefix to seed reduction with +{ + return ThreadReduce(input, reduction_op, prefix, Int2Type()); +} + + +/** + * \brief Perform a sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned. + * + * \tparam LENGTH LengthT of input array + * \tparam T [inferred] The data type to be reduced. + * \tparam ScanOp [inferred] Binary reduction operator type having member T operator()(const T &a, const T &b) + */ +template < + int LENGTH, + typename T, + typename ReductionOp> +__device__ __forceinline__ T ThreadReduce( + T* input, ///< [in] Input array + ReductionOp reduction_op) ///< [in] Binary reduction operator +{ + T prefix = input[0]; + return ThreadReduce(input + 1, reduction_op, prefix); +} + + +/** + * \brief Perform a sequential reduction over the statically-sized \p input array, seeded with the specified \p prefix. The aggregate is returned. + * + * \tparam LENGTH [inferred] LengthT of \p input array + * \tparam T [inferred] The data type to be reduced. + * \tparam ScanOp [inferred] Binary reduction operator type having member T operator()(const T &a, const T &b) + */ +template < + int LENGTH, + typename T, + typename ReductionOp> +__device__ __forceinline__ T ThreadReduce( + T (&input)[LENGTH], ///< [in] Input array + ReductionOp reduction_op, ///< [in] Binary reduction operator + T prefix) ///< [in] Prefix to seed reduction with +{ + return ThreadReduce(input, reduction_op, prefix, Int2Type()); +} + + +/** + * \brief Serial reduction with the specified operator + * + * \tparam LENGTH [inferred] LengthT of \p input array + * \tparam T [inferred] The data type to be reduced. + * \tparam ScanOp [inferred] Binary reduction operator type having member T operator()(const T &a, const T &b) + */ +template < + int LENGTH, + typename T, + typename ReductionOp> +__device__ __forceinline__ T ThreadReduce( + T (&input)[LENGTH], ///< [in] Input array + ReductionOp reduction_op) ///< [in] Binary reduction operator +{ + return ThreadReduce((T*) input, reduction_op); +} + + +} // internal namespace +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/thread/thread_scan.cuh b/fastertransformer/cuda/cub/thread/thread_scan.cuh new file mode 100644 index 000000000..8d67549ae --- /dev/null +++ b/fastertransformer/cuda/cub/thread/thread_scan.cuh @@ -0,0 +1,268 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Thread utilities for sequential prefix scan over statically-sized array types + */ + +#pragma once + +#include "../thread/thread_operators.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/// Internal namespace (to prevent ADL mishaps between static functions when mixing different CUB installations) +namespace internal { + + +/** + * \addtogroup UtilModule + * @{ + */ + +/** + * \name Sequential prefix scan over statically-sized array types + * @{ + */ + +template < + int LENGTH, + typename T, + typename ScanOp> +__device__ __forceinline__ T ThreadScanExclusive( + T inclusive, + T exclusive, + T *input, ///< [in] Input array + T *output, ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + Int2Type /*length*/) +{ + #pragma unroll + for (int i = 0; i < LENGTH; ++i) + { + inclusive = scan_op(exclusive, input[i]); + output[i] = exclusive; + exclusive = inclusive; + } + + return inclusive; +} + + + +/** + * \brief Perform a sequential exclusive prefix scan over \p LENGTH elements of the \p input array, seeded with the specified \p prefix. The aggregate is returned. + * + * \tparam LENGTH LengthT of \p input and \p output arrays + * \tparam T [inferred] The data type to be scanned. + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ +template < + int LENGTH, + typename T, + typename ScanOp> +__device__ __forceinline__ T ThreadScanExclusive( + T *input, ///< [in] Input array + T *output, ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T prefix, ///< [in] Prefix to seed scan with + bool apply_prefix = true) ///< [in] Whether or not the calling thread should apply its prefix. If not, the first output element is undefined. (Handy for preventing thread-0 from applying a prefix.) +{ + T inclusive = input[0]; + if (apply_prefix) + { + inclusive = scan_op(prefix, inclusive); + } + output[0] = prefix; + T exclusive = inclusive; + + return ThreadScanExclusive(inclusive, exclusive, input + 1, output + 1, scan_op, Int2Type()); +} + + +/** + * \brief Perform a sequential exclusive prefix scan over the statically-sized \p input array, seeded with the specified \p prefix. The aggregate is returned. + * + * \tparam LENGTH [inferred] LengthT of \p input and \p output arrays + * \tparam T [inferred] The data type to be scanned. + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ +template < + int LENGTH, + typename T, + typename ScanOp> +__device__ __forceinline__ T ThreadScanExclusive( + T (&input)[LENGTH], ///< [in] Input array + T (&output)[LENGTH], ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T prefix, ///< [in] Prefix to seed scan with + bool apply_prefix = true) ///< [in] Whether or not the calling thread should apply its prefix. (Handy for preventing thread-0 from applying a prefix.) +{ + return ThreadScanExclusive((T*) input, (T*) output, scan_op, prefix, apply_prefix); +} + + + + + + + + + +template < + int LENGTH, + typename T, + typename ScanOp> +__device__ __forceinline__ T ThreadScanInclusive( + T inclusive, + T *input, ///< [in] Input array + T *output, ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + Int2Type /*length*/) +{ + #pragma unroll + for (int i = 0; i < LENGTH; ++i) + { + inclusive = scan_op(inclusive, input[i]); + output[i] = inclusive; + } + + return inclusive; +} + + +/** + * \brief Perform a sequential inclusive prefix scan over \p LENGTH elements of the \p input array. The aggregate is returned. + * + * \tparam LENGTH LengthT of \p input and \p output arrays + * \tparam T [inferred] The data type to be scanned. + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ +template < + int LENGTH, + typename T, + typename ScanOp> +__device__ __forceinline__ T ThreadScanInclusive( + T *input, ///< [in] Input array + T *output, ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op) ///< [in] Binary scan operator +{ + T inclusive = input[0]; + output[0] = inclusive; + + // Continue scan + return ThreadScanInclusive(inclusive, input + 1, output + 1, scan_op, Int2Type()); +} + + +/** + * \brief Perform a sequential inclusive prefix scan over the statically-sized \p input array. The aggregate is returned. + * + * \tparam LENGTH [inferred] LengthT of \p input and \p output arrays + * \tparam T [inferred] The data type to be scanned. + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ +template < + int LENGTH, + typename T, + typename ScanOp> +__device__ __forceinline__ T ThreadScanInclusive( + T (&input)[LENGTH], ///< [in] Input array + T (&output)[LENGTH], ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op) ///< [in] Binary scan operator +{ + return ThreadScanInclusive((T*) input, (T*) output, scan_op); +} + + +/** + * \brief Perform a sequential inclusive prefix scan over \p LENGTH elements of the \p input array, seeded with the specified \p prefix. The aggregate is returned. + * + * \tparam LENGTH LengthT of \p input and \p output arrays + * \tparam T [inferred] The data type to be scanned. + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ +template < + int LENGTH, + typename T, + typename ScanOp> +__device__ __forceinline__ T ThreadScanInclusive( + T *input, ///< [in] Input array + T *output, ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T prefix, ///< [in] Prefix to seed scan with + bool apply_prefix = true) ///< [in] Whether or not the calling thread should apply its prefix. (Handy for preventing thread-0 from applying a prefix.) +{ + T inclusive = input[0]; + if (apply_prefix) + { + inclusive = scan_op(prefix, inclusive); + } + output[0] = inclusive; + + // Continue scan + return ThreadScanInclusive(inclusive, input + 1, output + 1, scan_op, Int2Type()); +} + + +/** + * \brief Perform a sequential inclusive prefix scan over the statically-sized \p input array, seeded with the specified \p prefix. The aggregate is returned. + * + * \tparam LENGTH [inferred] LengthT of \p input and \p output arrays + * \tparam T [inferred] The data type to be scanned. + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ +template < + int LENGTH, + typename T, + typename ScanOp> +__device__ __forceinline__ T ThreadScanInclusive( + T (&input)[LENGTH], ///< [in] Input array + T (&output)[LENGTH], ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T prefix, ///< [in] Prefix to seed scan with + bool apply_prefix = true) ///< [in] Whether or not the calling thread should apply its prefix. (Handy for preventing thread-0 from applying a prefix.) +{ + return ThreadScanInclusive((T*) input, (T*) output, scan_op, prefix, apply_prefix); +} + + +//@} end member group + +/** @} */ // end group UtilModule + + +} // internal namespace +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/thread/thread_search.cuh b/fastertransformer/cuda/cub/thread/thread_search.cuh new file mode 100644 index 000000000..3099080a3 --- /dev/null +++ b/fastertransformer/cuda/cub/thread/thread_search.cuh @@ -0,0 +1,154 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Thread utilities for sequential search + */ + +#pragma once + +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * Computes the begin offsets into A and B for the specific diagonal + */ +template < + typename AIteratorT, + typename BIteratorT, + typename OffsetT, + typename CoordinateT> +__host__ __device__ __forceinline__ void MergePathSearch( + OffsetT diagonal, + AIteratorT a, + BIteratorT b, + OffsetT a_len, + OffsetT b_len, + CoordinateT& path_coordinate) +{ + /// The value type of the input iterator + typedef typename std::iterator_traits::value_type T; + + OffsetT split_min = CUB_MAX(diagonal - b_len, 0); + OffsetT split_max = CUB_MIN(diagonal, a_len); + + while (split_min < split_max) + { + OffsetT split_pivot = (split_min + split_max) >> 1; + if (a[split_pivot] <= b[diagonal - split_pivot - 1]) + { + // Move candidate split range up A, down B + split_min = split_pivot + 1; + } + else + { + // Move candidate split range up B, down A + split_max = split_pivot; + } + } + + path_coordinate.x = CUB_MIN(split_min, a_len); + path_coordinate.y = diagonal - split_min; +} + + + +/** + * \brief Returns the offset of the first value within \p input which does not compare less than \p val + */ +template < + typename InputIteratorT, + typename OffsetT, + typename T> +__device__ __forceinline__ OffsetT LowerBound( + InputIteratorT input, ///< [in] Input sequence + OffsetT num_items, ///< [in] Input sequence length + T val) ///< [in] Search key +{ + OffsetT retval = 0; + while (num_items > 0) + { + OffsetT half = num_items >> 1; + if (input[retval + half] < val) + { + retval = retval + (half + 1); + num_items = num_items - (half + 1); + } + else + { + num_items = half; + } + } + + return retval; +} + + +/** + * \brief Returns the offset of the first value within \p input which compares greater than \p val + */ +template < + typename InputIteratorT, + typename OffsetT, + typename T> +__device__ __forceinline__ OffsetT UpperBound( + InputIteratorT input, ///< [in] Input sequence + OffsetT num_items, ///< [in] Input sequence length + T val) ///< [in] Search key +{ + OffsetT retval = 0; + while (num_items > 0) + { + OffsetT half = num_items >> 1; + if (val < input[retval + half]) + { + num_items = half; + } + else + { + retval = retval + (half + 1); + num_items = num_items - (half + 1); + } + } + + return retval; +} + + + + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/thread/thread_store.cuh b/fastertransformer/cuda/cub/thread/thread_store.cuh new file mode 100644 index 000000000..ec20b36f4 --- /dev/null +++ b/fastertransformer/cuda/cub/thread/thread_store.cuh @@ -0,0 +1,422 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Thread utilities for writing memory using PTX cache modifiers. + */ + +#pragma once + +#include + +#include "../util_ptx.cuh" +#include "../util_type.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \addtogroup UtilIo + * @{ + */ + + +//----------------------------------------------------------------------------- +// Tags and constants +//----------------------------------------------------------------------------- + +/** + * \brief Enumeration of cache modifiers for memory store operations. + */ +enum CacheStoreModifier +{ + STORE_DEFAULT, ///< Default (no modifier) + STORE_WB, ///< Cache write-back all coherent levels + STORE_CG, ///< Cache at global level + STORE_CS, ///< Cache streaming (likely to be accessed once) + STORE_WT, ///< Cache write-through (to system memory) + STORE_VOLATILE, ///< Volatile shared (any memory space) +}; + + +/** + * \name Thread I/O (cache modified) + * @{ + */ + +/** + * \brief Thread utility for writing memory using cub::CacheStoreModifier cache modifiers. Can be used to store any data type. + * + * \par Example + * \code + * #include // or equivalently + * + * // 32-bit store using cache-global modifier: + * int *d_out; + * int val; + * cub::ThreadStore(d_out + threadIdx.x, val); + * + * // 16-bit store using default modifier + * short *d_out; + * short val; + * cub::ThreadStore(d_out + threadIdx.x, val); + * + * // 256-bit store using write-through modifier + * double4 *d_out; + * double4 val; + * cub::ThreadStore(d_out + threadIdx.x, val); + * + * // 96-bit store using cache-streaming cache modifier + * struct TestFoo { bool a; short b; }; + * TestFoo *d_struct; + * TestFoo val; + * cub::ThreadStore(d_out + threadIdx.x, val); + * \endcode + * + * \tparam MODIFIER [inferred] CacheStoreModifier enumeration + * \tparam InputIteratorT [inferred] Output iterator type \iterator + * \tparam T [inferred] Data type of output value + */ +template < + CacheStoreModifier MODIFIER, + typename OutputIteratorT, + typename T> +__device__ __forceinline__ void ThreadStore(OutputIteratorT itr, T val); + + +//@} end member group + + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + +/// Helper structure for templated store iteration (inductive case) +template +struct IterateThreadStore +{ + template + static __device__ __forceinline__ void Store(T *ptr, T *vals) + { + ThreadStore(ptr + COUNT, vals[COUNT]); + IterateThreadStore::template Store(ptr, vals); + } + + template + static __device__ __forceinline__ void Dereference(OutputIteratorT ptr, T *vals) + { + ptr[COUNT] = vals[COUNT]; + IterateThreadStore::Dereference(ptr, vals); + } + +}; + +/// Helper structure for templated store iteration (termination case) +template +struct IterateThreadStore +{ + template + static __device__ __forceinline__ void Store(T * /*ptr*/, T * /*vals*/) {} + + template + static __device__ __forceinline__ void Dereference(OutputIteratorT /*ptr*/, T * /*vals*/) {} +}; + + +/** + * Define a uint4 (16B) ThreadStore specialization for the given Cache load modifier + */ +#define _CUB_STORE_16(cub_modifier, ptx_modifier) \ + template<> \ + __device__ __forceinline__ void ThreadStore(uint4* ptr, uint4 val) \ + { \ + asm volatile ("st."#ptx_modifier".v4.u32 [%0], {%1, %2, %3, %4};" : : \ + _CUB_ASM_PTR_(ptr), \ + "r"(val.x), \ + "r"(val.y), \ + "r"(val.z), \ + "r"(val.w)); \ + } \ + template<> \ + __device__ __forceinline__ void ThreadStore(ulonglong2* ptr, ulonglong2 val) \ + { \ + asm volatile ("st."#ptx_modifier".v2.u64 [%0], {%1, %2};" : : \ + _CUB_ASM_PTR_(ptr), \ + "l"(val.x), \ + "l"(val.y)); \ + } + + +/** + * Define a uint2 (8B) ThreadStore specialization for the given Cache load modifier + */ +#define _CUB_STORE_8(cub_modifier, ptx_modifier) \ + template<> \ + __device__ __forceinline__ void ThreadStore(ushort4* ptr, ushort4 val) \ + { \ + asm volatile ("st."#ptx_modifier".v4.u16 [%0], {%1, %2, %3, %4};" : : \ + _CUB_ASM_PTR_(ptr), \ + "h"(val.x), \ + "h"(val.y), \ + "h"(val.z), \ + "h"(val.w)); \ + } \ + template<> \ + __device__ __forceinline__ void ThreadStore(uint2* ptr, uint2 val) \ + { \ + asm volatile ("st."#ptx_modifier".v2.u32 [%0], {%1, %2};" : : \ + _CUB_ASM_PTR_(ptr), \ + "r"(val.x), \ + "r"(val.y)); \ + } \ + template<> \ + __device__ __forceinline__ void ThreadStore(unsigned long long* ptr, unsigned long long val) \ + { \ + asm volatile ("st."#ptx_modifier".u64 [%0], %1;" : : \ + _CUB_ASM_PTR_(ptr), \ + "l"(val)); \ + } + +/** + * Define a unsigned int (4B) ThreadStore specialization for the given Cache load modifier + */ +#define _CUB_STORE_4(cub_modifier, ptx_modifier) \ + template<> \ + __device__ __forceinline__ void ThreadStore(unsigned int* ptr, unsigned int val) \ + { \ + asm volatile ("st."#ptx_modifier".u32 [%0], %1;" : : \ + _CUB_ASM_PTR_(ptr), \ + "r"(val)); \ + } + + +/** + * Define a unsigned short (2B) ThreadStore specialization for the given Cache load modifier + */ +#define _CUB_STORE_2(cub_modifier, ptx_modifier) \ + template<> \ + __device__ __forceinline__ void ThreadStore(unsigned short* ptr, unsigned short val) \ + { \ + asm volatile ("st."#ptx_modifier".u16 [%0], %1;" : : \ + _CUB_ASM_PTR_(ptr), \ + "h"(val)); \ + } + + +/** + * Define a unsigned char (1B) ThreadStore specialization for the given Cache load modifier + */ +#define _CUB_STORE_1(cub_modifier, ptx_modifier) \ + template<> \ + __device__ __forceinline__ void ThreadStore(unsigned char* ptr, unsigned char val) \ + { \ + asm volatile ( \ + "{" \ + " .reg .u8 datum;" \ + " cvt.u8.u16 datum, %1;" \ + " st."#ptx_modifier".u8 [%0], datum;" \ + "}" : : \ + _CUB_ASM_PTR_(ptr), \ + "h"((unsigned short) val)); \ + } + +/** + * Define powers-of-two ThreadStore specializations for the given Cache load modifier + */ +#define _CUB_STORE_ALL(cub_modifier, ptx_modifier) \ + _CUB_STORE_16(cub_modifier, ptx_modifier) \ + _CUB_STORE_8(cub_modifier, ptx_modifier) \ + _CUB_STORE_4(cub_modifier, ptx_modifier) \ + _CUB_STORE_2(cub_modifier, ptx_modifier) \ + _CUB_STORE_1(cub_modifier, ptx_modifier) \ + + +/** + * Define ThreadStore specializations for the various Cache load modifiers + */ +#if CUB_PTX_ARCH >= 200 + _CUB_STORE_ALL(STORE_WB, wb) + _CUB_STORE_ALL(STORE_CG, cg) + _CUB_STORE_ALL(STORE_CS, cs) + _CUB_STORE_ALL(STORE_WT, wt) +#else + _CUB_STORE_ALL(STORE_WB, global) + _CUB_STORE_ALL(STORE_CG, global) + _CUB_STORE_ALL(STORE_CS, global) + _CUB_STORE_ALL(STORE_WT, volatile.global) +#endif + + +// Macro cleanup +#undef _CUB_STORE_ALL +#undef _CUB_STORE_1 +#undef _CUB_STORE_2 +#undef _CUB_STORE_4 +#undef _CUB_STORE_8 +#undef _CUB_STORE_16 + + +/** + * ThreadStore definition for STORE_DEFAULT modifier on iterator types + */ +template +__device__ __forceinline__ void ThreadStore( + OutputIteratorT itr, + T val, + Int2Type /*modifier*/, + Int2Type /*is_pointer*/) +{ + *itr = val; +} + + +/** + * ThreadStore definition for STORE_DEFAULT modifier on pointer types + */ +template +__device__ __forceinline__ void ThreadStore( + T *ptr, + T val, + Int2Type /*modifier*/, + Int2Type /*is_pointer*/) +{ + *ptr = val; +} + + +/** + * ThreadStore definition for STORE_VOLATILE modifier on primitive pointer types + */ +template +__device__ __forceinline__ void ThreadStoreVolatilePtr( + T *ptr, + T val, + Int2Type /*is_primitive*/) +{ + *reinterpret_cast(ptr) = val; +} + + +/** + * ThreadStore definition for STORE_VOLATILE modifier on non-primitive pointer types + */ +template +__device__ __forceinline__ void ThreadStoreVolatilePtr( + T *ptr, + T val, + Int2Type /*is_primitive*/) +{ + // Create a temporary using shuffle-words, then store using volatile-words + typedef typename UnitWord::VolatileWord VolatileWord; + typedef typename UnitWord::ShuffleWord ShuffleWord; + + const int VOLATILE_MULTIPLE = sizeof(T) / sizeof(VolatileWord); + const int SHUFFLE_MULTIPLE = sizeof(T) / sizeof(ShuffleWord); + + VolatileWord words[VOLATILE_MULTIPLE]; + + #pragma unroll + for (int i = 0; i < SHUFFLE_MULTIPLE; ++i) + reinterpret_cast(words)[i] = reinterpret_cast(&val)[i]; + + IterateThreadStore<0, VOLATILE_MULTIPLE>::template Dereference( + reinterpret_cast(ptr), + words); +} + + +/** + * ThreadStore definition for STORE_VOLATILE modifier on pointer types + */ +template +__device__ __forceinline__ void ThreadStore( + T *ptr, + T val, + Int2Type /*modifier*/, + Int2Type /*is_pointer*/) +{ + ThreadStoreVolatilePtr(ptr, val, Int2Type::PRIMITIVE>()); +} + + +/** + * ThreadStore definition for generic modifiers on pointer types + */ +template +__device__ __forceinline__ void ThreadStore( + T *ptr, + T val, + Int2Type /*modifier*/, + Int2Type /*is_pointer*/) +{ + // Create a temporary using shuffle-words, then store using device-words + typedef typename UnitWord::DeviceWord DeviceWord; + typedef typename UnitWord::ShuffleWord ShuffleWord; + + const int DEVICE_MULTIPLE = sizeof(T) / sizeof(DeviceWord); + const int SHUFFLE_MULTIPLE = sizeof(T) / sizeof(ShuffleWord); + + DeviceWord words[DEVICE_MULTIPLE]; + + #pragma unroll + for (int i = 0; i < SHUFFLE_MULTIPLE; ++i) + reinterpret_cast(words)[i] = reinterpret_cast(&val)[i]; + + IterateThreadStore<0, DEVICE_MULTIPLE>::template Store( + reinterpret_cast(ptr), + words); +} + + +/** + * ThreadStore definition for generic modifiers + */ +template +__device__ __forceinline__ void ThreadStore(OutputIteratorT itr, T val) +{ + ThreadStore( + itr, + val, + Int2Type(), + Int2Type::VALUE>()); +} + + + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + +/** @} */ // end group UtilIo + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/util_allocator.cuh b/fastertransformer/cuda/cub/util_allocator.cuh new file mode 100644 index 000000000..0e6dd0486 --- /dev/null +++ b/fastertransformer/cuda/cub/util_allocator.cuh @@ -0,0 +1,708 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/****************************************************************************** + * Simple caching allocator for device memory allocations. The allocator is + * thread-safe and capable of managing device allocations on multiple devices. + ******************************************************************************/ + +#pragma once + +#include "util_namespace.cuh" +#include "util_debug.cuh" + +#include +#include + +#include "host/mutex.cuh" +#include + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \addtogroup UtilMgmt + * @{ + */ + + +/****************************************************************************** + * CachingDeviceAllocator (host use) + ******************************************************************************/ + +/** + * \brief A simple caching allocator for device memory allocations. + * + * \par Overview + * The allocator is thread-safe and stream-safe and is capable of managing cached + * device allocations on multiple devices. It behaves as follows: + * + * \par + * - Allocations from the allocator are associated with an \p active_stream. Once freed, + * the allocation becomes available immediately for reuse within the \p active_stream + * with which it was associated with during allocation, and it becomes available for + * reuse within other streams when all prior work submitted to \p active_stream has completed. + * - Allocations are categorized and cached by bin size. A new allocation request of + * a given size will only consider cached allocations within the corresponding bin. + * - Bin limits progress geometrically in accordance with the growth factor + * \p bin_growth provided during construction. Unused device allocations within + * a larger bin cache are not reused for allocation requests that categorize to + * smaller bin sizes. + * - Allocation requests below (\p bin_growth ^ \p min_bin) are rounded up to + * (\p bin_growth ^ \p min_bin). + * - Allocations above (\p bin_growth ^ \p max_bin) are not rounded up to the nearest + * bin and are simply freed when they are deallocated instead of being returned + * to a bin-cache. + * - %If the total storage of cached allocations on a given device will exceed + * \p max_cached_bytes, allocations for that device are simply freed when they are + * deallocated instead of being returned to their bin-cache. + * + * \par + * For example, the default-constructed CachingDeviceAllocator is configured with: + * - \p bin_growth = 8 + * - \p min_bin = 3 + * - \p max_bin = 7 + * - \p max_cached_bytes = 6MB - 1B + * + * \par + * which delineates five bin-sizes: 512B, 4KB, 32KB, 256KB, and 2MB + * and sets a maximum of 6,291,455 cached bytes per device + * + */ +struct CachingDeviceAllocator +{ + + //--------------------------------------------------------------------- + // Constants + //--------------------------------------------------------------------- + + /// Out-of-bounds bin + static const unsigned int INVALID_BIN = (unsigned int) -1; + + /// Invalid size + static const size_t INVALID_SIZE = (size_t) -1; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + /// Invalid device ordinal + static const int INVALID_DEVICE_ORDINAL = -1; + + //--------------------------------------------------------------------- + // Type definitions and helper types + //--------------------------------------------------------------------- + + /** + * Descriptor for device memory allocations + */ + struct BlockDescriptor + { + void* d_ptr; // Device pointer + size_t bytes; // Size of allocation in bytes + unsigned int bin; // Bin enumeration + int device; // device ordinal + cudaStream_t associated_stream; // Associated associated_stream + cudaEvent_t ready_event; // Signal when associated stream has run to the point at which this block was freed + + // Constructor (suitable for searching maps for a specific block, given its pointer and device) + BlockDescriptor(void *d_ptr, int device) : + d_ptr(d_ptr), + bytes(0), + bin(INVALID_BIN), + device(device), + associated_stream(0), + ready_event(0) + {} + + // Constructor (suitable for searching maps for a range of suitable blocks, given a device) + BlockDescriptor(int device) : + d_ptr(NULL), + bytes(0), + bin(INVALID_BIN), + device(device), + associated_stream(0), + ready_event(0) + {} + + // Comparison functor for comparing device pointers + static bool PtrCompare(const BlockDescriptor &a, const BlockDescriptor &b) + { + if (a.device == b.device) + return (a.d_ptr < b.d_ptr); + else + return (a.device < b.device); + } + + // Comparison functor for comparing allocation sizes + static bool SizeCompare(const BlockDescriptor &a, const BlockDescriptor &b) + { + if (a.device == b.device) + return (a.bytes < b.bytes); + else + return (a.device < b.device); + } + }; + + /// BlockDescriptor comparator function interface + typedef bool (*Compare)(const BlockDescriptor &, const BlockDescriptor &); + + class TotalBytes { + public: + size_t free; + size_t live; + TotalBytes() { free = live = 0; } + }; + + /// Set type for cached blocks (ordered by size) + typedef std::multiset CachedBlocks; + + /// Set type for live blocks (ordered by ptr) + typedef std::multiset BusyBlocks; + + /// Map type of device ordinals to the number of cached bytes cached by each device + typedef std::map GpuCachedBytes; + + + //--------------------------------------------------------------------- + // Utility functions + //--------------------------------------------------------------------- + + /** + * Integer pow function for unsigned base and exponent + */ + static unsigned int IntPow( + unsigned int base, + unsigned int exp) + { + unsigned int retval = 1; + while (exp > 0) + { + if (exp & 1) { + retval = retval * base; // multiply the result by the current base + } + base = base * base; // square the base + exp = exp >> 1; // divide the exponent in half + } + return retval; + } + + + /** + * Round up to the nearest power-of + */ + void NearestPowerOf( + unsigned int &power, + size_t &rounded_bytes, + unsigned int base, + size_t value) + { + power = 0; + rounded_bytes = 1; + + if (value * base < value) + { + // Overflow + power = sizeof(size_t) * 8; + rounded_bytes = size_t(0) - 1; + return; + } + + while (rounded_bytes < value) + { + rounded_bytes *= base; + power++; + } + } + + + //--------------------------------------------------------------------- + // Fields + //--------------------------------------------------------------------- + + cub::Mutex mutex; /// Mutex for thread-safety + + unsigned int bin_growth; /// Geometric growth factor for bin-sizes + unsigned int min_bin; /// Minimum bin enumeration + unsigned int max_bin; /// Maximum bin enumeration + + size_t min_bin_bytes; /// Minimum bin size + size_t max_bin_bytes; /// Maximum bin size + size_t max_cached_bytes; /// Maximum aggregate cached bytes per device + + const bool skip_cleanup; /// Whether or not to skip a call to FreeAllCached() when destructor is called. (The CUDA runtime may have already shut down for statically declared allocators) + bool debug; /// Whether or not to print (de)allocation events to stdout + + GpuCachedBytes cached_bytes; /// Map of device ordinal to aggregate cached bytes on that device + CachedBlocks cached_blocks; /// Set of cached device allocations available for reuse + BusyBlocks live_blocks; /// Set of live device allocations currently in use + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + //--------------------------------------------------------------------- + // Methods + //--------------------------------------------------------------------- + + /** + * \brief Constructor. + */ + CachingDeviceAllocator( + unsigned int bin_growth, ///< Geometric growth factor for bin-sizes + unsigned int min_bin = 1, ///< Minimum bin (default is bin_growth ^ 1) + unsigned int max_bin = INVALID_BIN, ///< Maximum bin (default is no max bin) + size_t max_cached_bytes = INVALID_SIZE, ///< Maximum aggregate cached bytes per device (default is no limit) + bool skip_cleanup = false, ///< Whether or not to skip a call to \p FreeAllCached() when the destructor is called (default is to deallocate) + bool debug = false) ///< Whether or not to print (de)allocation events to stdout (default is no stderr output) + : + bin_growth(bin_growth), + min_bin(min_bin), + max_bin(max_bin), + min_bin_bytes(IntPow(bin_growth, min_bin)), + max_bin_bytes(IntPow(bin_growth, max_bin)), + max_cached_bytes(max_cached_bytes), + skip_cleanup(skip_cleanup), + debug(debug), + cached_blocks(BlockDescriptor::SizeCompare), + live_blocks(BlockDescriptor::PtrCompare) + {} + + + /** + * \brief Default constructor. + * + * Configured with: + * \par + * - \p bin_growth = 8 + * - \p min_bin = 3 + * - \p max_bin = 7 + * - \p max_cached_bytes = (\p bin_growth ^ \p max_bin) * 3) - 1 = 6,291,455 bytes + * + * which delineates five bin-sizes: 512B, 4KB, 32KB, 256KB, and 2MB and + * sets a maximum of 6,291,455 cached bytes per device + */ + CachingDeviceAllocator( + bool skip_cleanup = false, + bool debug = false) + : + bin_growth(8), + min_bin(3), + max_bin(7), + min_bin_bytes(IntPow(bin_growth, min_bin)), + max_bin_bytes(IntPow(bin_growth, max_bin)), + max_cached_bytes((max_bin_bytes * 3) - 1), + skip_cleanup(skip_cleanup), + debug(debug), + cached_blocks(BlockDescriptor::SizeCompare), + live_blocks(BlockDescriptor::PtrCompare) + {} + + + /** + * \brief Sets the limit on the number bytes this allocator is allowed to cache per device. + * + * Changing the ceiling of cached bytes does not cause any allocations (in-use or + * cached-in-reserve) to be freed. See \p FreeAllCached(). + */ + cudaError_t SetMaxCachedBytes( + size_t max_cached_bytes) + { + // Lock + mutex.Lock(); + + if (debug) _CubLog("Changing max_cached_bytes (%lld -> %lld)\n", (long long) this->max_cached_bytes, (long long) max_cached_bytes); + + this->max_cached_bytes = max_cached_bytes; + + // Unlock + mutex.Unlock(); + + return cudaSuccess; + } + + + /** + * \brief Provides a suitable allocation of device memory for the given size on the specified device. + * + * Once freed, the allocation becomes available immediately for reuse within the \p active_stream + * with which it was associated with during allocation, and it becomes available for reuse within other + * streams when all prior work submitted to \p active_stream has completed. + */ + cudaError_t DeviceAllocate( + int device, ///< [in] Device on which to place the allocation + void **d_ptr, ///< [out] Reference to pointer to the allocation + size_t bytes, ///< [in] Minimum number of bytes for the allocation + cudaStream_t active_stream = 0) ///< [in] The stream to be associated with this allocation + { + *d_ptr = NULL; + int entrypoint_device = INVALID_DEVICE_ORDINAL; + cudaError_t error = cudaSuccess; + + if (device == INVALID_DEVICE_ORDINAL) + { + if (CubDebug(error = cudaGetDevice(&entrypoint_device))) return error; + device = entrypoint_device; + } + + // Create a block descriptor for the requested allocation + bool found = false; + BlockDescriptor search_key(device); + search_key.associated_stream = active_stream; + NearestPowerOf(search_key.bin, search_key.bytes, bin_growth, bytes); + + if (search_key.bin > max_bin) + { + // Bin is greater than our maximum bin: allocate the request + // exactly and give out-of-bounds bin. It will not be cached + // for reuse when returned. + search_key.bin = INVALID_BIN; + search_key.bytes = bytes; + } + else + { + // Search for a suitable cached allocation: lock + mutex.Lock(); + + if (search_key.bin < min_bin) + { + // Bin is less than minimum bin: round up + search_key.bin = min_bin; + search_key.bytes = min_bin_bytes; + } + + // Iterate through the range of cached blocks on the same device in the same bin + CachedBlocks::iterator block_itr = cached_blocks.lower_bound(search_key); + while ((block_itr != cached_blocks.end()) + && (block_itr->device == device) + && (block_itr->bin == search_key.bin)) + { + // To prevent races with reusing blocks returned by the host but still + // in use by the device, only consider cached blocks that are + // either (from the active stream) or (from an idle stream) + if ((active_stream == block_itr->associated_stream) || + (cudaEventQuery(block_itr->ready_event) != cudaErrorNotReady)) + { + // Reuse existing cache block. Insert into live blocks. + found = true; + search_key = *block_itr; + search_key.associated_stream = active_stream; + live_blocks.insert(search_key); + + // Remove from free blocks + cached_bytes[device].free -= search_key.bytes; + cached_bytes[device].live += search_key.bytes; + + if (debug) _CubLog("\tDevice %d reused cached block at %p (%lld bytes) for stream %lld (previously associated with stream %lld).\n", + device, search_key.d_ptr, (long long) search_key.bytes, (long long) search_key.associated_stream, (long long) block_itr->associated_stream); + + cached_blocks.erase(block_itr); + + break; + } + block_itr++; + } + + // Done searching: unlock + mutex.Unlock(); + } + + // Allocate the block if necessary + if (!found) + { + // Set runtime's current device to specified device (entrypoint may not be set) + if (device != entrypoint_device) + { + if (CubDebug(error = cudaGetDevice(&entrypoint_device))) return error; + if (CubDebug(error = cudaSetDevice(device))) return error; + } + + // Attempt to allocate + if (CubDebug(error = cudaMalloc(&search_key.d_ptr, search_key.bytes)) == cudaErrorMemoryAllocation) + { + // The allocation attempt failed: free all cached blocks on device and retry + if (debug) _CubLog("\tDevice %d failed to allocate %lld bytes for stream %lld, retrying after freeing cached allocations", + device, (long long) search_key.bytes, (long long) search_key.associated_stream); + + error = cudaSuccess; // Reset the error we will return + cudaGetLastError(); // Reset CUDART's error + + // Lock + mutex.Lock(); + + // Iterate the range of free blocks on the same device + BlockDescriptor free_key(device); + CachedBlocks::iterator block_itr = cached_blocks.lower_bound(free_key); + + while ((block_itr != cached_blocks.end()) && (block_itr->device == device)) + { + // No need to worry about synchronization with the device: cudaFree is + // blocking and will synchronize across all kernels executing + // on the current device + + // Free device memory and destroy stream event. + if (CubDebug(error = cudaFree(block_itr->d_ptr))) break; + if (CubDebug(error = cudaEventDestroy(block_itr->ready_event))) break; + + // Reduce balance and erase entry + cached_bytes[device].free -= block_itr->bytes; + + if (debug) _CubLog("\tDevice %d freed %lld bytes.\n\t\t %lld available blocks cached (%lld bytes), %lld live blocks (%lld bytes) outstanding.\n", + device, (long long) block_itr->bytes, (long long) cached_blocks.size(), (long long) cached_bytes[device].free, (long long) live_blocks.size(), (long long) cached_bytes[device].live); + + cached_blocks.erase(block_itr); + + block_itr++; + } + + // Unlock + mutex.Unlock(); + + // Return under error + if (error) return error; + + // Try to allocate again + if (CubDebug(error = cudaMalloc(&search_key.d_ptr, search_key.bytes))) return error; + } + + // Create ready event + if (CubDebug(error = cudaEventCreateWithFlags(&search_key.ready_event, cudaEventDisableTiming))) + return error; + + // Insert into live blocks + mutex.Lock(); + live_blocks.insert(search_key); + cached_bytes[device].live += search_key.bytes; + mutex.Unlock(); + + if (debug) _CubLog("\tDevice %d allocated new device block at %p (%lld bytes associated with stream %lld).\n", + device, search_key.d_ptr, (long long) search_key.bytes, (long long) search_key.associated_stream); + + // Attempt to revert back to previous device if necessary + if ((entrypoint_device != INVALID_DEVICE_ORDINAL) && (entrypoint_device != device)) + { + if (CubDebug(error = cudaSetDevice(entrypoint_device))) return error; + } + } + + // Copy device pointer to output parameter + *d_ptr = search_key.d_ptr; + + if (debug) _CubLog("\t\t%lld available blocks cached (%lld bytes), %lld live blocks outstanding(%lld bytes).\n", + (long long) cached_blocks.size(), (long long) cached_bytes[device].free, (long long) live_blocks.size(), (long long) cached_bytes[device].live); + + return error; + } + + + /** + * \brief Provides a suitable allocation of device memory for the given size on the current device. + * + * Once freed, the allocation becomes available immediately for reuse within the \p active_stream + * with which it was associated with during allocation, and it becomes available for reuse within other + * streams when all prior work submitted to \p active_stream has completed. + */ + cudaError_t DeviceAllocate( + void **d_ptr, ///< [out] Reference to pointer to the allocation + size_t bytes, ///< [in] Minimum number of bytes for the allocation + cudaStream_t active_stream = 0) ///< [in] The stream to be associated with this allocation + { + return DeviceAllocate(INVALID_DEVICE_ORDINAL, d_ptr, bytes, active_stream); + } + + + /** + * \brief Frees a live allocation of device memory on the specified device, returning it to the allocator. + * + * Once freed, the allocation becomes available immediately for reuse within the \p active_stream + * with which it was associated with during allocation, and it becomes available for reuse within other + * streams when all prior work submitted to \p active_stream has completed. + */ + cudaError_t DeviceFree( + int device, + void* d_ptr) + { + int entrypoint_device = INVALID_DEVICE_ORDINAL; + cudaError_t error = cudaSuccess; + + if (device == INVALID_DEVICE_ORDINAL) + { + if (CubDebug(error = cudaGetDevice(&entrypoint_device))) + return error; + device = entrypoint_device; + } + + // Lock + mutex.Lock(); + + // Find corresponding block descriptor + bool recached = false; + BlockDescriptor search_key(d_ptr, device); + BusyBlocks::iterator block_itr = live_blocks.find(search_key); + if (block_itr != live_blocks.end()) + { + // Remove from live blocks + search_key = *block_itr; + live_blocks.erase(block_itr); + cached_bytes[device].live -= search_key.bytes; + + // Keep the returned allocation if bin is valid and we won't exceed the max cached threshold + if ((search_key.bin != INVALID_BIN) && (cached_bytes[device].free + search_key.bytes <= max_cached_bytes)) + { + // Insert returned allocation into free blocks + recached = true; + cached_blocks.insert(search_key); + cached_bytes[device].free += search_key.bytes; + + if (debug) _CubLog("\tDevice %d returned %lld bytes from associated stream %lld.\n\t\t %lld available blocks cached (%lld bytes), %lld live blocks outstanding. (%lld bytes)\n", + device, (long long) search_key.bytes, (long long) search_key.associated_stream, (long long) cached_blocks.size(), + (long long) cached_bytes[device].free, (long long) live_blocks.size(), (long long) cached_bytes[device].live); + } + } + + // Unlock + mutex.Unlock(); + + // First set to specified device (entrypoint may not be set) + if (device != entrypoint_device) + { + if (CubDebug(error = cudaGetDevice(&entrypoint_device))) return error; + if (CubDebug(error = cudaSetDevice(device))) return error; + } + + if (recached) + { + // Insert the ready event in the associated stream (must have current device set properly) + if (CubDebug(error = cudaEventRecord(search_key.ready_event, search_key.associated_stream))) return error; + } + else + { + // Free the allocation from the runtime and cleanup the event. + if (CubDebug(error = cudaFree(d_ptr))) return error; + if (CubDebug(error = cudaEventDestroy(search_key.ready_event))) return error; + + if (debug) _CubLog("\tDevice %d freed %lld bytes from associated stream %lld.\n\t\t %lld available blocks cached (%lld bytes), %lld live blocks (%lld bytes) outstanding.\n", + device, (long long) search_key.bytes, (long long) search_key.associated_stream, (long long) cached_blocks.size(), (long long) cached_bytes[device].free, (long long) live_blocks.size(), (long long) cached_bytes[device].live); + } + + // Reset device + if ((entrypoint_device != INVALID_DEVICE_ORDINAL) && (entrypoint_device != device)) + { + if (CubDebug(error = cudaSetDevice(entrypoint_device))) return error; + } + + return error; + } + + + /** + * \brief Frees a live allocation of device memory on the current device, returning it to the allocator. + * + * Once freed, the allocation becomes available immediately for reuse within the \p active_stream + * with which it was associated with during allocation, and it becomes available for reuse within other + * streams when all prior work submitted to \p active_stream has completed. + */ + cudaError_t DeviceFree( + void* d_ptr) + { + return DeviceFree(INVALID_DEVICE_ORDINAL, d_ptr); + } + + + /** + * \brief Frees all cached device allocations on all devices + */ + cudaError_t FreeAllCached() + { + cudaError_t error = cudaSuccess; + int entrypoint_device = INVALID_DEVICE_ORDINAL; + int current_device = INVALID_DEVICE_ORDINAL; + + mutex.Lock(); + + while (!cached_blocks.empty()) + { + // Get first block + CachedBlocks::iterator begin = cached_blocks.begin(); + + // Get entry-point device ordinal if necessary + if (entrypoint_device == INVALID_DEVICE_ORDINAL) + { + if (CubDebug(error = cudaGetDevice(&entrypoint_device))) break; + } + + // Set current device ordinal if necessary + if (begin->device != current_device) + { + if (CubDebug(error = cudaSetDevice(begin->device))) break; + current_device = begin->device; + } + + // Free device memory + if (CubDebug(error = cudaFree(begin->d_ptr))) break; + if (CubDebug(error = cudaEventDestroy(begin->ready_event))) break; + + // Reduce balance and erase entry + cached_bytes[current_device].free -= begin->bytes; + + if (debug) _CubLog("\tDevice %d freed %lld bytes.\n\t\t %lld available blocks cached (%lld bytes), %lld live blocks (%lld bytes) outstanding.\n", + current_device, (long long) begin->bytes, (long long) cached_blocks.size(), (long long) cached_bytes[current_device].free, (long long) live_blocks.size(), (long long) cached_bytes[current_device].live); + + cached_blocks.erase(begin); + } + + mutex.Unlock(); + + // Attempt to revert back to entry-point device if necessary + if (entrypoint_device != INVALID_DEVICE_ORDINAL) + { + if (CubDebug(error = cudaSetDevice(entrypoint_device))) return error; + } + + return error; + } + + + /** + * \brief Destructor + */ + virtual ~CachingDeviceAllocator() + { + if (!skip_cleanup) + FreeAllCached(); + } + +}; + + + + +/** @} */ // end group UtilMgmt + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/util_arch.cuh b/fastertransformer/cuda/cub/util_arch.cuh new file mode 100644 index 000000000..28d81e7cd --- /dev/null +++ b/fastertransformer/cuda/cub/util_arch.cuh @@ -0,0 +1,151 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Static architectural properties by SM version. + */ + +#pragma once + +#include "util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +#if (__CUDACC_VER_MAJOR__ >= 9) && !defined(CUB_USE_COOPERATIVE_GROUPS) + #define CUB_USE_COOPERATIVE_GROUPS +#endif + +/// CUB_PTX_ARCH reflects the PTX version targeted by the active compiler pass (or zero during the host pass). +#ifndef CUB_PTX_ARCH + #ifndef __CUDA_ARCH__ + #define CUB_PTX_ARCH 0 + #else + #define CUB_PTX_ARCH __CUDA_ARCH__ + #endif +#endif + + +/// Whether or not the source targeted by the active compiler pass is allowed to invoke device kernels or methods from the CUDA runtime API. +#ifndef CUB_RUNTIME_FUNCTION + #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__>= 350 && defined(__CUDACC_RDC__)) + #define CUB_RUNTIME_ENABLED + #define CUB_RUNTIME_FUNCTION __host__ __device__ + #else + #define CUB_RUNTIME_FUNCTION __host__ + #endif +#endif + + +/// Number of threads per warp +#ifndef CUB_LOG_WARP_THREADS + #define CUB_LOG_WARP_THREADS(arch) \ + (5) + #define CUB_WARP_THREADS(arch) \ + (1 << CUB_LOG_WARP_THREADS(arch)) + + #define CUB_PTX_WARP_THREADS CUB_WARP_THREADS(CUB_PTX_ARCH) + #define CUB_PTX_LOG_WARP_THREADS CUB_LOG_WARP_THREADS(CUB_PTX_ARCH) +#endif + + +/// Number of smem banks +#ifndef CUB_LOG_SMEM_BANKS + #define CUB_LOG_SMEM_BANKS(arch) \ + ((arch >= 200) ? \ + (5) : \ + (4)) + #define CUB_SMEM_BANKS(arch) \ + (1 << CUB_LOG_SMEM_BANKS(arch)) + + #define CUB_PTX_LOG_SMEM_BANKS CUB_LOG_SMEM_BANKS(CUB_PTX_ARCH) + #define CUB_PTX_SMEM_BANKS CUB_SMEM_BANKS(CUB_PTX_ARCH) +#endif + + +/// Oversubscription factor +#ifndef CUB_SUBSCRIPTION_FACTOR + #define CUB_SUBSCRIPTION_FACTOR(arch) \ + ((arch >= 300) ? \ + (5) : \ + ((arch >= 200) ? \ + (3) : \ + (10))) + #define CUB_PTX_SUBSCRIPTION_FACTOR CUB_SUBSCRIPTION_FACTOR(CUB_PTX_ARCH) +#endif + + +/// Prefer padding overhead vs X-way conflicts greater than this threshold +#ifndef CUB_PREFER_CONFLICT_OVER_PADDING + #define CUB_PREFER_CONFLICT_OVER_PADDING(arch) \ + ((arch >= 300) ? \ + (1) : \ + (4)) + #define CUB_PTX_PREFER_CONFLICT_OVER_PADDING CUB_PREFER_CONFLICT_OVER_PADDING(CUB_PTX_ARCH) +#endif + + +/// Scale down the number of threads to keep same amount of scratch storage as the nominal configuration for 4B data. Minimum of two warps. +#ifndef CUB_SCALED_BLOCK_THREADS + #define CUB_SCALED_BLOCK_THREADS(NOMINAL_4B_BLOCK_THREADS, T, PTX_ARCH) \ + (CUB_MIN( \ + NOMINAL_4B_BLOCK_THREADS, \ + CUB_WARP_THREADS(PTX_ARCH) * CUB_MAX( \ + 2, \ + (NOMINAL_4B_BLOCK_THREADS / CUB_WARP_THREADS(PTX_ARCH)) * 4 / sizeof(T)))) +#endif + +/// Scale down number of items per thread to keep the same amount of register storage as the nominal configuration for 4B data. Minimum 1 item per thread +#ifndef CUB_SCALED_ITEMS_PER_THREAD + #define CUB_SCALED_ITEMS_PER_THREAD(NOMINAL_4B_ITEMS_PER_THREAD, NOMINAL_4B_BLOCK_THREADS, T, PTX_ARCH) \ + CUB_MAX( \ + 1, \ + (sizeof(T) < 4) ? \ + ((NOMINAL_4B_ITEMS_PER_THREAD * NOMINAL_4B_BLOCK_THREADS * 4) / CUB_MAX(4, sizeof(T))) / CUB_SCALED_BLOCK_THREADS(NOMINAL_4B_BLOCK_THREADS, T, PTX_ARCH) / 2 : \ + ((NOMINAL_4B_ITEMS_PER_THREAD * NOMINAL_4B_BLOCK_THREADS * 4) / CUB_MAX(4, sizeof(T))) / CUB_SCALED_BLOCK_THREADS(NOMINAL_4B_BLOCK_THREADS, T, PTX_ARCH)) +#endif + +/// Define both nominal threads-per-block and items-per-thread +#ifndef CUB_SCALED_GRANULARITIES + #define CUB_SCALED_GRANULARITIES(NOMINAL_4B_BLOCK_THREADS, NOMINAL_4B_ITEMS_PER_THREAD, T) \ + CUB_SCALED_BLOCK_THREADS(NOMINAL_4B_BLOCK_THREADS, T, 200), \ + CUB_SCALED_ITEMS_PER_THREAD(NOMINAL_4B_ITEMS_PER_THREAD, NOMINAL_4B_BLOCK_THREADS, T, 200) +#endif + + + +#endif // Do not document + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/util_debug.cuh b/fastertransformer/cuda/cub/util_debug.cuh new file mode 100644 index 000000000..3ad832e73 --- /dev/null +++ b/fastertransformer/cuda/cub/util_debug.cuh @@ -0,0 +1,145 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Error and event logging routines. + * + * The following macros definitions are supported: + * - \p CUB_LOG. Simple event messages are printed to \p stdout. + */ + +#pragma once + +#include +#include "util_namespace.cuh" +#include "util_arch.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \addtogroup UtilMgmt + * @{ + */ + + +/// CUB error reporting macro (prints error messages to stderr) +#if (defined(DEBUG) || defined(_DEBUG)) && !defined(CUB_STDERR) + #define CUB_STDERR +#endif + + + +/** + * \brief %If \p CUB_STDERR is defined and \p error is not \p cudaSuccess, the corresponding error message is printed to \p stderr (or \p stdout in device code) along with the supplied source context. + * + * \return The CUDA error. + */ +__host__ __device__ __forceinline__ cudaError_t Debug( + cudaError_t error, + const char* filename, + int line) +{ + (void)filename; + (void)line; +#ifdef CUB_STDERR + if (error) + { + #if (CUB_PTX_ARCH == 0) + fprintf(stderr, "CUDA error %d [%s, %d]: %s\n", error, filename, line, cudaGetErrorString(error)); + fflush(stderr); + #elif (CUB_PTX_ARCH >= 200) + printf("CUDA error %d [block (%d,%d,%d) thread (%d,%d,%d), %s, %d]\n", error, blockIdx.z, blockIdx.y, blockIdx.x, threadIdx.z, threadIdx.y, threadIdx.x, filename, line); + #endif + } +#endif + return error; +} + + +/** + * \brief Debug macro + */ +#ifndef CubDebug + #define CubDebug(e) cub::Debug((cudaError_t) (e), __FILE__, __LINE__) +#endif + + +/** + * \brief Debug macro with exit + */ +#ifndef CubDebugExit + #define CubDebugExit(e) if (cub::Debug((cudaError_t) (e), __FILE__, __LINE__)) { exit(1); } +#endif + + +/** + * \brief Log macro for printf statements. + */ +#if !defined(_CubLog) + #if !(defined(__clang__) && defined(__CUDA__)) + #if (CUB_PTX_ARCH == 0) + #define _CubLog(format, ...) printf(format,__VA_ARGS__); + #elif (CUB_PTX_ARCH >= 200) + #define _CubLog(format, ...) printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, blockIdx.z, blockIdx.y, blockIdx.x, threadIdx.z, threadIdx.y, threadIdx.x, __VA_ARGS__); + #endif + #else + // XXX shameless hack for clang around variadic printf... + // Compilies w/o supplying -std=c++11 but shows warning, + // so we sielence them :) + #pragma clang diagnostic ignored "-Wc++11-extensions" + #pragma clang diagnostic ignored "-Wunnamed-type-template-args" + template + inline __host__ __device__ void va_printf(char const* format, Args const&... args) + { + #ifdef __CUDA_ARCH__ + printf(format, blockIdx.z, blockIdx.y, blockIdx.x, threadIdx.z, threadIdx.y, threadIdx.x, args...); + #else + printf(format, args...); + #endif + } + #ifndef __CUDA_ARCH__ + #define _CubLog(format, ...) va_printf(format,__VA_ARGS__); + #else + #define _CubLog(format, ...) va_printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, __VA_ARGS__); + #endif + #endif +#endif + + + + +/** @} */ // end group UtilMgmt + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/util_device.cuh b/fastertransformer/cuda/cub/util_device.cuh new file mode 100644 index 000000000..a5f3b6144 --- /dev/null +++ b/fastertransformer/cuda/cub/util_device.cuh @@ -0,0 +1,347 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Properties of a given CUDA device and the corresponding PTX bundle + */ + +#pragma once + +#include "util_type.cuh" +#include "util_arch.cuh" +#include "util_debug.cuh" +#include "util_namespace.cuh" +#include "util_macro.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \addtogroup UtilMgmt + * @{ + */ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + +/** + * Alias temporaries to externally-allocated device storage (or simply return the amount of storage needed). + */ +template +__host__ __device__ __forceinline__ +cudaError_t AliasTemporaries( + void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Size in bytes of \t d_temp_storage allocation + void* (&allocations)[ALLOCATIONS], ///< [in,out] Pointers to device allocations needed + size_t (&allocation_sizes)[ALLOCATIONS]) ///< [in] Sizes in bytes of device allocations needed +{ + const int ALIGN_BYTES = 256; + const int ALIGN_MASK = ~(ALIGN_BYTES - 1); + + // Compute exclusive prefix sum over allocation requests + size_t allocation_offsets[ALLOCATIONS]; + size_t bytes_needed = 0; + for (int i = 0; i < ALLOCATIONS; ++i) + { + size_t allocation_bytes = (allocation_sizes[i] + ALIGN_BYTES - 1) & ALIGN_MASK; + allocation_offsets[i] = bytes_needed; + bytes_needed += allocation_bytes; + } + bytes_needed += ALIGN_BYTES - 1; + + // Check if the caller is simply requesting the size of the storage allocation + if (!d_temp_storage) + { + temp_storage_bytes = bytes_needed; + return cudaSuccess; + } + + // Check if enough storage provided + if (temp_storage_bytes < bytes_needed) + { + return CubDebug(cudaErrorInvalidValue); + } + + // Alias + d_temp_storage = (void *) ((size_t(d_temp_storage) + ALIGN_BYTES - 1) & ALIGN_MASK); + for (int i = 0; i < ALLOCATIONS; ++i) + { + allocations[i] = static_cast(d_temp_storage) + allocation_offsets[i]; + } + + return cudaSuccess; +} + + +/** + * Empty kernel for querying PTX manifest metadata (e.g., version) for the current device + */ +template +__global__ void EmptyKernel(void) { } + + +#endif // DOXYGEN_SHOULD_SKIP_THIS + +/** + * \brief Retrieves the PTX version that will be used on the current device (major * 100 + minor * 10) + */ +CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t PtxVersion(int &ptx_version) +{ + struct Dummy + { + /// Type definition of the EmptyKernel kernel entry point + typedef void (*EmptyKernelPtr)(); + + /// Force EmptyKernel to be generated if this class is used + CUB_RUNTIME_FUNCTION __forceinline__ + EmptyKernelPtr Empty() + { + return EmptyKernel; + } + }; + + +#ifndef CUB_RUNTIME_ENABLED + (void)ptx_version; + + // CUDA API calls not supported from this device + return cudaErrorInvalidConfiguration; + +#elif (CUB_PTX_ARCH > 0) + + ptx_version = CUB_PTX_ARCH; + return cudaSuccess; + +#else + + cudaError_t error = cudaSuccess; + do + { + cudaFuncAttributes empty_kernel_attrs; + if (CubDebug(error = cudaFuncGetAttributes(&empty_kernel_attrs, EmptyKernel))) break; + ptx_version = empty_kernel_attrs.ptxVersion * 10; + } + while (0); + + return error; + +#endif +} + + +/** + * \brief Retrieves the SM version (major * 100 + minor * 10) + */ +CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t SmVersion(int &sm_version, int device_ordinal) +{ +#ifndef CUB_RUNTIME_ENABLED + (void)sm_version; + (void)device_ordinal; + + // CUDA API calls not supported from this device + return cudaErrorInvalidConfiguration; + +#else + + cudaError_t error = cudaSuccess; + do + { + // Fill in SM version + int major, minor; + if (CubDebug(error = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device_ordinal))) break; + if (CubDebug(error = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device_ordinal))) break; + sm_version = major * 100 + minor * 10; + } + while (0); + + return error; + +#endif +} + + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +/** + * Synchronize the stream if specified + */ +CUB_RUNTIME_FUNCTION __forceinline__ +static cudaError_t SyncStream(cudaStream_t stream) +{ +#if (CUB_PTX_ARCH == 0) + return cudaStreamSynchronize(stream); +#else + (void)stream; + // Device can't yet sync on a specific stream + return cudaDeviceSynchronize(); +#endif +} + + +/** + * \brief Computes maximum SM occupancy in thread blocks for executing the given kernel function pointer \p kernel_ptr on the current device with \p block_threads per thread block. + * + * \par Snippet + * The code snippet below illustrates the use of the MaxSmOccupancy function. + * \par + * \code + * #include // or equivalently + * + * template + * __global__ void ExampleKernel() + * { + * // Allocate shared memory for BlockScan + * __shared__ volatile T buffer[4096]; + * + * ... + * } + * + * ... + * + * // Determine SM occupancy for ExampleKernel specialized for unsigned char + * int max_sm_occupancy; + * MaxSmOccupancy(max_sm_occupancy, ExampleKernel, 64); + * + * // max_sm_occupancy <-- 4 on SM10 + * // max_sm_occupancy <-- 8 on SM20 + * // max_sm_occupancy <-- 12 on SM35 + * + * \endcode + * + */ +template +CUB_RUNTIME_FUNCTION __forceinline__ +cudaError_t MaxSmOccupancy( + int &max_sm_occupancy, ///< [out] maximum number of thread blocks that can reside on a single SM + KernelPtr kernel_ptr, ///< [in] Kernel pointer for which to compute SM occupancy + int block_threads, ///< [in] Number of threads per thread block + int dynamic_smem_bytes = 0) +{ +#ifndef CUB_RUNTIME_ENABLED + (void)dynamic_smem_bytes; + (void)block_threads; + (void)kernel_ptr; + (void)max_sm_occupancy; + + // CUDA API calls not supported from this device + return CubDebug(cudaErrorInvalidConfiguration); + +#else + + return cudaOccupancyMaxActiveBlocksPerMultiprocessor ( + &max_sm_occupancy, + kernel_ptr, + block_threads, + dynamic_smem_bytes); + +#endif // CUB_RUNTIME_ENABLED +} + + +/****************************************************************************** + * Policy management + ******************************************************************************/ + +/** + * Kernel dispatch configuration + */ +struct KernelConfig +{ + int block_threads; + int items_per_thread; + int tile_size; + int sm_occupancy; + + CUB_RUNTIME_FUNCTION __forceinline__ + KernelConfig() : block_threads(0), items_per_thread(0), tile_size(0), sm_occupancy(0) {} + + template + CUB_RUNTIME_FUNCTION __forceinline__ + cudaError_t Init(KernelPtrT kernel_ptr) + { + block_threads = AgentPolicyT::BLOCK_THREADS; + items_per_thread = AgentPolicyT::ITEMS_PER_THREAD; + tile_size = block_threads * items_per_thread; + cudaError_t retval = MaxSmOccupancy(sm_occupancy, kernel_ptr, block_threads); + return retval; + } +}; + + + +/// Helper for dispatching into a policy chain +template +struct ChainedPolicy +{ + /// The policy for the active compiler pass + typedef typename If<(CUB_PTX_ARCH < PTX_VERSION), typename PrevPolicyT::ActivePolicy, PolicyT>::Type ActivePolicy; + + /// Specializes and dispatches op in accordance to the first policy in the chain of adequate PTX version + template + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Invoke(int ptx_version, FunctorT &op) + { + if (ptx_version < PTX_VERSION) { + return PrevPolicyT::Invoke(ptx_version, op); + } + return op.template Invoke(); + } +}; + +/// Helper for dispatching into a policy chain (end-of-chain specialization) +template +struct ChainedPolicy +{ + /// The policy for the active compiler pass + typedef PolicyT ActivePolicy; + + /// Specializes and dispatches op in accordance to the first policy in the chain of adequate PTX version + template + CUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Invoke(int /*ptx_version*/, FunctorT &op) { + return op.template Invoke(); + } +}; + + + + +#endif // Do not document + + + + +/** @} */ // end group UtilMgmt + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/util_macro.cuh b/fastertransformer/cuda/cub/util_macro.cuh new file mode 100644 index 000000000..ff8636542 --- /dev/null +++ b/fastertransformer/cuda/cub/util_macro.cuh @@ -0,0 +1,103 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/****************************************************************************** + * Common C/C++ macro utilities + ******************************************************************************/ + +#pragma once + +#include "util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \addtogroup UtilModule + * @{ + */ + +#ifndef CUB_ALIGN + #if defined(_WIN32) || defined(_WIN64) + /// Align struct + #define CUB_ALIGN(bytes) __declspec(align(32)) + #else + /// Align struct + #define CUB_ALIGN(bytes) __attribute__((aligned(bytes))) + #endif +#endif + +#ifndef CUB_MAX + /// Select maximum(a, b) + #define CUB_MAX(a, b) (((b) > (a)) ? (b) : (a)) +#endif + +#ifndef CUB_MIN + /// Select minimum(a, b) + #define CUB_MIN(a, b) (((b) < (a)) ? (b) : (a)) +#endif + +#ifndef CUB_QUOTIENT_FLOOR + /// Quotient of x/y rounded down to nearest integer + #define CUB_QUOTIENT_FLOOR(x, y) ((x) / (y)) +#endif + +#ifndef CUB_QUOTIENT_CEILING + /// Quotient of x/y rounded up to nearest integer + #define CUB_QUOTIENT_CEILING(x, y) (((x) + (y) - 1) / (y)) +#endif + +#ifndef CUB_ROUND_UP_NEAREST + /// x rounded up to the nearest multiple of y + #define CUB_ROUND_UP_NEAREST(x, y) ((((x) + (y) - 1) / (y)) * y) +#endif + +#ifndef CUB_ROUND_DOWN_NEAREST + /// x rounded down to the nearest multiple of y + #define CUB_ROUND_DOWN_NEAREST(x, y) (((x) / (y)) * y) +#endif + + +#ifndef CUB_STATIC_ASSERT + #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + #define CUB_CAT_(a, b) a ## b + #define CUB_CAT(a, b) CUB_CAT_(a, b) + #endif // DOXYGEN_SHOULD_SKIP_THIS + + /// Static assert + #define CUB_STATIC_ASSERT(cond, msg) typedef int CUB_CAT(cub_static_assert, __LINE__)[(cond) ? 1 : -1] +#endif + +/** @} */ // end group UtilModule + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/util_namespace.cuh b/fastertransformer/cuda/cub/util_namespace.cuh new file mode 100644 index 000000000..c8991d08f --- /dev/null +++ b/fastertransformer/cuda/cub/util_namespace.cuh @@ -0,0 +1,46 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Place-holder for prefixing the cub namespace + */ + +#pragma once + +// For example: +//#define CUB_NS_PREFIX namespace thrust{ namespace detail { +//#define CUB_NS_POSTFIX } } + +#ifndef CUB_NS_PREFIX +#define CUB_NS_PREFIX +#endif + +#ifndef CUB_NS_POSTFIX +#define CUB_NS_POSTFIX +#endif diff --git a/fastertransformer/cuda/cub/util_ptx.cuh b/fastertransformer/cuda/cub/util_ptx.cuh new file mode 100644 index 000000000..582ca0d8b --- /dev/null +++ b/fastertransformer/cuda/cub/util_ptx.cuh @@ -0,0 +1,758 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * PTX intrinsics + */ + + +#pragma once + +#include "util_type.cuh" +#include "util_arch.cuh" +#include "util_namespace.cuh" +#include "util_debug.cuh" + + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \addtogroup UtilPtx + * @{ + */ + + +/****************************************************************************** + * PTX helper macros + ******************************************************************************/ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +/** + * Register modifier for pointer-types (for inlining PTX assembly) + */ +#if defined(_WIN64) || defined(__LP64__) + #define __CUB_LP64__ 1 + // 64-bit register modifier for inlined asm + #define _CUB_ASM_PTR_ "l" + #define _CUB_ASM_PTR_SIZE_ "u64" +#else + #define __CUB_LP64__ 0 + // 32-bit register modifier for inlined asm + #define _CUB_ASM_PTR_ "r" + #define _CUB_ASM_PTR_SIZE_ "u32" +#endif + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + +/****************************************************************************** + * Inlined PTX intrinsics + ******************************************************************************/ + +/** + * \brief Shift-right then add. Returns (\p x >> \p shift) + \p addend. + */ +__device__ __forceinline__ unsigned int SHR_ADD( + unsigned int x, + unsigned int shift, + unsigned int addend) +{ + unsigned int ret; +#if CUB_PTX_ARCH >= 200 + asm ("vshr.u32.u32.u32.clamp.add %0, %1, %2, %3;" : + "=r"(ret) : "r"(x), "r"(shift), "r"(addend)); +#else + ret = (x >> shift) + addend; +#endif + return ret; +} + + +/** + * \brief Shift-left then add. Returns (\p x << \p shift) + \p addend. + */ +__device__ __forceinline__ unsigned int SHL_ADD( + unsigned int x, + unsigned int shift, + unsigned int addend) +{ + unsigned int ret; +#if CUB_PTX_ARCH >= 200 + asm ("vshl.u32.u32.u32.clamp.add %0, %1, %2, %3;" : + "=r"(ret) : "r"(x), "r"(shift), "r"(addend)); +#else + ret = (x << shift) + addend; +#endif + return ret; +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +/** + * Bitfield-extract. + */ +template +__device__ __forceinline__ unsigned int BFE( + UnsignedBits source, + unsigned int bit_start, + unsigned int num_bits, + Int2Type /*byte_len*/) +{ + unsigned int bits; +#if CUB_PTX_ARCH >= 200 + asm ("bfe.u32 %0, %1, %2, %3;" : "=r"(bits) : "r"((unsigned int) source), "r"(bit_start), "r"(num_bits)); +#else + const unsigned int MASK = (1 << num_bits) - 1; + bits = (source >> bit_start) & MASK; +#endif + return bits; +} + + +/** + * Bitfield-extract for 64-bit types. + */ +template +__device__ __forceinline__ unsigned int BFE( + UnsignedBits source, + unsigned int bit_start, + unsigned int num_bits, + Int2Type<8> /*byte_len*/) +{ + const unsigned long long MASK = (1ull << num_bits) - 1; + return (source >> bit_start) & MASK; +} + +#endif // DOXYGEN_SHOULD_SKIP_THIS + +/** + * \brief Bitfield-extract. Extracts \p num_bits from \p source starting at bit-offset \p bit_start. The input \p source may be an 8b, 16b, 32b, or 64b unsigned integer type. + */ +template +__device__ __forceinline__ unsigned int BFE( + UnsignedBits source, + unsigned int bit_start, + unsigned int num_bits) +{ + return BFE(source, bit_start, num_bits, Int2Type()); +} + + +/** + * \brief Bitfield insert. Inserts the \p num_bits least significant bits of \p y into \p x at bit-offset \p bit_start. + */ +__device__ __forceinline__ void BFI( + unsigned int &ret, + unsigned int x, + unsigned int y, + unsigned int bit_start, + unsigned int num_bits) +{ +#if CUB_PTX_ARCH >= 200 + asm ("bfi.b32 %0, %1, %2, %3, %4;" : + "=r"(ret) : "r"(y), "r"(x), "r"(bit_start), "r"(num_bits)); +#else + x <<= bit_start; + unsigned int MASK_X = ((1 << num_bits) - 1) << bit_start; + unsigned int MASK_Y = ~MASK_X; + ret = (y & MASK_Y) | (x & MASK_X); +#endif +} + + +/** + * \brief Three-operand add. Returns \p x + \p y + \p z. + */ +__device__ __forceinline__ unsigned int IADD3(unsigned int x, unsigned int y, unsigned int z) +{ +#if CUB_PTX_ARCH >= 200 + asm ("vadd.u32.u32.u32.add %0, %1, %2, %3;" : "=r"(x) : "r"(x), "r"(y), "r"(z)); +#else + x = x + y + z; +#endif + return x; +} + + +/** + * \brief Byte-permute. Pick four arbitrary bytes from two 32-bit registers, and reassemble them into a 32-bit destination register. For SM2.0 or later. + * + * \par + * The bytes in the two source registers \p a and \p b are numbered from 0 to 7: + * {\p b, \p a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}. For each of the four bytes + * {b3, b2, b1, b0} selected in the return value, a 4-bit selector is defined within + * the four lower "nibbles" of \p index: {\p index } = {n7, n6, n5, n4, n3, n2, n1, n0} + * + * \par Snippet + * The code snippet below illustrates byte-permute. + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * int a = 0x03020100; + * int b = 0x07060504; + * int index = 0x00007531; + * + * int selected = PRMT(a, b, index); // 0x07050301 + * + * \endcode + * + */ +__device__ __forceinline__ int PRMT(unsigned int a, unsigned int b, unsigned int index) +{ + int ret; + asm ("prmt.b32 %0, %1, %2, %3;" : "=r"(ret) : "r"(a), "r"(b), "r"(index)); + return ret; +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +/** + * Sync-threads barrier. + */ +__device__ __forceinline__ void BAR(int count) +{ + asm volatile("bar.sync 1, %0;" : : "r"(count)); +} + +/** + * CTA barrier + */ +__device__ __forceinline__ void CTA_SYNC() +{ + __syncthreads(); +} + + +/** + * CTA barrier with predicate + */ +__device__ __forceinline__ int CTA_SYNC_AND(int p) +{ + return __syncthreads_and(p); +} + + +/** + * Warp barrier + */ +__device__ __forceinline__ void WARP_SYNC(unsigned int member_mask) +{ +#ifdef CUB_USE_COOPERATIVE_GROUPS + __syncwarp(member_mask); +#endif +} + + +/** + * Warp any + */ +__device__ __forceinline__ int WARP_ANY(int predicate, unsigned int member_mask) +{ +#ifdef CUB_USE_COOPERATIVE_GROUPS + return __any_sync(member_mask, predicate); +#else + return ::__any(predicate); +#endif +} + + +/** + * Warp any + */ +__device__ __forceinline__ int WARP_ALL(int predicate, unsigned int member_mask) +{ +#ifdef CUB_USE_COOPERATIVE_GROUPS + return __all_sync(member_mask, predicate); +#else + return ::__all(predicate); +#endif +} + + +/** + * Warp ballot + */ +__device__ __forceinline__ int WARP_BALLOT(int predicate, unsigned int member_mask) +{ +#ifdef CUB_USE_COOPERATIVE_GROUPS + return __ballot_sync(member_mask, predicate); +#else + return __ballot(predicate); +#endif +} + +/** + * Warp synchronous shfl_up + */ +__device__ __forceinline__ +unsigned int SHFL_UP_SYNC(unsigned int word, int src_offset, int flags, unsigned int member_mask) +{ +#ifdef CUB_USE_COOPERATIVE_GROUPS + asm volatile("shfl.sync.up.b32 %0, %1, %2, %3, %4;" + : "=r"(word) : "r"(word), "r"(src_offset), "r"(flags), "r"(member_mask)); +#else + asm volatile("shfl.up.b32 %0, %1, %2, %3;" + : "=r"(word) : "r"(word), "r"(src_offset), "r"(flags)); +#endif + return word; +} + +/** + * Warp synchronous shfl_down + */ +__device__ __forceinline__ +unsigned int SHFL_DOWN_SYNC(unsigned int word, int src_offset, int flags, unsigned int member_mask) +{ +#ifdef CUB_USE_COOPERATIVE_GROUPS + asm volatile("shfl.sync.down.b32 %0, %1, %2, %3, %4;" + : "=r"(word) : "r"(word), "r"(src_offset), "r"(flags), "r"(member_mask)); +#else + asm volatile("shfl.down.b32 %0, %1, %2, %3;" + : "=r"(word) : "r"(word), "r"(src_offset), "r"(flags)); +#endif + return word; +} + +/** + * Warp synchronous shfl_idx + */ +__device__ __forceinline__ +unsigned int SHFL_IDX_SYNC(unsigned int word, int src_lane, int flags, unsigned int member_mask) +{ +#ifdef CUB_USE_COOPERATIVE_GROUPS + asm volatile("shfl.sync.idx.b32 %0, %1, %2, %3, %4;" + : "=r"(word) : "r"(word), "r"(src_lane), "r"(flags), "r"(member_mask)); +#else + asm volatile("shfl.idx.b32 %0, %1, %2, %3;" + : "=r"(word) : "r"(word), "r"(src_lane), "r"(flags)); +#endif + return word; +} + +/** + * Floating point multiply. (Mantissa LSB rounds towards zero.) + */ +__device__ __forceinline__ float FMUL_RZ(float a, float b) +{ + float d; + asm ("mul.rz.f32 %0, %1, %2;" : "=f"(d) : "f"(a), "f"(b)); + return d; +} + + +/** + * Floating point multiply-add. (Mantissa LSB rounds towards zero.) + */ +__device__ __forceinline__ float FFMA_RZ(float a, float b, float c) +{ + float d; + asm ("fma.rz.f32 %0, %1, %2, %3;" : "=f"(d) : "f"(a), "f"(b), "f"(c)); + return d; +} + +#endif // DOXYGEN_SHOULD_SKIP_THIS + +/** + * \brief Terminates the calling thread + */ +__device__ __forceinline__ void ThreadExit() { + asm volatile("exit;"); +} + + +/** + * \brief Abort execution and generate an interrupt to the host CPU + */ +__device__ __forceinline__ void ThreadTrap() { + asm volatile("trap;"); +} + + +/** + * \brief Returns the row-major linear thread identifier for a multidimensional thread block + */ +__device__ __forceinline__ int RowMajorTid(int block_dim_x, int block_dim_y, int block_dim_z) +{ + return ((block_dim_z == 1) ? 0 : (threadIdx.z * block_dim_x * block_dim_y)) + + ((block_dim_y == 1) ? 0 : (threadIdx.y * block_dim_x)) + + threadIdx.x; +} + + +/** + * \brief Returns the warp lane ID of the calling thread + */ +__device__ __forceinline__ unsigned int LaneId() +{ + unsigned int ret; + asm ("mov.u32 %0, %%laneid;" : "=r"(ret) ); + return ret; +} + + +/** + * \brief Returns the warp ID of the calling thread. Warp ID is guaranteed to be unique among warps, but may not correspond to a zero-based ranking within the thread block. + */ +__device__ __forceinline__ unsigned int WarpId() +{ + unsigned int ret; + asm ("mov.u32 %0, %%warpid;" : "=r"(ret) ); + return ret; +} + +/** + * \brief Returns the warp lane mask of all lanes less than the calling thread + */ +__device__ __forceinline__ unsigned int LaneMaskLt() +{ + unsigned int ret; + asm ("mov.u32 %0, %%lanemask_lt;" : "=r"(ret) ); + return ret; +} + +/** + * \brief Returns the warp lane mask of all lanes less than or equal to the calling thread + */ +__device__ __forceinline__ unsigned int LaneMaskLe() +{ + unsigned int ret; + asm ("mov.u32 %0, %%lanemask_le;" : "=r"(ret) ); + return ret; +} + +/** + * \brief Returns the warp lane mask of all lanes greater than the calling thread + */ +__device__ __forceinline__ unsigned int LaneMaskGt() +{ + unsigned int ret; + asm ("mov.u32 %0, %%lanemask_gt;" : "=r"(ret) ); + return ret; +} + +/** + * \brief Returns the warp lane mask of all lanes greater than or equal to the calling thread + */ +__device__ __forceinline__ unsigned int LaneMaskGe() +{ + unsigned int ret; + asm ("mov.u32 %0, %%lanemask_ge;" : "=r"(ret) ); + return ret; +} + +/** @} */ // end group UtilPtx + + + + +/** + * \brief Shuffle-up for any data type. Each warp-lanei obtains the value \p input contributed by warp-lanei-src_offset. For thread lanes \e i < src_offset, the thread's own \p input is returned to the thread. ![](shfl_up_logo.png) + * \ingroup WarpModule + * + * \tparam LOGICAL_WARP_THREADS The number of threads per "logical" warp. Must be a power-of-two <= 32. + * \tparam T [inferred] The input/output element type + * + * \par + * - Available only for SM3.0 or newer + * + * \par Snippet + * The code snippet below illustrates each thread obtaining a \p double value from the + * predecessor of its predecessor. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Obtain one input item per thread + * double thread_data = ... + * + * // Obtain item from two ranks below + * double peer_data = ShuffleUp<32>(thread_data, 2, 0, 0xffffffff); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the first warp of threads is {1.0, 2.0, 3.0, 4.0, 5.0, ..., 32.0}. + * The corresponding output \p peer_data will be {1.0, 2.0, 1.0, 2.0, 3.0, ..., 30.0}. + * + */ +template < + int LOGICAL_WARP_THREADS, ///< Number of threads per logical warp + typename T> +__device__ __forceinline__ T ShuffleUp( + T input, ///< [in] The value to broadcast + int src_offset, ///< [in] The relative down-offset of the peer to read from + int first_thread, ///< [in] Index of first lane in logical warp (typically 0) + unsigned int member_mask) ///< [in] 32-bit mask of participating warp lanes +{ + /// The 5-bit SHFL mask for logically splitting warps into sub-segments starts 8-bits up + enum { + SHFL_C = (32 - LOGICAL_WARP_THREADS) << 8 + }; + + typedef typename UnitWord::ShuffleWord ShuffleWord; + + const int WORDS = (sizeof(T) + sizeof(ShuffleWord) - 1) / sizeof(ShuffleWord); + + T output; + ShuffleWord *output_alias = reinterpret_cast(&output); + ShuffleWord *input_alias = reinterpret_cast(&input); + + unsigned int shuffle_word; + shuffle_word = SHFL_UP_SYNC((unsigned int)input_alias[0], src_offset, first_thread | SHFL_C, member_mask); + output_alias[0] = shuffle_word; + + #pragma unroll + for (int WORD = 1; WORD < WORDS; ++WORD) + { + shuffle_word = SHFL_UP_SYNC((unsigned int)input_alias[WORD], src_offset, first_thread | SHFL_C, member_mask); + output_alias[WORD] = shuffle_word; + } + + return output; +} + + +/** + * \brief Shuffle-down for any data type. Each warp-lanei obtains the value \p input contributed by warp-lanei+src_offset. For thread lanes \e i >= WARP_THREADS, the thread's own \p input is returned to the thread. ![](shfl_down_logo.png) + * \ingroup WarpModule + * + * \tparam LOGICAL_WARP_THREADS The number of threads per "logical" warp. Must be a power-of-two <= 32. + * \tparam T [inferred] The input/output element type + * + * \par + * - Available only for SM3.0 or newer + * + * \par Snippet + * The code snippet below illustrates each thread obtaining a \p double value from the + * successor of its successor. + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Obtain one input item per thread + * double thread_data = ... + * + * // Obtain item from two ranks below + * double peer_data = ShuffleDown<32>(thread_data, 2, 31, 0xffffffff); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the first warp of threads is {1.0, 2.0, 3.0, 4.0, 5.0, ..., 32.0}. + * The corresponding output \p peer_data will be {3.0, 4.0, 5.0, 6.0, 7.0, ..., 32.0}. + * + */ +template < + int LOGICAL_WARP_THREADS, ///< Number of threads per logical warp + typename T> +__device__ __forceinline__ T ShuffleDown( + T input, ///< [in] The value to broadcast + int src_offset, ///< [in] The relative up-offset of the peer to read from + int last_thread, ///< [in] Index of last thread in logical warp (typically 31 for a 32-thread warp) + unsigned int member_mask) ///< [in] 32-bit mask of participating warp lanes +{ + /// The 5-bit SHFL mask for logically splitting warps into sub-segments starts 8-bits up + enum { + SHFL_C = (32 - LOGICAL_WARP_THREADS) << 8 + }; + + typedef typename UnitWord::ShuffleWord ShuffleWord; + + const int WORDS = (sizeof(T) + sizeof(ShuffleWord) - 1) / sizeof(ShuffleWord); + + T output; + ShuffleWord *output_alias = reinterpret_cast(&output); + ShuffleWord *input_alias = reinterpret_cast(&input); + + unsigned int shuffle_word; + shuffle_word = SHFL_DOWN_SYNC((unsigned int)input_alias[0], src_offset, last_thread | SHFL_C, member_mask); + output_alias[0] = shuffle_word; + + #pragma unroll + for (int WORD = 1; WORD < WORDS; ++WORD) + { + shuffle_word = SHFL_DOWN_SYNC((unsigned int)input_alias[WORD], src_offset, last_thread | SHFL_C, member_mask); + output_alias[WORD] = shuffle_word; + } + + return output; +} + + +/** + * \brief Shuffle-broadcast for any data type. Each warp-lanei obtains the value \p input + * contributed by warp-lanesrc_lane. For \p src_lane < 0 or \p src_lane >= WARP_THREADS, + * then the thread's own \p input is returned to the thread. ![](shfl_broadcast_logo.png) + * + * \tparam LOGICAL_WARP_THREADS The number of threads per "logical" warp. Must be a power-of-two <= 32. + * \tparam T [inferred] The input/output element type + * + * \ingroup WarpModule + * + * \par + * - Available only for SM3.0 or newer + * + * \par Snippet + * The code snippet below illustrates each thread obtaining a \p double value from warp-lane0. + * + * \par + * \code + * #include // or equivalently + * + * __global__ void ExampleKernel(...) + * { + * // Obtain one input item per thread + * double thread_data = ... + * + * // Obtain item from thread 0 + * double peer_data = ShuffleIndex<32>(thread_data, 0, 0xffffffff); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the first warp of threads is {1.0, 2.0, 3.0, 4.0, 5.0, ..., 32.0}. + * The corresponding output \p peer_data will be {1.0, 1.0, 1.0, 1.0, 1.0, ..., 1.0}. + * + */ +template < + int LOGICAL_WARP_THREADS, ///< Number of threads per logical warp + typename T> +__device__ __forceinline__ T ShuffleIndex( + T input, ///< [in] The value to broadcast + int src_lane, ///< [in] Which warp lane is to do the broadcasting + unsigned int member_mask) ///< [in] 32-bit mask of participating warp lanes +{ + /// The 5-bit SHFL mask for logically splitting warps into sub-segments starts 8-bits up + enum { + SHFL_C = ((32 - LOGICAL_WARP_THREADS) << 8) | (LOGICAL_WARP_THREADS - 1) + }; + + typedef typename UnitWord::ShuffleWord ShuffleWord; + + const int WORDS = (sizeof(T) + sizeof(ShuffleWord) - 1) / sizeof(ShuffleWord); + + T output; + ShuffleWord *output_alias = reinterpret_cast(&output); + ShuffleWord *input_alias = reinterpret_cast(&input); + + unsigned int shuffle_word; + shuffle_word = SHFL_IDX_SYNC((unsigned int)input_alias[0], + src_lane, + SHFL_C, + member_mask); + + output_alias[0] = shuffle_word; + + #pragma unroll + for (int WORD = 1; WORD < WORDS; ++WORD) + { + shuffle_word = SHFL_IDX_SYNC((unsigned int)input_alias[WORD], + src_lane, + SHFL_C, + member_mask); + + output_alias[WORD] = shuffle_word; + } + + return output; +} + + + +/** + * Compute a 32b mask of threads having the same least-significant + * LABEL_BITS of \p label as the calling thread. + */ +template +inline __device__ unsigned int MatchAny(unsigned int label) +{ + unsigned int retval; + + // Extract masks of common threads for each bit + #pragma unroll + for (int BIT = 0; BIT < LABEL_BITS; ++BIT) + { + unsigned int mask; + unsigned int current_bit = 1 << BIT; + asm ("{\n" + " .reg .pred p;\n" + " and.b32 %0, %1, %2;" + " setp.eq.u32 p, %0, %2;\n" +#ifdef CUB_USE_COOPERATIVE_GROUPS + " vote.ballot.sync.b32 %0, p, 0xffffffff;\n" +#else + " vote.ballot.b32 %0, p;\n" +#endif + " @!p not.b32 %0, %0;\n" + "}\n" : "=r"(mask) : "r"(label), "r"(current_bit)); + + // Remove peers who differ + retval = (BIT == 0) ? mask : retval & mask; + } + + return retval; + +// // VOLTA match +// unsigned int retval; +// asm ("{\n" +// " match.any.sync.b32 %0, %1, 0xffffffff;\n" +// "}\n" : "=r"(retval) : "r"(label)); +// return retval; + +} + + + + + + + + + + + + + + + + + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/util_type.cuh b/fastertransformer/cuda/cub/util_type.cuh new file mode 100644 index 000000000..0ba41e1ed --- /dev/null +++ b/fastertransformer/cuda/cub/util_type.cuh @@ -0,0 +1,1167 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * Common type manipulation (metaprogramming) utilities + */ + +#pragma once + +#include +#include +#include + +#if (__CUDACC_VER_MAJOR__ >= 9) + #include +#endif + +#include "util_macro.cuh" +#include "util_arch.cuh" +#include "util_namespace.cuh" + + + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \addtogroup UtilModule + * @{ + */ + + + +/****************************************************************************** + * Type equality + ******************************************************************************/ + +/** + * \brief Type selection (IF ? ThenType : ElseType) + */ +template +struct If +{ + /// Conditional type result + typedef ThenType Type; // true +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +template +struct If +{ + typedef ElseType Type; // false +}; + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + + +/****************************************************************************** + * Conditional types + ******************************************************************************/ + +/** + * \brief Type equality test + */ +template +struct Equals +{ + enum { + VALUE = 0, + NEGATE = 1 + }; +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +template +struct Equals +{ + enum { + VALUE = 1, + NEGATE = 0 + }; +}; + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + +/****************************************************************************** + * Static math + ******************************************************************************/ + +/** + * \brief Statically determine log2(N), rounded up. + * + * For example: + * Log2<8>::VALUE // 3 + * Log2<3>::VALUE // 2 + */ +template +struct Log2 +{ + /// Static logarithm value + enum { VALUE = Log2> 1), COUNT + 1>::VALUE }; // Inductive case +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +template +struct Log2 +{ + enum {VALUE = (1 << (COUNT - 1) < N) ? // Base case + COUNT : + COUNT - 1 }; +}; + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + +/** + * \brief Statically determine if N is a power-of-two + */ +template +struct PowerOfTwo +{ + enum { VALUE = ((N & (N - 1)) == 0) }; +}; + + + +/****************************************************************************** + * Pointer vs. iterator detection + ******************************************************************************/ + +/** + * \brief Pointer vs. iterator + */ +template +struct IsPointer +{ + enum { VALUE = 0 }; +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +template +struct IsPointer +{ + enum { VALUE = 1 }; +}; + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + + +/****************************************************************************** + * Qualifier detection + ******************************************************************************/ + +/** + * \brief Volatile modifier test + */ +template +struct IsVolatile +{ + enum { VALUE = 0 }; +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +template +struct IsVolatile +{ + enum { VALUE = 1 }; +}; + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + +/****************************************************************************** + * Qualifier removal + ******************************************************************************/ + +/** + * \brief Removes \p const and \p volatile qualifiers from type \p Tp. + * + * For example: + * typename RemoveQualifiers::Type // int; + */ +template +struct RemoveQualifiers +{ + /// Type without \p const and \p volatile qualifiers + typedef Up Type; +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +template +struct RemoveQualifiers +{ + typedef Up Type; +}; + +template +struct RemoveQualifiers +{ + typedef Up Type; +}; + +template +struct RemoveQualifiers +{ + typedef Up Type; +}; + + +/****************************************************************************** + * Marker types + ******************************************************************************/ + +/** + * \brief A simple "NULL" marker type + */ +struct NullType +{ +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + template + __host__ __device__ __forceinline__ NullType& operator =(const T&) { return *this; } + + __host__ __device__ __forceinline__ bool operator ==(const NullType&) { return true; } + + __host__ __device__ __forceinline__ bool operator !=(const NullType&) { return false; } + +#endif // DOXYGEN_SHOULD_SKIP_THIS +}; + + +/** + * \brief Allows for the treatment of an integral constant as a type at compile-time (e.g., to achieve static call dispatch based on constant integral values) + */ +template +struct Int2Type +{ + enum {VALUE = A}; +}; + + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + +/****************************************************************************** + * Size and alignment + ******************************************************************************/ + +/// Structure alignment +template +struct AlignBytes +{ + struct Pad + { + T val; + char byte; + }; + + enum + { + /// The "true CUDA" alignment of T in bytes + ALIGN_BYTES = sizeof(Pad) - sizeof(T) + }; + + /// The "truly aligned" type + typedef T Type; +}; + +// Specializations where host C++ compilers (e.g., 32-bit Windows) may disagree +// with device C++ compilers (EDG) on types passed as template parameters through +// kernel functions + +#define __CUB_ALIGN_BYTES(t, b) \ + template <> struct AlignBytes \ + { enum { ALIGN_BYTES = b }; typedef __align__(b) t Type; }; + +__CUB_ALIGN_BYTES(short4, 8) +__CUB_ALIGN_BYTES(ushort4, 8) +__CUB_ALIGN_BYTES(int2, 8) +__CUB_ALIGN_BYTES(uint2, 8) +__CUB_ALIGN_BYTES(long long, 8) +__CUB_ALIGN_BYTES(unsigned long long, 8) +__CUB_ALIGN_BYTES(float2, 8) +__CUB_ALIGN_BYTES(double, 8) +#ifdef _WIN32 + __CUB_ALIGN_BYTES(long2, 8) + __CUB_ALIGN_BYTES(ulong2, 8) +#else + __CUB_ALIGN_BYTES(long2, 16) + __CUB_ALIGN_BYTES(ulong2, 16) +#endif +__CUB_ALIGN_BYTES(int4, 16) +__CUB_ALIGN_BYTES(uint4, 16) +__CUB_ALIGN_BYTES(float4, 16) +__CUB_ALIGN_BYTES(long4, 16) +__CUB_ALIGN_BYTES(ulong4, 16) +__CUB_ALIGN_BYTES(longlong2, 16) +__CUB_ALIGN_BYTES(ulonglong2, 16) +__CUB_ALIGN_BYTES(double2, 16) +__CUB_ALIGN_BYTES(longlong4, 16) +__CUB_ALIGN_BYTES(ulonglong4, 16) +__CUB_ALIGN_BYTES(double4, 16) + +template struct AlignBytes : AlignBytes {}; +template struct AlignBytes : AlignBytes {}; +template struct AlignBytes : AlignBytes {}; + + +/// Unit-words of data movement +template +struct UnitWord +{ + enum { + ALIGN_BYTES = AlignBytes::ALIGN_BYTES + }; + + template + struct IsMultiple + { + enum { + UNIT_ALIGN_BYTES = AlignBytes::ALIGN_BYTES, + IS_MULTIPLE = (sizeof(T) % sizeof(Unit) == 0) && (ALIGN_BYTES % UNIT_ALIGN_BYTES == 0) + }; + }; + + /// Biggest shuffle word that T is a whole multiple of and is not larger than the alignment of T + typedef typename If::IS_MULTIPLE, + unsigned int, + typename If::IS_MULTIPLE, + unsigned short, + unsigned char>::Type>::Type ShuffleWord; + + /// Biggest volatile word that T is a whole multiple of and is not larger than the alignment of T + typedef typename If::IS_MULTIPLE, + unsigned long long, + ShuffleWord>::Type VolatileWord; + + /// Biggest memory-access word that T is a whole multiple of and is not larger than the alignment of T + typedef typename If::IS_MULTIPLE, + ulonglong2, + VolatileWord>::Type DeviceWord; + + /// Biggest texture reference word that T is a whole multiple of and is not larger than the alignment of T + typedef typename If::IS_MULTIPLE, + uint4, + typename If::IS_MULTIPLE, + uint2, + ShuffleWord>::Type>::Type TextureWord; +}; + + +// float2 specialization workaround (for SM10-SM13) +template <> +struct UnitWord +{ + typedef int ShuffleWord; +#if (CUB_PTX_ARCH > 0) && (CUB_PTX_ARCH <= 130) + typedef float VolatileWord; + typedef uint2 DeviceWord; +#else + typedef unsigned long long VolatileWord; + typedef unsigned long long DeviceWord; +#endif + typedef float2 TextureWord; +}; + +// float4 specialization workaround (for SM10-SM13) +template <> +struct UnitWord +{ + typedef int ShuffleWord; +#if (CUB_PTX_ARCH > 0) && (CUB_PTX_ARCH <= 130) + typedef float VolatileWord; + typedef uint4 DeviceWord; +#else + typedef unsigned long long VolatileWord; + typedef ulonglong2 DeviceWord; +#endif + typedef float4 TextureWord; +}; + + +// char2 specialization workaround (for SM10-SM13) +template <> +struct UnitWord +{ + typedef unsigned short ShuffleWord; +#if (CUB_PTX_ARCH > 0) && (CUB_PTX_ARCH <= 130) + typedef unsigned short VolatileWord; + typedef short DeviceWord; +#else + typedef unsigned short VolatileWord; + typedef unsigned short DeviceWord; +#endif + typedef unsigned short TextureWord; +}; + + +template struct UnitWord : UnitWord {}; +template struct UnitWord : UnitWord {}; +template struct UnitWord : UnitWord {}; + + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + + +/****************************************************************************** + * Vector type inference utilities. + ******************************************************************************/ + +/** + * \brief Exposes a member typedef \p Type that names the corresponding CUDA vector type if one exists. Otherwise \p Type refers to the CubVector structure itself, which will wrap the corresponding \p x, \p y, etc. vector fields. + */ +template struct CubVector; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +enum +{ + /// The maximum number of elements in CUDA vector types + MAX_VEC_ELEMENTS = 4, +}; + + +/** + * Generic vector-1 type + */ +template +struct CubVector +{ + T x; + + typedef T BaseType; + typedef CubVector Type; +}; + +/** + * Generic vector-2 type + */ +template +struct CubVector +{ + T x; + T y; + + typedef T BaseType; + typedef CubVector Type; +}; + +/** + * Generic vector-3 type + */ +template +struct CubVector +{ + T x; + T y; + T z; + + typedef T BaseType; + typedef CubVector Type; +}; + +/** + * Generic vector-4 type + */ +template +struct CubVector +{ + T x; + T y; + T z; + T w; + + typedef T BaseType; + typedef CubVector Type; +}; + + +/** + * Macro for expanding partially-specialized built-in vector types + */ +#define CUB_DEFINE_VECTOR_TYPE(base_type,short_type) \ + \ + template<> struct CubVector : short_type##1 \ + { \ + typedef base_type BaseType; \ + typedef short_type##1 Type; \ + __host__ __device__ __forceinline__ CubVector operator+(const CubVector &other) const { \ + CubVector retval; \ + retval.x = x + other.x; \ + return retval; \ + } \ + __host__ __device__ __forceinline__ CubVector operator-(const CubVector &other) const { \ + CubVector retval; \ + retval.x = x - other.x; \ + return retval; \ + } \ + }; \ + \ + template<> struct CubVector : short_type##2 \ + { \ + typedef base_type BaseType; \ + typedef short_type##2 Type; \ + __host__ __device__ __forceinline__ CubVector operator+(const CubVector &other) const { \ + CubVector retval; \ + retval.x = x + other.x; \ + retval.y = y + other.y; \ + return retval; \ + } \ + __host__ __device__ __forceinline__ CubVector operator-(const CubVector &other) const { \ + CubVector retval; \ + retval.x = x - other.x; \ + retval.y = y - other.y; \ + return retval; \ + } \ + }; \ + \ + template<> struct CubVector : short_type##3 \ + { \ + typedef base_type BaseType; \ + typedef short_type##3 Type; \ + __host__ __device__ __forceinline__ CubVector operator+(const CubVector &other) const { \ + CubVector retval; \ + retval.x = x + other.x; \ + retval.y = y + other.y; \ + retval.z = z + other.z; \ + return retval; \ + } \ + __host__ __device__ __forceinline__ CubVector operator-(const CubVector &other) const { \ + CubVector retval; \ + retval.x = x - other.x; \ + retval.y = y - other.y; \ + retval.z = z - other.z; \ + return retval; \ + } \ + }; \ + \ + template<> struct CubVector : short_type##4 \ + { \ + typedef base_type BaseType; \ + typedef short_type##4 Type; \ + __host__ __device__ __forceinline__ CubVector operator+(const CubVector &other) const { \ + CubVector retval; \ + retval.x = x + other.x; \ + retval.y = y + other.y; \ + retval.z = z + other.z; \ + retval.w = w + other.w; \ + return retval; \ + } \ + __host__ __device__ __forceinline__ CubVector operator-(const CubVector &other) const { \ + CubVector retval; \ + retval.x = x - other.x; \ + retval.y = y - other.y; \ + retval.z = z - other.z; \ + retval.w = w - other.w; \ + return retval; \ + } \ + }; + + + +// Expand CUDA vector types for built-in primitives +CUB_DEFINE_VECTOR_TYPE(char, char) +CUB_DEFINE_VECTOR_TYPE(signed char, char) +CUB_DEFINE_VECTOR_TYPE(short, short) +CUB_DEFINE_VECTOR_TYPE(int, int) +CUB_DEFINE_VECTOR_TYPE(long, long) +CUB_DEFINE_VECTOR_TYPE(long long, longlong) +CUB_DEFINE_VECTOR_TYPE(unsigned char, uchar) +CUB_DEFINE_VECTOR_TYPE(unsigned short, ushort) +CUB_DEFINE_VECTOR_TYPE(unsigned int, uint) +CUB_DEFINE_VECTOR_TYPE(unsigned long, ulong) +CUB_DEFINE_VECTOR_TYPE(unsigned long long, ulonglong) +CUB_DEFINE_VECTOR_TYPE(float, float) +CUB_DEFINE_VECTOR_TYPE(double, double) +CUB_DEFINE_VECTOR_TYPE(bool, uchar) + +// Undefine macros +#undef CUB_DEFINE_VECTOR_TYPE + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + + +/****************************************************************************** + * Wrapper types + ******************************************************************************/ + +/** + * \brief A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions + */ +template +struct Uninitialized +{ + /// Biggest memory-access word that T is a whole multiple of and is not larger than the alignment of T + typedef typename UnitWord::DeviceWord DeviceWord; + + enum + { + WORDS = sizeof(T) / sizeof(DeviceWord) + }; + + /// Backing storage + DeviceWord storage[WORDS]; + + /// Alias + __host__ __device__ __forceinline__ T& Alias() + { + return reinterpret_cast(*this); + } +}; + + +/** + * \brief A key identifier paired with a corresponding value + */ +template < + typename _Key, + typename _Value +#if defined(_WIN32) && !defined(_WIN64) + , bool KeyIsLT = (AlignBytes<_Key>::ALIGN_BYTES < AlignBytes<_Value>::ALIGN_BYTES) + , bool ValIsLT = (AlignBytes<_Value>::ALIGN_BYTES < AlignBytes<_Key>::ALIGN_BYTES) +#endif // #if defined(_WIN32) && !defined(_WIN64) + > +struct KeyValuePair +{ + typedef _Key Key; ///< Key data type + typedef _Value Value; ///< Value data type + + Key key; ///< Item key + Value value; ///< Item value + + /// Constructor + __host__ __device__ __forceinline__ + KeyValuePair() {} + + /// Constructor + __host__ __device__ __forceinline__ + KeyValuePair(Key const& key, Value const& value) : key(key), value(value) {} + + /// Inequality operator + __host__ __device__ __forceinline__ bool operator !=(const KeyValuePair &b) + { + return (value != b.value) || (key != b.key); + } +}; + +#if defined(_WIN32) && !defined(_WIN64) + +/** + * Win32 won't do 16B alignment. This can present two problems for + * should-be-16B-aligned (but actually 8B aligned) built-in and intrinsics members: + * 1) If a smaller-aligned item were to be listed first, the host compiler places the + * should-be-16B item at too early an offset (and disagrees with device compiler) + * 2) Or, if a smaller-aligned item lists second, the host compiler gets the size + * of the struct wrong (and disagrees with device compiler) + * + * So we put the larger-should-be-aligned item first, and explicitly pad the + * end of the struct + */ + +/// Smaller key specialization +template +struct KeyValuePair +{ + typedef K Key; + typedef V Value; + + typedef char Pad[AlignBytes::ALIGN_BYTES - AlignBytes::ALIGN_BYTES]; + + Value value; // Value has larger would-be alignment and goes first + Key key; + Pad pad; + + /// Constructor + __host__ __device__ __forceinline__ + KeyValuePair() {} + + /// Constructor + __host__ __device__ __forceinline__ + KeyValuePair(Key const& key, Value const& value) : key(key), value(value) {} + + /// Inequality operator + __host__ __device__ __forceinline__ bool operator !=(const KeyValuePair &b) + { + return (value != b.value) || (key != b.key); + } +}; + + +/// Smaller value specialization +template +struct KeyValuePair +{ + typedef K Key; + typedef V Value; + + typedef char Pad[AlignBytes::ALIGN_BYTES - AlignBytes::ALIGN_BYTES]; + + Key key; // Key has larger would-be alignment and goes first + Value value; + Pad pad; + + /// Constructor + __host__ __device__ __forceinline__ + KeyValuePair() {} + + /// Constructor + __host__ __device__ __forceinline__ + KeyValuePair(Key const& key, Value const& value) : key(key), value(value) {} + + /// Inequality operator + __host__ __device__ __forceinline__ bool operator !=(const KeyValuePair &b) + { + return (value != b.value) || (key != b.key); + } +}; + +#endif // #if defined(_WIN32) && !defined(_WIN64) + + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + +/** + * \brief A wrapper for passing simple static arrays as kernel parameters + */ +template +struct ArrayWrapper +{ + + /// Statically-sized array of type \p T + T array[COUNT]; + + /// Constructor + __host__ __device__ __forceinline__ ArrayWrapper() {} +}; + +#endif // DOXYGEN_SHOULD_SKIP_THIS + +/** + * \brief Double-buffer storage wrapper for multi-pass stream transformations that require more than one storage array for streaming intermediate results back and forth. + * + * Many multi-pass computations require a pair of "ping-pong" storage + * buffers (e.g., one for reading from and the other for writing to, and then + * vice-versa for the subsequent pass). This structure wraps a set of device + * buffers and a "selector" member to track which is "current". + */ +template +struct DoubleBuffer +{ + /// Pair of device buffer pointers + T *d_buffers[2]; + + /// Selector into \p d_buffers (i.e., the active/valid buffer) + int selector; + + /// \brief Constructor + __host__ __device__ __forceinline__ DoubleBuffer() + { + selector = 0; + d_buffers[0] = NULL; + d_buffers[1] = NULL; + } + + /// \brief Constructor + __host__ __device__ __forceinline__ DoubleBuffer( + T *d_current, ///< The currently valid buffer + T *d_alternate) ///< Alternate storage buffer of the same size as \p d_current + { + selector = 0; + d_buffers[0] = d_current; + d_buffers[1] = d_alternate; + } + + /// \brief Return pointer to the currently valid buffer + __host__ __device__ __forceinline__ T* Current() { return d_buffers[selector]; } + + /// \brief Return pointer to the currently invalid buffer + __host__ __device__ __forceinline__ T* Alternate() { return d_buffers[selector ^ 1]; } + +}; + + + +/****************************************************************************** + * Typedef-detection + ******************************************************************************/ + + +/** + * \brief Defines a structure \p detector_name that is templated on type \p T. The \p detector_name struct exposes a constant member \p VALUE indicating whether or not parameter \p T exposes a nested type \p nested_type_name + */ +#define CUB_DEFINE_DETECT_NESTED_TYPE(detector_name, nested_type_name) \ + template \ + struct detector_name \ + { \ + template \ + static char& test(typename C::nested_type_name*); \ + template \ + static int& test(...); \ + enum \ + { \ + VALUE = sizeof(test(0)) < sizeof(int) \ + }; \ + }; + + + +/****************************************************************************** + * Simple enable-if (similar to Boost) + ******************************************************************************/ + +/** + * \brief Simple enable-if (similar to Boost) + */ +template +struct EnableIf +{ + /// Enable-if type for SFINAE dummy variables + typedef T Type; +}; + + +template +struct EnableIf {}; + + + +/****************************************************************************** + * Typedef-detection + ******************************************************************************/ + +/** + * \brief Determine whether or not BinaryOp's functor is of the form bool operator()(const T& a, const T&b) or bool operator()(const T& a, const T&b, unsigned int idx) + */ +template +struct BinaryOpHasIdxParam +{ +private: +/* + template struct SFINAE1 {}; + template struct SFINAE2 {}; + template struct SFINAE3 {}; + template struct SFINAE4 {}; +*/ + template struct SFINAE5 {}; + template struct SFINAE6 {}; + template struct SFINAE7 {}; + template struct SFINAE8 {}; +/* + template static char Test(SFINAE1 *); + template static char Test(SFINAE2 *); + template static char Test(SFINAE3 *); + template static char Test(SFINAE4 *); +*/ + template __host__ __device__ static char Test(SFINAE5 *); + template __host__ __device__ static char Test(SFINAE6 *); + template __host__ __device__ static char Test(SFINAE7 *); + template __host__ __device__ static char Test(SFINAE8 *); + + template static int Test(...); + +public: + + /// Whether the functor BinaryOp has a third unsigned int index param + static const bool HAS_PARAM = sizeof(Test(NULL)) == sizeof(char); +}; + + + + +/****************************************************************************** + * Simple type traits utilities. + * + * For example: + * Traits::CATEGORY // SIGNED_INTEGER + * Traits::NULL_TYPE // true + * Traits::CATEGORY // NOT_A_NUMBER + * Traits::PRIMITIVE; // false + * + ******************************************************************************/ + +/** + * \brief Basic type traits categories + */ +enum Category +{ + NOT_A_NUMBER, + SIGNED_INTEGER, + UNSIGNED_INTEGER, + FLOATING_POINT +}; + + +/** + * \brief Basic type traits + */ +template +struct BaseTraits +{ + /// Category + static const Category CATEGORY = _CATEGORY; + enum + { + PRIMITIVE = _PRIMITIVE, + NULL_TYPE = _NULL_TYPE, + }; +}; + + +/** + * Basic type traits (unsigned primitive specialization) + */ +template +struct BaseTraits +{ + typedef _UnsignedBits UnsignedBits; + + static const Category CATEGORY = UNSIGNED_INTEGER; + static const UnsignedBits LOWEST_KEY = UnsignedBits(0); + static const UnsignedBits MAX_KEY = UnsignedBits(-1); + + enum + { + PRIMITIVE = true, + NULL_TYPE = false, + }; + + + static __device__ __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key) + { + return key; + } + + static __device__ __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key) + { + return key; + } + + static __host__ __device__ __forceinline__ T Max() + { + UnsignedBits retval = MAX_KEY; + return reinterpret_cast(retval); + } + + static __host__ __device__ __forceinline__ T Lowest() + { + UnsignedBits retval = LOWEST_KEY; + return reinterpret_cast(retval); + } +}; + + +/** + * Basic type traits (signed primitive specialization) + */ +template +struct BaseTraits +{ + typedef _UnsignedBits UnsignedBits; + + static const Category CATEGORY = SIGNED_INTEGER; + static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1); + static const UnsignedBits LOWEST_KEY = HIGH_BIT; + static const UnsignedBits MAX_KEY = UnsignedBits(-1) ^ HIGH_BIT; + + enum + { + PRIMITIVE = true, + NULL_TYPE = false, + }; + + static __device__ __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key) + { + return key ^ HIGH_BIT; + }; + + static __device__ __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key) + { + return key ^ HIGH_BIT; + }; + + static __host__ __device__ __forceinline__ T Max() + { + UnsignedBits retval = MAX_KEY; + return reinterpret_cast(retval); + } + + static __host__ __device__ __forceinline__ T Lowest() + { + UnsignedBits retval = LOWEST_KEY; + return reinterpret_cast(retval); + } +}; + +template +struct FpLimits; + +template <> +struct FpLimits +{ + static __host__ __device__ __forceinline__ float Max() { + return FLT_MAX; + } + + static __host__ __device__ __forceinline__ float Lowest() { + return FLT_MAX * float(-1); + } +}; + +template <> +struct FpLimits +{ + static __host__ __device__ __forceinline__ double Max() { + return DBL_MAX; + } + + static __host__ __device__ __forceinline__ double Lowest() { + return DBL_MAX * double(-1); + } +}; + + +#if (__CUDACC_VER_MAJOR__ >= 9) +template <> +struct FpLimits<__half> +{ + static __host__ __device__ __forceinline__ __half Max() { + unsigned short max_word = 0x7BFF; + return reinterpret_cast<__half&>(max_word); + } + + static __host__ __device__ __forceinline__ __half Lowest() { + unsigned short lowest_word = 0xFBFF; + return reinterpret_cast<__half&>(lowest_word); + } +}; +#endif + + +/** + * Basic type traits (fp primitive specialization) + */ +template +struct BaseTraits +{ + typedef _UnsignedBits UnsignedBits; + + static const Category CATEGORY = FLOATING_POINT; + static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1); + static const UnsignedBits LOWEST_KEY = UnsignedBits(-1); + static const UnsignedBits MAX_KEY = UnsignedBits(-1) ^ HIGH_BIT; + + enum + { + PRIMITIVE = true, + NULL_TYPE = false, + }; + + static __device__ __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key) + { + UnsignedBits mask = (key & HIGH_BIT) ? UnsignedBits(-1) : HIGH_BIT; + return key ^ mask; + }; + + static __device__ __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key) + { + UnsignedBits mask = (key & HIGH_BIT) ? HIGH_BIT : UnsignedBits(-1); + return key ^ mask; + }; + + static __host__ __device__ __forceinline__ T Max() { + return FpLimits::Max(); + } + + static __host__ __device__ __forceinline__ T Lowest() { + return FpLimits::Lowest(); + } +}; + + +/** + * \brief Numeric type traits + */ +template struct NumericTraits : BaseTraits {}; + +template <> struct NumericTraits : BaseTraits {}; + +template <> struct NumericTraits : BaseTraits<(std::numeric_limits::is_signed) ? SIGNED_INTEGER : UNSIGNED_INTEGER, true, false, unsigned char, char> {}; +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; + +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; + +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; +#if (__CUDACC_VER_MAJOR__ >= 9) + template <> struct NumericTraits<__half> : BaseTraits {}; +#endif + +template <> struct NumericTraits : BaseTraits::VolatileWord, bool> {}; + + + +/** + * \brief Type traits + */ +template +struct Traits : NumericTraits::Type> {}; + + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + +/** @} */ // end group UtilModule + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/warp/specializations/warp_reduce_shfl.cuh b/fastertransformer/cuda/cub/warp/specializations/warp_reduce_shfl.cuh new file mode 100644 index 000000000..bbbf37e5c --- /dev/null +++ b/fastertransformer/cuda/cub/warp/specializations/warp_reduce_shfl.cuh @@ -0,0 +1,541 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::WarpReduceShfl provides SHFL-based variants of parallel reduction of items partitioned across a CUDA thread warp. + */ + +#pragma once + +#include "../../thread/thread_operators.cuh" +#include "../../util_ptx.cuh" +#include "../../util_type.cuh" +#include "../../util_macro.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \brief WarpReduceShfl provides SHFL-based variants of parallel reduction of items partitioned across a CUDA thread warp. + * + * LOGICAL_WARP_THREADS must be a power-of-two + */ +template < + typename T, ///< Data type being reduced + int LOGICAL_WARP_THREADS, ///< Number of threads per logical warp + int PTX_ARCH> ///< The PTX compute capability for which to to specialize this collective +struct WarpReduceShfl +{ + //--------------------------------------------------------------------- + // Constants and type definitions + //--------------------------------------------------------------------- + + enum + { + /// Whether the logical warp size and the PTX warp size coincide + IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)), + + /// The number of warp reduction steps + STEPS = Log2::VALUE, + + /// Number of logical warps in a PTX warp + LOGICAL_WARPS = CUB_WARP_THREADS(PTX_ARCH) / LOGICAL_WARP_THREADS, + + /// The 5-bit SHFL mask for logically splitting warps into sub-segments starts 8-bits up + SHFL_C = (CUB_WARP_THREADS(PTX_ARCH) - LOGICAL_WARP_THREADS) << 8 + + }; + + template + struct IsInteger + { + enum { + ///Whether the data type is a small (32b or less) integer for which we can use a single SFHL instruction per exchange + IS_SMALL_UNSIGNED = (Traits::CATEGORY == UNSIGNED_INTEGER) && (sizeof(S) <= sizeof(unsigned int)) + }; + }; + + + /// Shared memory storage layout type + typedef NullType TempStorage; + + + //--------------------------------------------------------------------- + // Thread fields + //--------------------------------------------------------------------- + + /// Lane index in logical warp + unsigned int lane_id; + + /// Logical warp index in 32-thread physical warp + unsigned int warp_id; + + /// 32-thread physical warp member mask of logical warp + unsigned int member_mask; + + + //--------------------------------------------------------------------- + // Construction + //--------------------------------------------------------------------- + + /// Constructor + __device__ __forceinline__ WarpReduceShfl( + TempStorage &/*temp_storage*/) + { + lane_id = LaneId(); + warp_id = 0; + member_mask = 0xffffffffu >> (CUB_WARP_THREADS(PTX_ARCH) - LOGICAL_WARP_THREADS); + + if (!IS_ARCH_WARP) + { + warp_id = lane_id / LOGICAL_WARP_THREADS; + lane_id = lane_id % LOGICAL_WARP_THREADS; + member_mask = member_mask << (warp_id * LOGICAL_WARP_THREADS); + } + } + + + //--------------------------------------------------------------------- + // Reduction steps + //--------------------------------------------------------------------- + + /// Reduction (specialized for summation across uint32 types) + __device__ __forceinline__ unsigned int ReduceStep( + unsigned int input, ///< [in] Calling thread's input item. + cub::Sum /*reduction_op*/, ///< [in] Binary reduction operator + int last_lane, ///< [in] Index of last lane in segment + int offset) ///< [in] Up-offset to pull from + { + unsigned int output; + int shfl_c = last_lane | SHFL_C; // Shuffle control (mask and last_lane) + + // Use predicate set from SHFL to guard against invalid peers +#ifdef CUB_USE_COOPERATIVE_GROUPS + asm volatile( + "{" + " .reg .u32 r0;" + " .reg .pred p;" + " shfl.sync.down.b32 r0|p, %1, %2, %3, %5;" + " @p add.u32 r0, r0, %4;" + " mov.u32 %0, r0;" + "}" + : "=r"(output) : "r"(input), "r"(offset), "r"(shfl_c), "r"(input), "r"(member_mask)); +#else + asm volatile( + "{" + " .reg .u32 r0;" + " .reg .pred p;" + " shfl.down.b32 r0|p, %1, %2, %3;" + " @p add.u32 r0, r0, %4;" + " mov.u32 %0, r0;" + "}" + : "=r"(output) : "r"(input), "r"(offset), "r"(shfl_c), "r"(input)); +#endif + + return output; + } + + + /// Reduction (specialized for summation across fp32 types) + __device__ __forceinline__ float ReduceStep( + float input, ///< [in] Calling thread's input item. + cub::Sum /*reduction_op*/, ///< [in] Binary reduction operator + int last_lane, ///< [in] Index of last lane in segment + int offset) ///< [in] Up-offset to pull from + { + float output; + int shfl_c = last_lane | SHFL_C; // Shuffle control (mask and last_lane) + + // Use predicate set from SHFL to guard against invalid peers +#ifdef CUB_USE_COOPERATIVE_GROUPS + asm volatile( + "{" + " .reg .f32 r0;" + " .reg .pred p;" + " shfl.sync.down.b32 r0|p, %1, %2, %3, %5;" + " @p add.f32 r0, r0, %4;" + " mov.f32 %0, r0;" + "}" + : "=f"(output) : "f"(input), "r"(offset), "r"(shfl_c), "f"(input), "r"(member_mask)); +#else + asm volatile( + "{" + " .reg .f32 r0;" + " .reg .pred p;" + " shfl.down.b32 r0|p, %1, %2, %3;" + " @p add.f32 r0, r0, %4;" + " mov.f32 %0, r0;" + "}" + : "=f"(output) : "f"(input), "r"(offset), "r"(shfl_c), "f"(input)); +#endif + + return output; + } + + + /// Reduction (specialized for summation across unsigned long long types) + __device__ __forceinline__ unsigned long long ReduceStep( + unsigned long long input, ///< [in] Calling thread's input item. + cub::Sum /*reduction_op*/, ///< [in] Binary reduction operator + int last_lane, ///< [in] Index of last lane in segment + int offset) ///< [in] Up-offset to pull from + { + unsigned long long output; + int shfl_c = last_lane | SHFL_C; // Shuffle control (mask and last_lane) + +#ifdef CUB_USE_COOPERATIVE_GROUPS + asm volatile( + "{" + " .reg .u32 lo;" + " .reg .u32 hi;" + " .reg .pred p;" + " mov.b64 {lo, hi}, %1;" + " shfl.sync.down.b32 lo|p, lo, %2, %3, %4;" + " shfl.sync.down.b32 hi|p, hi, %2, %3, %4;" + " mov.b64 %0, {lo, hi};" + " @p add.u64 %0, %0, %1;" + "}" + : "=l"(output) : "l"(input), "r"(offset), "r"(shfl_c), "r"(member_mask)); +#else + asm volatile( + "{" + " .reg .u32 lo;" + " .reg .u32 hi;" + " .reg .pred p;" + " mov.b64 {lo, hi}, %1;" + " shfl.down.b32 lo|p, lo, %2, %3;" + " shfl.down.b32 hi|p, hi, %2, %3;" + " mov.b64 %0, {lo, hi};" + " @p add.u64 %0, %0, %1;" + "}" + : "=l"(output) : "l"(input), "r"(offset), "r"(shfl_c)); +#endif + + return output; + } + + + /// Reduction (specialized for summation across long long types) + __device__ __forceinline__ long long ReduceStep( + long long input, ///< [in] Calling thread's input item. + cub::Sum /*reduction_op*/, ///< [in] Binary reduction operator + int last_lane, ///< [in] Index of last lane in segment + int offset) ///< [in] Up-offset to pull from + { + long long output; + int shfl_c = last_lane | SHFL_C; // Shuffle control (mask and last_lane) + + // Use predicate set from SHFL to guard against invalid peers +#ifdef CUB_USE_COOPERATIVE_GROUPS + asm volatile( + "{" + " .reg .u32 lo;" + " .reg .u32 hi;" + " .reg .pred p;" + " mov.b64 {lo, hi}, %1;" + " shfl.sync.down.b32 lo|p, lo, %2, %3, %4;" + " shfl.sync.down.b32 hi|p, hi, %2, %3, %4;" + " mov.b64 %0, {lo, hi};" + " @p add.s64 %0, %0, %1;" + "}" + : "=l"(output) : "l"(input), "r"(offset), "r"(shfl_c), "r"(member_mask)); +#else + asm volatile( + "{" + " .reg .u32 lo;" + " .reg .u32 hi;" + " .reg .pred p;" + " mov.b64 {lo, hi}, %1;" + " shfl.down.b32 lo|p, lo, %2, %3;" + " shfl.down.b32 hi|p, hi, %2, %3;" + " mov.b64 %0, {lo, hi};" + " @p add.s64 %0, %0, %1;" + "}" + : "=l"(output) : "l"(input), "r"(offset), "r"(shfl_c)); +#endif + + return output; + } + + + /// Reduction (specialized for summation across double types) + __device__ __forceinline__ double ReduceStep( + double input, ///< [in] Calling thread's input item. + cub::Sum /*reduction_op*/, ///< [in] Binary reduction operator + int last_lane, ///< [in] Index of last lane in segment + int offset) ///< [in] Up-offset to pull from + { + double output; + int shfl_c = last_lane | SHFL_C; // Shuffle control (mask and last_lane) + + // Use predicate set from SHFL to guard against invalid peers +#ifdef CUB_USE_COOPERATIVE_GROUPS + asm volatile( + "{" + " .reg .u32 lo;" + " .reg .u32 hi;" + " .reg .pred p;" + " .reg .f64 r0;" + " mov.b64 %0, %1;" + " mov.b64 {lo, hi}, %1;" + " shfl.sync.down.b32 lo|p, lo, %2, %3, %4;" + " shfl.sync.down.b32 hi|p, hi, %2, %3, %4;" + " mov.b64 r0, {lo, hi};" + " @p add.f64 %0, %0, r0;" + "}" + : "=d"(output) : "d"(input), "r"(offset), "r"(shfl_c), "r"(member_mask)); +#else + asm volatile( + "{" + " .reg .u32 lo;" + " .reg .u32 hi;" + " .reg .pred p;" + " .reg .f64 r0;" + " mov.b64 %0, %1;" + " mov.b64 {lo, hi}, %1;" + " shfl.down.b32 lo|p, lo, %2, %3;" + " shfl.down.b32 hi|p, hi, %2, %3;" + " mov.b64 r0, {lo, hi};" + " @p add.f64 %0, %0, r0;" + "}" + : "=d"(output) : "d"(input), "r"(offset), "r"(shfl_c)); +#endif + + return output; + } + + + /// Reduction (specialized for swizzled ReduceByKeyOp across KeyValuePair types) + template + __device__ __forceinline__ KeyValuePair ReduceStep( + KeyValuePair input, ///< [in] Calling thread's input item. + SwizzleScanOp > /*reduction_op*/, ///< [in] Binary reduction operator + int last_lane, ///< [in] Index of last lane in segment + int offset) ///< [in] Up-offset to pull from + { + KeyValuePair output; + + KeyT other_key = ShuffleDown(input.key, offset, last_lane, member_mask); + + output.key = input.key; + output.value = ReduceStep( + input.value, + cub::Sum(), + last_lane, + offset, + Int2Type::IS_SMALL_UNSIGNED>()); + + if (input.key != other_key) + output.value = input.value; + + return output; + } + + + + /// Reduction (specialized for swizzled ReduceBySegmentOp across KeyValuePair types) + template + __device__ __forceinline__ KeyValuePair ReduceStep( + KeyValuePair input, ///< [in] Calling thread's input item. + SwizzleScanOp > /*reduction_op*/, ///< [in] Binary reduction operator + int last_lane, ///< [in] Index of last lane in segment + int offset) ///< [in] Up-offset to pull from + { + KeyValuePair output; + + output.value = ReduceStep(input.value, cub::Sum(), last_lane, offset, Int2Type::IS_SMALL_UNSIGNED>()); + output.key = ReduceStep(input.key, cub::Sum(), last_lane, offset, Int2Type::IS_SMALL_UNSIGNED>()); + + if (input.key > 0) + output.value = input.value; + + return output; + } + + + /// Reduction step (generic) + template + __device__ __forceinline__ _T ReduceStep( + _T input, ///< [in] Calling thread's input item. + ReductionOp reduction_op, ///< [in] Binary reduction operator + int last_lane, ///< [in] Index of last lane in segment + int offset) ///< [in] Up-offset to pull from + { + _T output = input; + + _T temp = ShuffleDown(output, offset, last_lane, member_mask); + + // Perform reduction op if valid + if (offset + lane_id <= last_lane) + output = reduction_op(input, temp); + + return output; + } + + + /// Reduction step (specialized for small unsigned integers size 32b or less) + template + __device__ __forceinline__ _T ReduceStep( + _T input, ///< [in] Calling thread's input item. + ReductionOp reduction_op, ///< [in] Binary reduction operator + int last_lane, ///< [in] Index of last lane in segment + int offset, ///< [in] Up-offset to pull from + Int2Type /*is_small_unsigned*/) ///< [in] Marker type indicating whether T is a small unsigned integer + { + return ReduceStep(input, reduction_op, last_lane, offset); + } + + + /// Reduction step (specialized for types other than small unsigned integers size 32b or less) + template + __device__ __forceinline__ _T ReduceStep( + _T input, ///< [in] Calling thread's input item. + ReductionOp reduction_op, ///< [in] Binary reduction operator + int last_lane, ///< [in] Index of last lane in segment + int offset, ///< [in] Up-offset to pull from + Int2Type /*is_small_unsigned*/) ///< [in] Marker type indicating whether T is a small unsigned integer + { + return ReduceStep(input, reduction_op, last_lane, offset); + } + + + //--------------------------------------------------------------------- + // Templated inclusive scan iteration + //--------------------------------------------------------------------- + + template + __device__ __forceinline__ void ReduceStep( + T& input, ///< [in] Calling thread's input item. + ReductionOp reduction_op, ///< [in] Binary reduction operator + int last_lane, ///< [in] Index of last lane in segment + Int2Type /*step*/) + { + input = ReduceStep(input, reduction_op, last_lane, 1 << STEP, Int2Type::IS_SMALL_UNSIGNED>()); + + ReduceStep(input, reduction_op, last_lane, Int2Type()); + } + + template + __device__ __forceinline__ void ReduceStep( + T& /*input*/, ///< [in] Calling thread's input item. + ReductionOp /*reduction_op*/, ///< [in] Binary reduction operator + int /*last_lane*/, ///< [in] Index of last lane in segment + Int2Type /*step*/) + {} + + + //--------------------------------------------------------------------- + // Reduction operations + //--------------------------------------------------------------------- + + /// Reduction + template < + bool ALL_LANES_VALID, ///< Whether all lanes in each warp are contributing a valid fold of items + typename ReductionOp> + __device__ __forceinline__ T Reduce( + T input, ///< [in] Calling thread's input + int valid_items, ///< [in] Total number of valid items across the logical warp + ReductionOp reduction_op) ///< [in] Binary reduction operator + { + int last_lane = (ALL_LANES_VALID) ? + LOGICAL_WARP_THREADS - 1 : + valid_items - 1; + + T output = input; + +// // Iterate reduction steps +// #pragma unroll +// for (int STEP = 0; STEP < STEPS; STEP++) +// { +// output = ReduceStep(output, reduction_op, last_lane, 1 << STEP, Int2Type::IS_SMALL_UNSIGNED>()); +// } + + // Template-iterate reduction steps + ReduceStep(output, reduction_op, last_lane, Int2Type<0>()); + + return output; + } + + + /// Segmented reduction + template < + bool HEAD_SEGMENTED, ///< Whether flags indicate a segment-head or a segment-tail + typename FlagT, + typename ReductionOp> + __device__ __forceinline__ T SegmentedReduce( + T input, ///< [in] Calling thread's input + FlagT flag, ///< [in] Whether or not the current lane is a segment head/tail + ReductionOp reduction_op) ///< [in] Binary reduction operator + { + // Get the start flags for each thread in the warp. + int warp_flags = WARP_BALLOT(flag, member_mask); + + // Convert to tail-segmented + if (HEAD_SEGMENTED) + warp_flags >>= 1; + + // Mask out the bits below the current thread + warp_flags &= LaneMaskGe(); + + // Mask of physical lanes outside the logical warp and convert to logical lanemask + if (!IS_ARCH_WARP) + { + warp_flags = (warp_flags & member_mask) >> (warp_id * LOGICAL_WARP_THREADS); + } + + // Mask in the last lane of logical warp + warp_flags |= 1u << (LOGICAL_WARP_THREADS - 1); + + // Find the next set flag + int last_lane = __clz(__brev(warp_flags)); + + T output = input; + +// // Iterate reduction steps +// #pragma unroll +// for (int STEP = 0; STEP < STEPS; STEP++) +// { +// output = ReduceStep(output, reduction_op, last_lane, 1 << STEP, Int2Type::IS_SMALL_UNSIGNED>()); +// } + + // Template-iterate reduction steps + ReduceStep(output, reduction_op, last_lane, Int2Type<0>()); + + return output; + } +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/warp/specializations/warp_reduce_smem.cuh b/fastertransformer/cuda/cub/warp/specializations/warp_reduce_smem.cuh new file mode 100644 index 000000000..7baa573be --- /dev/null +++ b/fastertransformer/cuda/cub/warp/specializations/warp_reduce_smem.cuh @@ -0,0 +1,372 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::WarpReduceSmem provides smem-based variants of parallel reduction of items partitioned across a CUDA thread warp. + */ + +#pragma once + +#include "../../thread/thread_operators.cuh" +#include "../../thread/thread_load.cuh" +#include "../../thread/thread_store.cuh" +#include "../../util_type.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \brief WarpReduceSmem provides smem-based variants of parallel reduction of items partitioned across a CUDA thread warp. + */ +template < + typename T, ///< Data type being reduced + int LOGICAL_WARP_THREADS, ///< Number of threads per logical warp + int PTX_ARCH> ///< The PTX compute capability for which to to specialize this collective +struct WarpReduceSmem +{ + /****************************************************************************** + * Constants and type definitions + ******************************************************************************/ + + enum + { + /// Whether the logical warp size and the PTX warp size coincide + IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)), + + /// Whether the logical warp size is a power-of-two + IS_POW_OF_TWO = PowerOfTwo::VALUE, + + /// The number of warp scan steps + STEPS = Log2::VALUE, + + /// The number of threads in half a warp + HALF_WARP_THREADS = 1 << (STEPS - 1), + + /// The number of shared memory elements per warp + WARP_SMEM_ELEMENTS = LOGICAL_WARP_THREADS + HALF_WARP_THREADS, + + /// FlagT status (when not using ballot) + UNSET = 0x0, // Is initially unset + SET = 0x1, // Is initially set + SEEN = 0x2, // Has seen another head flag from a successor peer + }; + + /// Shared memory flag type + typedef unsigned char SmemFlag; + + /// Shared memory storage layout type (1.5 warps-worth of elements for each warp) + struct _TempStorage + { + T reduce[WARP_SMEM_ELEMENTS]; + SmemFlag flags[WARP_SMEM_ELEMENTS]; + }; + + // Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + _TempStorage &temp_storage; + unsigned int lane_id; + unsigned int member_mask; + + + /****************************************************************************** + * Construction + ******************************************************************************/ + + /// Constructor + __device__ __forceinline__ WarpReduceSmem( + TempStorage &temp_storage) + : + temp_storage(temp_storage.Alias()), + + lane_id(IS_ARCH_WARP ? + LaneId() : + LaneId() % LOGICAL_WARP_THREADS), + + member_mask((0xffffffff >> (32 - LOGICAL_WARP_THREADS)) << ((IS_ARCH_WARP || !IS_POW_OF_TWO ) ? + 0 : // arch-width and non-power-of-two subwarps cannot be tiled with the arch-warp + ((LaneId() / LOGICAL_WARP_THREADS) * LOGICAL_WARP_THREADS))) + {} + + /****************************************************************************** + * Utility methods + ******************************************************************************/ + + //--------------------------------------------------------------------- + // Regular reduction + //--------------------------------------------------------------------- + + /** + * Reduction step + */ + template < + bool ALL_LANES_VALID, ///< Whether all lanes in each warp are contributing a valid fold of items + typename ReductionOp, + int STEP> + __device__ __forceinline__ T ReduceStep( + T input, ///< [in] Calling thread's input + int valid_items, ///< [in] Total number of valid items across the logical warp + ReductionOp reduction_op, ///< [in] Reduction operator + Int2Type /*step*/) + { + const int OFFSET = 1 << STEP; + + // Share input through buffer + ThreadStore(&temp_storage.reduce[lane_id], input); + + WARP_SYNC(member_mask); + + // Update input if peer_addend is in range + if ((ALL_LANES_VALID && IS_POW_OF_TWO) || ((lane_id + OFFSET) < valid_items)) + { + T peer_addend = ThreadLoad(&temp_storage.reduce[lane_id + OFFSET]); + input = reduction_op(input, peer_addend); + } + + WARP_SYNC(member_mask); + + return ReduceStep(input, valid_items, reduction_op, Int2Type()); + } + + + /** + * Reduction step (terminate) + */ + template < + bool ALL_LANES_VALID, ///< Whether all lanes in each warp are contributing a valid fold of items + typename ReductionOp> + __device__ __forceinline__ T ReduceStep( + T input, ///< [in] Calling thread's input + int valid_items, ///< [in] Total number of valid items across the logical warp + ReductionOp /*reduction_op*/, ///< [in] Reduction operator + Int2Type /*step*/) + { + return input; + } + + + //--------------------------------------------------------------------- + // Segmented reduction + //--------------------------------------------------------------------- + + + /** + * Ballot-based segmented reduce + */ + template < + bool HEAD_SEGMENTED, ///< Whether flags indicate a segment-head or a segment-tail + typename FlagT, + typename ReductionOp> + __device__ __forceinline__ T SegmentedReduce( + T input, ///< [in] Calling thread's input + FlagT flag, ///< [in] Whether or not the current lane is a segment head/tail + ReductionOp reduction_op, ///< [in] Reduction operator + Int2Type /*has_ballot*/) ///< [in] Marker type for whether the target arch has ballot functionality + { + // Get the start flags for each thread in the warp. + int warp_flags = WARP_BALLOT(flag, member_mask); + + if (!HEAD_SEGMENTED) + warp_flags <<= 1; + + // Keep bits above the current thread. + warp_flags &= LaneMaskGt(); + + // Accommodate packing of multiple logical warps in a single physical warp + if (!IS_ARCH_WARP) + { + warp_flags >>= (LaneId() / LOGICAL_WARP_THREADS) * LOGICAL_WARP_THREADS; + } + + // Find next flag + int next_flag = __clz(__brev(warp_flags)); + + // Clip the next segment at the warp boundary if necessary + if (LOGICAL_WARP_THREADS != 32) + next_flag = CUB_MIN(next_flag, LOGICAL_WARP_THREADS); + + #pragma unroll + for (int STEP = 0; STEP < STEPS; STEP++) + { + const int OFFSET = 1 << STEP; + + // Share input into buffer + ThreadStore(&temp_storage.reduce[lane_id], input); + + WARP_SYNC(member_mask); + + // Update input if peer_addend is in range + if (OFFSET + lane_id < next_flag) + { + T peer_addend = ThreadLoad(&temp_storage.reduce[lane_id + OFFSET]); + input = reduction_op(input, peer_addend); + } + + WARP_SYNC(member_mask); + } + + return input; + } + + + /** + * Smem-based segmented reduce + */ + template < + bool HEAD_SEGMENTED, ///< Whether flags indicate a segment-head or a segment-tail + typename FlagT, + typename ReductionOp> + __device__ __forceinline__ T SegmentedReduce( + T input, ///< [in] Calling thread's input + FlagT flag, ///< [in] Whether or not the current lane is a segment head/tail + ReductionOp reduction_op, ///< [in] Reduction operator + Int2Type /*has_ballot*/) ///< [in] Marker type for whether the target arch has ballot functionality + { + enum + { + UNSET = 0x0, // Is initially unset + SET = 0x1, // Is initially set + SEEN = 0x2, // Has seen another head flag from a successor peer + }; + + // Alias flags onto shared data storage + volatile SmemFlag *flag_storage = temp_storage.flags; + + SmemFlag flag_status = (flag) ? SET : UNSET; + + for (int STEP = 0; STEP < STEPS; STEP++) + { + const int OFFSET = 1 << STEP; + + // Share input through buffer + ThreadStore(&temp_storage.reduce[lane_id], input); + + WARP_SYNC(member_mask); + + // Get peer from buffer + T peer_addend = ThreadLoad(&temp_storage.reduce[lane_id + OFFSET]); + + WARP_SYNC(member_mask); + + // Share flag through buffer + flag_storage[lane_id] = flag_status; + + // Get peer flag from buffer + SmemFlag peer_flag_status = flag_storage[lane_id + OFFSET]; + + // Update input if peer was in range + if (lane_id < LOGICAL_WARP_THREADS - OFFSET) + { + if (HEAD_SEGMENTED) + { + // Head-segmented + if ((flag_status & SEEN) == 0) + { + // Has not seen a more distant head flag + if (peer_flag_status & SET) + { + // Has now seen a head flag + flag_status |= SEEN; + } + else + { + // Peer is not a head flag: grab its count + input = reduction_op(input, peer_addend); + } + + // Update seen status to include that of peer + flag_status |= (peer_flag_status & SEEN); + } + } + else + { + // Tail-segmented. Simply propagate flag status + if (!flag_status) + { + input = reduction_op(input, peer_addend); + flag_status |= peer_flag_status; + } + + } + } + } + + return input; + } + + + /****************************************************************************** + * Interface + ******************************************************************************/ + + /** + * Reduction + */ + template < + bool ALL_LANES_VALID, ///< Whether all lanes in each warp are contributing a valid fold of items + typename ReductionOp> + __device__ __forceinline__ T Reduce( + T input, ///< [in] Calling thread's input + int valid_items, ///< [in] Total number of valid items across the logical warp + ReductionOp reduction_op) ///< [in] Reduction operator + { + return ReduceStep(input, valid_items, reduction_op, Int2Type<0>()); + } + + + /** + * Segmented reduction + */ + template < + bool HEAD_SEGMENTED, ///< Whether flags indicate a segment-head or a segment-tail + typename FlagT, + typename ReductionOp> + __device__ __forceinline__ T SegmentedReduce( + T input, ///< [in] Calling thread's input + FlagT flag, ///< [in] Whether or not the current lane is a segment head/tail + ReductionOp reduction_op) ///< [in] Reduction operator + { + return SegmentedReduce(input, flag, reduction_op, Int2Type<(PTX_ARCH >= 200)>()); + } + + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/warp/specializations/warp_scan_shfl.cuh b/fastertransformer/cuda/cub/warp/specializations/warp_scan_shfl.cuh new file mode 100644 index 000000000..7f4e1c94b --- /dev/null +++ b/fastertransformer/cuda/cub/warp/specializations/warp_scan_shfl.cuh @@ -0,0 +1,632 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::WarpScanShfl provides SHFL-based variants of parallel prefix scan of items partitioned across a CUDA thread warp. + */ + +#pragma once + +#include "../../thread/thread_operators.cuh" +#include "../../util_type.cuh" +#include "../../util_ptx.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \brief WarpScanShfl provides SHFL-based variants of parallel prefix scan of items partitioned across a CUDA thread warp. + * + * LOGICAL_WARP_THREADS must be a power-of-two + */ +template < + typename T, ///< Data type being scanned + int LOGICAL_WARP_THREADS, ///< Number of threads per logical warp + int PTX_ARCH> ///< The PTX compute capability for which to to specialize this collective +struct WarpScanShfl +{ + //--------------------------------------------------------------------- + // Constants and type definitions + //--------------------------------------------------------------------- + + enum + { + /// Whether the logical warp size and the PTX warp size coincide + IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)), + + /// The number of warp scan steps + STEPS = Log2::VALUE, + + /// The 5-bit SHFL mask for logically splitting warps into sub-segments starts 8-bits up + SHFL_C = (CUB_WARP_THREADS(PTX_ARCH) - LOGICAL_WARP_THREADS) << 8 + }; + + template + struct IntegerTraits + { + enum { + ///Whether the data type is a small (32b or less) integer for which we can use a single SFHL instruction per exchange + IS_SMALL_UNSIGNED = (Traits::CATEGORY == UNSIGNED_INTEGER) && (sizeof(S) <= sizeof(unsigned int)) + }; + }; + + /// Shared memory storage layout type + struct TempStorage {}; + + + //--------------------------------------------------------------------- + // Thread fields + //--------------------------------------------------------------------- + + /// Lane index in logical warp + unsigned int lane_id; + + /// Logical warp index in 32-thread physical warp + unsigned int warp_id; + + /// 32-thread physical warp member mask of logical warp + unsigned int member_mask; + + //--------------------------------------------------------------------- + // Construction + //--------------------------------------------------------------------- + + /// Constructor + __device__ __forceinline__ WarpScanShfl( + TempStorage &/*temp_storage*/) + { + lane_id = LaneId(); + warp_id = 0; + member_mask = 0xffffffffu >> (CUB_WARP_THREADS(PTX_ARCH) - LOGICAL_WARP_THREADS); + + if (!IS_ARCH_WARP) + { + warp_id = lane_id / LOGICAL_WARP_THREADS; + lane_id = lane_id % LOGICAL_WARP_THREADS; + member_mask = member_mask << (warp_id * LOGICAL_WARP_THREADS); + } + } + + + //--------------------------------------------------------------------- + // Inclusive scan steps + //--------------------------------------------------------------------- + + /// Inclusive prefix scan step (specialized for summation across int32 types) + __device__ __forceinline__ int InclusiveScanStep( + int input, ///< [in] Calling thread's input item. + cub::Sum /*scan_op*/, ///< [in] Binary scan operator + int first_lane, ///< [in] Index of first lane in segment + int offset) ///< [in] Up-offset to pull from + { + int output; + int shfl_c = first_lane | SHFL_C; // Shuffle control (mask and first-lane) + + // Use predicate set from SHFL to guard against invalid peers +#ifdef CUB_USE_COOPERATIVE_GROUPS + asm volatile( + "{" + " .reg .s32 r0;" + " .reg .pred p;" + " shfl.sync.up.b32 r0|p, %1, %2, %3, %5;" + " @p add.s32 r0, r0, %4;" + " mov.s32 %0, r0;" + "}" + : "=r"(output) : "r"(input), "r"(offset), "r"(shfl_c), "r"(input), "r"(member_mask)); +#else + asm volatile( + "{" + " .reg .s32 r0;" + " .reg .pred p;" + " shfl.up.b32 r0|p, %1, %2, %3;" + " @p add.s32 r0, r0, %4;" + " mov.s32 %0, r0;" + "}" + : "=r"(output) : "r"(input), "r"(offset), "r"(shfl_c), "r"(input)); +#endif + + return output; + } + + /// Inclusive prefix scan step (specialized for summation across uint32 types) + __device__ __forceinline__ unsigned int InclusiveScanStep( + unsigned int input, ///< [in] Calling thread's input item. + cub::Sum /*scan_op*/, ///< [in] Binary scan operator + int first_lane, ///< [in] Index of first lane in segment + int offset) ///< [in] Up-offset to pull from + { + unsigned int output; + int shfl_c = first_lane | SHFL_C; // Shuffle control (mask and first-lane) + + // Use predicate set from SHFL to guard against invalid peers +#ifdef CUB_USE_COOPERATIVE_GROUPS + asm volatile( + "{" + " .reg .u32 r0;" + " .reg .pred p;" + " shfl.sync.up.b32 r0|p, %1, %2, %3, %5;" + " @p add.u32 r0, r0, %4;" + " mov.u32 %0, r0;" + "}" + : "=r"(output) : "r"(input), "r"(offset), "r"(shfl_c), "r"(input), "r"(member_mask)); +#else + asm volatile( + "{" + " .reg .u32 r0;" + " .reg .pred p;" + " shfl.up.b32 r0|p, %1, %2, %3;" + " @p add.u32 r0, r0, %4;" + " mov.u32 %0, r0;" + "}" + : "=r"(output) : "r"(input), "r"(offset), "r"(shfl_c), "r"(input)); +#endif + + return output; + } + + + /// Inclusive prefix scan step (specialized for summation across fp32 types) + __device__ __forceinline__ float InclusiveScanStep( + float input, ///< [in] Calling thread's input item. + cub::Sum /*scan_op*/, ///< [in] Binary scan operator + int first_lane, ///< [in] Index of first lane in segment + int offset) ///< [in] Up-offset to pull from + { + float output; + int shfl_c = first_lane | SHFL_C; // Shuffle control (mask and first-lane) + + // Use predicate set from SHFL to guard against invalid peers +#ifdef CUB_USE_COOPERATIVE_GROUPS + asm volatile( + "{" + " .reg .f32 r0;" + " .reg .pred p;" + " shfl.sync.up.b32 r0|p, %1, %2, %3, %5;" + " @p add.f32 r0, r0, %4;" + " mov.f32 %0, r0;" + "}" + : "=f"(output) : "f"(input), "r"(offset), "r"(shfl_c), "f"(input), "r"(member_mask)); +#else + asm volatile( + "{" + " .reg .f32 r0;" + " .reg .pred p;" + " shfl.up.b32 r0|p, %1, %2, %3;" + " @p add.f32 r0, r0, %4;" + " mov.f32 %0, r0;" + "}" + : "=f"(output) : "f"(input), "r"(offset), "r"(shfl_c), "f"(input)); +#endif + + return output; + } + + + /// Inclusive prefix scan step (specialized for summation across unsigned long long types) + __device__ __forceinline__ unsigned long long InclusiveScanStep( + unsigned long long input, ///< [in] Calling thread's input item. + cub::Sum /*scan_op*/, ///< [in] Binary scan operator + int first_lane, ///< [in] Index of first lane in segment + int offset) ///< [in] Up-offset to pull from + { + unsigned long long output; + int shfl_c = first_lane | SHFL_C; // Shuffle control (mask and first-lane) + + // Use predicate set from SHFL to guard against invalid peers +#ifdef CUB_USE_COOPERATIVE_GROUPS + asm volatile( + "{" + " .reg .u64 r0;" + " .reg .u32 lo;" + " .reg .u32 hi;" + " .reg .pred p;" + " mov.b64 {lo, hi}, %1;" + " shfl.sync.up.b32 lo|p, lo, %2, %3, %5;" + " shfl.sync.up.b32 hi|p, hi, %2, %3, %5;" + " mov.b64 r0, {lo, hi};" + " @p add.u64 r0, r0, %4;" + " mov.u64 %0, r0;" + "}" + : "=l"(output) : "l"(input), "r"(offset), "r"(shfl_c), "l"(input), "r"(member_mask)); +#else + asm volatile( + "{" + " .reg .u64 r0;" + " .reg .u32 lo;" + " .reg .u32 hi;" + " .reg .pred p;" + " mov.b64 {lo, hi}, %1;" + " shfl.up.b32 lo|p, lo, %2, %3;" + " shfl.up.b32 hi|p, hi, %2, %3;" + " mov.b64 r0, {lo, hi};" + " @p add.u64 r0, r0, %4;" + " mov.u64 %0, r0;" + "}" + : "=l"(output) : "l"(input), "r"(offset), "r"(shfl_c), "l"(input)); +#endif + + return output; + } + + + /// Inclusive prefix scan step (specialized for summation across long long types) + __device__ __forceinline__ long long InclusiveScanStep( + long long input, ///< [in] Calling thread's input item. + cub::Sum /*scan_op*/, ///< [in] Binary scan operator + int first_lane, ///< [in] Index of first lane in segment + int offset) ///< [in] Up-offset to pull from + { + long long output; + int shfl_c = first_lane | SHFL_C; // Shuffle control (mask and first-lane) + + // Use predicate set from SHFL to guard against invalid peers +#ifdef CUB_USE_COOPERATIVE_GROUPS + asm volatile( + "{" + " .reg .s64 r0;" + " .reg .u32 lo;" + " .reg .u32 hi;" + " .reg .pred p;" + " mov.b64 {lo, hi}, %1;" + " shfl.sync.up.b32 lo|p, lo, %2, %3, %5;" + " shfl.sync.up.b32 hi|p, hi, %2, %3, %5;" + " mov.b64 r0, {lo, hi};" + " @p add.s64 r0, r0, %4;" + " mov.s64 %0, r0;" + "}" + : "=l"(output) : "l"(input), "r"(offset), "r"(shfl_c), "l"(input), "r"(member_mask)); +#else + asm volatile( + "{" + " .reg .s64 r0;" + " .reg .u32 lo;" + " .reg .u32 hi;" + " .reg .pred p;" + " mov.b64 {lo, hi}, %1;" + " shfl.up.b32 lo|p, lo, %2, %3;" + " shfl.up.b32 hi|p, hi, %2, %3;" + " mov.b64 r0, {lo, hi};" + " @p add.s64 r0, r0, %4;" + " mov.s64 %0, r0;" + "}" + : "=l"(output) : "l"(input), "r"(offset), "r"(shfl_c), "l"(input)); +#endif + + return output; + } + + + /// Inclusive prefix scan step (specialized for summation across fp64 types) + __device__ __forceinline__ double InclusiveScanStep( + double input, ///< [in] Calling thread's input item. + cub::Sum /*scan_op*/, ///< [in] Binary scan operator + int first_lane, ///< [in] Index of first lane in segment + int offset) ///< [in] Up-offset to pull from + { + double output; + int shfl_c = first_lane | SHFL_C; // Shuffle control (mask and first-lane) + + // Use predicate set from SHFL to guard against invalid peers +#ifdef CUB_USE_COOPERATIVE_GROUPS + asm volatile( + "{" + " .reg .u32 lo;" + " .reg .u32 hi;" + " .reg .pred p;" + " .reg .f64 r0;" + " mov.b64 %0, %1;" + " mov.b64 {lo, hi}, %1;" + " shfl.sync.up.b32 lo|p, lo, %2, %3, %4;" + " shfl.sync.up.b32 hi|p, hi, %2, %3, %4;" + " mov.b64 r0, {lo, hi};" + " @p add.f64 %0, %0, r0;" + "}" + : "=d"(output) : "d"(input), "r"(offset), "r"(shfl_c), "r"(member_mask)); +#else + asm volatile( + "{" + " .reg .u32 lo;" + " .reg .u32 hi;" + " .reg .pred p;" + " .reg .f64 r0;" + " mov.b64 %0, %1;" + " mov.b64 {lo, hi}, %1;" + " shfl.up.b32 lo|p, lo, %2, %3;" + " shfl.up.b32 hi|p, hi, %2, %3;" + " mov.b64 r0, {lo, hi};" + " @p add.f64 %0, %0, r0;" + "}" + : "=d"(output) : "d"(input), "r"(offset), "r"(shfl_c)); +#endif + + return output; + } + + +/* + /// Inclusive prefix scan (specialized for ReduceBySegmentOp across KeyValuePair types) + template + __device__ __forceinline__ KeyValuePairInclusiveScanStep( + KeyValuePair input, ///< [in] Calling thread's input item. + ReduceBySegmentOp scan_op, ///< [in] Binary scan operator + int first_lane, ///< [in] Index of first lane in segment + int offset) ///< [in] Up-offset to pull from + { + KeyValuePair output; + + output.value = InclusiveScanStep(input.value, cub::Sum(), first_lane, offset, Int2Type::IS_SMALL_UNSIGNED>()); + output.key = InclusiveScanStep(input.key, cub::Sum(), first_lane, offset, Int2Type::IS_SMALL_UNSIGNED>()); + + if (input.key > 0) + output.value = input.value; + + return output; + } +*/ + + /// Inclusive prefix scan step (generic) + template + __device__ __forceinline__ _T InclusiveScanStep( + _T input, ///< [in] Calling thread's input item. + ScanOpT scan_op, ///< [in] Binary scan operator + int first_lane, ///< [in] Index of first lane in segment + int offset) ///< [in] Up-offset to pull from + { + _T temp = ShuffleUp(input, offset, first_lane, member_mask); + + // Perform scan op if from a valid peer + _T output = scan_op(temp, input); + if (static_cast(lane_id) < first_lane + offset) + output = input; + + return output; + } + + + /// Inclusive prefix scan step (specialized for small integers size 32b or less) + template + __device__ __forceinline__ _T InclusiveScanStep( + _T input, ///< [in] Calling thread's input item. + ScanOpT scan_op, ///< [in] Binary scan operator + int first_lane, ///< [in] Index of first lane in segment + int offset, ///< [in] Up-offset to pull from + Int2Type /*is_small_unsigned*/) ///< [in] Marker type indicating whether T is a small integer + { + return InclusiveScanStep(input, scan_op, first_lane, offset); + } + + + /// Inclusive prefix scan step (specialized for types other than small integers size 32b or less) + template + __device__ __forceinline__ _T InclusiveScanStep( + _T input, ///< [in] Calling thread's input item. + ScanOpT scan_op, ///< [in] Binary scan operator + int first_lane, ///< [in] Index of first lane in segment + int offset, ///< [in] Up-offset to pull from + Int2Type /*is_small_unsigned*/) ///< [in] Marker type indicating whether T is a small integer + { + return InclusiveScanStep(input, scan_op, first_lane, offset); + } + + + /****************************************************************************** + * Interface + ******************************************************************************/ + + //--------------------------------------------------------------------- + // Broadcast + //--------------------------------------------------------------------- + + /// Broadcast + __device__ __forceinline__ T Broadcast( + T input, ///< [in] The value to broadcast + int src_lane) ///< [in] Which warp lane is to do the broadcasting + { + return ShuffleIndex(input, src_lane, member_mask); + } + + + //--------------------------------------------------------------------- + // Inclusive operations + //--------------------------------------------------------------------- + + /// Inclusive scan + template + __device__ __forceinline__ void InclusiveScan( + _T input, ///< [in] Calling thread's input item. + _T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOpT scan_op) ///< [in] Binary scan operator + { + inclusive_output = input; + + // Iterate scan steps + int segment_first_lane = 0; + + // Iterate scan steps + #pragma unroll + for (int STEP = 0; STEP < STEPS; STEP++) + { + inclusive_output = InclusiveScanStep( + inclusive_output, + scan_op, + segment_first_lane, + (1 << STEP), + Int2Type::IS_SMALL_UNSIGNED>()); + } + + } + + /// Inclusive scan, specialized for reduce-value-by-key + template + __device__ __forceinline__ void InclusiveScan( + KeyValuePair input, ///< [in] Calling thread's input item. + KeyValuePair &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ReduceByKeyOp scan_op) ///< [in] Binary scan operator + { + inclusive_output = input; + + KeyT pred_key = ShuffleUp(inclusive_output.key, 1, 0, member_mask); + + unsigned int ballot = WARP_BALLOT((pred_key != inclusive_output.key), member_mask); + + // Mask away all lanes greater than ours + ballot = ballot & LaneMaskLe(); + + // Find index of first set bit + int segment_first_lane = CUB_MAX(0, 31 - __clz(ballot)); + + // Iterate scan steps + #pragma unroll + for (int STEP = 0; STEP < STEPS; STEP++) + { + inclusive_output.value = InclusiveScanStep( + inclusive_output.value, + scan_op.op, + segment_first_lane, + (1 << STEP), + Int2Type::IS_SMALL_UNSIGNED>()); + } + } + + + /// Inclusive scan with aggregate + template + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item. + T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOpT scan_op, ///< [in] Binary scan operator + T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. + { + InclusiveScan(input, inclusive_output, scan_op); + + // Grab aggregate from last warp lane + warp_aggregate = ShuffleIndex(inclusive_output, LOGICAL_WARP_THREADS - 1, member_mask); + } + + + //--------------------------------------------------------------------- + // Get exclusive from inclusive + //--------------------------------------------------------------------- + + /// Update inclusive and exclusive using input and inclusive + template + __device__ __forceinline__ void Update( + T /*input*/, ///< [in] + T &inclusive, ///< [in, out] + T &exclusive, ///< [out] + ScanOpT /*scan_op*/, ///< [in] + IsIntegerT /*is_integer*/) ///< [in] + { + // initial value unknown + exclusive = ShuffleUp(inclusive, 1, 0, member_mask); + } + + /// Update inclusive and exclusive using input and inclusive (specialized for summation of integer types) + __device__ __forceinline__ void Update( + T input, + T &inclusive, + T &exclusive, + cub::Sum /*scan_op*/, + Int2Type /*is_integer*/) + { + // initial value presumed 0 + exclusive = inclusive - input; + } + + /// Update inclusive and exclusive using initial value using input, inclusive, and initial value + template + __device__ __forceinline__ void Update ( + T /*input*/, + T &inclusive, + T &exclusive, + ScanOpT scan_op, + T initial_value, + IsIntegerT /*is_integer*/) + { + inclusive = scan_op(initial_value, inclusive); + exclusive = ShuffleUp(inclusive, 1, 0, member_mask); + + if (lane_id == 0) + exclusive = initial_value; + } + + /// Update inclusive and exclusive using initial value using input and inclusive (specialized for summation of integer types) + __device__ __forceinline__ void Update ( + T input, + T &inclusive, + T &exclusive, + cub::Sum scan_op, + T initial_value, + Int2Type /*is_integer*/) + { + inclusive = scan_op(initial_value, inclusive); + exclusive = inclusive - input; + } + + + /// Update inclusive, exclusive, and warp aggregate using input and inclusive + template + __device__ __forceinline__ void Update ( + T input, + T &inclusive, + T &exclusive, + T &warp_aggregate, + ScanOpT scan_op, + IsIntegerT is_integer) + { + warp_aggregate = ShuffleIndex(inclusive, LOGICAL_WARP_THREADS - 1, member_mask); + Update(input, inclusive, exclusive, scan_op, is_integer); + } + + /// Update inclusive, exclusive, and warp aggregate using input, inclusive, and initial value + template + __device__ __forceinline__ void Update ( + T input, + T &inclusive, + T &exclusive, + T &warp_aggregate, + ScanOpT scan_op, + T initial_value, + IsIntegerT is_integer) + { + warp_aggregate = ShuffleIndex(inclusive, LOGICAL_WARP_THREADS - 1, member_mask); + Update(input, inclusive, exclusive, scan_op, initial_value, is_integer); + } + + + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/warp/specializations/warp_scan_smem.cuh b/fastertransformer/cuda/cub/warp/specializations/warp_scan_smem.cuh new file mode 100644 index 000000000..3237fcbfe --- /dev/null +++ b/fastertransformer/cuda/cub/warp/specializations/warp_scan_smem.cuh @@ -0,0 +1,397 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::WarpScanSmem provides smem-based variants of parallel prefix scan of items partitioned across a CUDA thread warp. + */ + +#pragma once + +#include "../../thread/thread_operators.cuh" +#include "../../thread/thread_load.cuh" +#include "../../thread/thread_store.cuh" +#include "../../util_type.cuh" +#include "../../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \brief WarpScanSmem provides smem-based variants of parallel prefix scan of items partitioned across a CUDA thread warp. + */ +template < + typename T, ///< Data type being scanned + int LOGICAL_WARP_THREADS, ///< Number of threads per logical warp + int PTX_ARCH> ///< The PTX compute capability for which to to specialize this collective +struct WarpScanSmem +{ + /****************************************************************************** + * Constants and type definitions + ******************************************************************************/ + + enum + { + /// Whether the logical warp size and the PTX warp size coincide + IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)), + + /// Whether the logical warp size is a power-of-two + IS_POW_OF_TWO = PowerOfTwo::VALUE, + + /// The number of warp scan steps + STEPS = Log2::VALUE, + + /// The number of threads in half a warp + HALF_WARP_THREADS = 1 << (STEPS - 1), + + /// The number of shared memory elements per warp + WARP_SMEM_ELEMENTS = LOGICAL_WARP_THREADS + HALF_WARP_THREADS, + }; + + /// Storage cell type (workaround for SM1x compiler bugs with custom-ops like Max() on signed chars) + typedef typename If<((Equals::VALUE || Equals::VALUE) && (PTX_ARCH < 200)), int, T>::Type CellT; + + /// Shared memory storage layout type (1.5 warps-worth of elements for each warp) + typedef CellT _TempStorage[WARP_SMEM_ELEMENTS]; + + // Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + _TempStorage &temp_storage; + unsigned int lane_id; + unsigned int member_mask; + + + /****************************************************************************** + * Construction + ******************************************************************************/ + + /// Constructor + __device__ __forceinline__ WarpScanSmem( + TempStorage &temp_storage) + : + temp_storage(temp_storage.Alias()), + + lane_id(IS_ARCH_WARP ? + LaneId() : + LaneId() % LOGICAL_WARP_THREADS), + + member_mask((0xffffffff >> (32 - LOGICAL_WARP_THREADS)) << ((IS_ARCH_WARP || !IS_POW_OF_TWO ) ? + 0 : // arch-width and non-power-of-two subwarps cannot be tiled with the arch-warp + ((LaneId() / LOGICAL_WARP_THREADS) * LOGICAL_WARP_THREADS))) + {} + + + /****************************************************************************** + * Utility methods + ******************************************************************************/ + + /// Basic inclusive scan iteration (template unrolled, inductive-case specialization) + template < + bool HAS_IDENTITY, + int STEP, + typename ScanOp> + __device__ __forceinline__ void ScanStep( + T &partial, + ScanOp scan_op, + Int2Type /*step*/) + { + const int OFFSET = 1 << STEP; + + // Share partial into buffer + ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) partial); + + WARP_SYNC(member_mask); + + // Update partial if addend is in range + if (HAS_IDENTITY || (lane_id >= OFFSET)) + { + T addend = (T) ThreadLoad(&temp_storage[HALF_WARP_THREADS + lane_id - OFFSET]); + partial = scan_op(addend, partial); + } + WARP_SYNC(member_mask); + + ScanStep(partial, scan_op, Int2Type()); + } + + + /// Basic inclusive scan iteration(template unrolled, base-case specialization) + template < + bool HAS_IDENTITY, + typename ScanOp> + __device__ __forceinline__ void ScanStep( + T &/*partial*/, + ScanOp /*scan_op*/, + Int2Type /*step*/) + {} + + + /// Inclusive prefix scan (specialized for summation across primitive types) + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item. + T &output, ///< [out] Calling thread's output item. May be aliased with \p input. + Sum scan_op, ///< [in] Binary scan operator + Int2Type /*is_primitive*/) ///< [in] Marker type indicating whether T is primitive type + { + T identity = 0; + ThreadStore(&temp_storage[lane_id], (CellT) identity); + + WARP_SYNC(member_mask); + + // Iterate scan steps + output = input; + ScanStep(output, scan_op, Int2Type<0>()); + } + + + /// Inclusive prefix scan + template + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item. + T &output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOp scan_op, ///< [in] Binary scan operator + Int2Type /*is_primitive*/) ///< [in] Marker type indicating whether T is primitive type + { + // Iterate scan steps + output = input; + ScanStep(output, scan_op, Int2Type<0>()); + } + + + /****************************************************************************** + * Interface + ******************************************************************************/ + + //--------------------------------------------------------------------- + // Broadcast + //--------------------------------------------------------------------- + + /// Broadcast + __device__ __forceinline__ T Broadcast( + T input, ///< [in] The value to broadcast + unsigned int src_lane) ///< [in] Which warp lane is to do the broadcasting + { + if (lane_id == src_lane) + { + ThreadStore(temp_storage, (CellT) input); + } + + WARP_SYNC(member_mask); + + return (T)ThreadLoad(temp_storage); + } + + + //--------------------------------------------------------------------- + // Inclusive operations + //--------------------------------------------------------------------- + + /// Inclusive scan + template + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item. + T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOp scan_op) ///< [in] Binary scan operator + { + InclusiveScan(input, inclusive_output, scan_op, Int2Type::PRIMITIVE>()); + } + + + /// Inclusive scan with aggregate + template + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item. + T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOp scan_op, ///< [in] Binary scan operator + T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. + { + InclusiveScan(input, inclusive_output, scan_op); + + // Retrieve aggregate + ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive_output); + + WARP_SYNC(member_mask); + + warp_aggregate = (T) ThreadLoad(&temp_storage[WARP_SMEM_ELEMENTS - 1]); + + WARP_SYNC(member_mask); + } + + + //--------------------------------------------------------------------- + // Get exclusive from inclusive + //--------------------------------------------------------------------- + + /// Update inclusive and exclusive using input and inclusive + template + __device__ __forceinline__ void Update( + T /*input*/, ///< [in] + T &inclusive, ///< [in, out] + T &exclusive, ///< [out] + ScanOpT /*scan_op*/, ///< [in] + IsIntegerT /*is_integer*/) ///< [in] + { + // initial value unknown + ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive); + + WARP_SYNC(member_mask); + + exclusive = (T) ThreadLoad(&temp_storage[HALF_WARP_THREADS + lane_id - 1]); + } + + /// Update inclusive and exclusive using input and inclusive (specialized for summation of integer types) + __device__ __forceinline__ void Update( + T input, + T &inclusive, + T &exclusive, + cub::Sum /*scan_op*/, + Int2Type /*is_integer*/) + { + // initial value presumed 0 + exclusive = inclusive - input; + } + + /// Update inclusive and exclusive using initial value using input, inclusive, and initial value + template + __device__ __forceinline__ void Update ( + T /*input*/, + T &inclusive, + T &exclusive, + ScanOpT scan_op, + T initial_value, + IsIntegerT /*is_integer*/) + { + inclusive = scan_op(initial_value, inclusive); + ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive); + + WARP_SYNC(member_mask); + + exclusive = (T) ThreadLoad(&temp_storage[HALF_WARP_THREADS + lane_id - 1]); + if (lane_id == 0) + exclusive = initial_value; + } + + /// Update inclusive and exclusive using initial value using input and inclusive (specialized for summation of integer types) + __device__ __forceinline__ void Update ( + T input, + T &inclusive, + T &exclusive, + cub::Sum scan_op, + T initial_value, + Int2Type /*is_integer*/) + { + inclusive = scan_op(initial_value, inclusive); + exclusive = inclusive - input; + } + + + /// Update inclusive, exclusive, and warp aggregate using input and inclusive + template + __device__ __forceinline__ void Update ( + T /*input*/, + T &inclusive, + T &exclusive, + T &warp_aggregate, + ScanOpT /*scan_op*/, + IsIntegerT /*is_integer*/) + { + // Initial value presumed to be unknown or identity (either way our padding is correct) + ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive); + + WARP_SYNC(member_mask); + + exclusive = (T) ThreadLoad(&temp_storage[HALF_WARP_THREADS + lane_id - 1]); + warp_aggregate = (T) ThreadLoad(&temp_storage[WARP_SMEM_ELEMENTS - 1]); + } + + /// Update inclusive, exclusive, and warp aggregate using input and inclusive (specialized for summation of integer types) + __device__ __forceinline__ void Update ( + T input, + T &inclusive, + T &exclusive, + T &warp_aggregate, + cub::Sum /*scan_o*/, + Int2Type /*is_integer*/) + { + // Initial value presumed to be unknown or identity (either way our padding is correct) + ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive); + + WARP_SYNC(member_mask); + + warp_aggregate = (T) ThreadLoad(&temp_storage[WARP_SMEM_ELEMENTS - 1]); + exclusive = inclusive - input; + } + + /// Update inclusive, exclusive, and warp aggregate using input, inclusive, and initial value + template + __device__ __forceinline__ void Update ( + T /*input*/, + T &inclusive, + T &exclusive, + T &warp_aggregate, + ScanOpT scan_op, + T initial_value, + IsIntegerT /*is_integer*/) + { + // Broadcast warp aggregate + ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive); + + WARP_SYNC(member_mask); + + warp_aggregate = (T) ThreadLoad(&temp_storage[WARP_SMEM_ELEMENTS - 1]); + + WARP_SYNC(member_mask); + + // Update inclusive with initial value + inclusive = scan_op(initial_value, inclusive); + + // Get exclusive from exclusive + ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id - 1], (CellT) inclusive); + + WARP_SYNC(member_mask); + + exclusive = (T) ThreadLoad(&temp_storage[HALF_WARP_THREADS + lane_id - 2]); + + if (lane_id == 0) + exclusive = initial_value; + } + + +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/warp/warp_reduce.cuh b/fastertransformer/cuda/cub/warp/warp_reduce.cuh new file mode 100644 index 000000000..189896b07 --- /dev/null +++ b/fastertransformer/cuda/cub/warp/warp_reduce.cuh @@ -0,0 +1,612 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * The cub::WarpReduce class provides [collective](index.html#sec0) methods for computing a parallel reduction of items partitioned across a CUDA thread warp. + */ + +#pragma once + +#include "specializations/warp_reduce_shfl.cuh" +#include "specializations/warp_reduce_smem.cuh" +#include "../thread/thread_operators.cuh" +#include "../util_arch.cuh" +#include "../util_type.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + + +/** + * \addtogroup WarpModule + * @{ + */ + +/** + * \brief The WarpReduce class provides [collective](index.html#sec0) methods for computing a parallel reduction of items partitioned across a CUDA thread warp. ![](warp_reduce_logo.png) + * + * \tparam T The reduction input/output element type + * \tparam LOGICAL_WARP_THREADS [optional] The number of threads per "logical" warp (may be less than the number of hardware warp threads). Default is the warp size of the targeted CUDA compute-capability (e.g., 32 threads for SM20). + * \tparam PTX_ARCH [optional] \ptxversion + * + * \par Overview + * - A reduction (or fold) + * uses a binary combining operator to compute a single aggregate from a list of input elements. + * - Supports "logical" warps smaller than the physical warp size (e.g., logical warps of 8 threads) + * - The number of entrant threads must be an multiple of \p LOGICAL_WARP_THREADS + * + * \par Performance Considerations + * - Uses special instructions when applicable (e.g., warp \p SHFL instructions) + * - Uses synchronization-free communication between warp lanes when applicable + * - Incurs zero bank conflicts for most types + * - Computation is slightly more efficient (i.e., having lower instruction overhead) for: + * - Summation (vs. generic reduction) + * - The architecture's warp size is a whole multiple of \p LOGICAL_WARP_THREADS + * + * \par Simple Examples + * \warpcollective{WarpReduce} + * \par + * The code snippet below illustrates four concurrent warp sum reductions within a block of + * 128 threads (one per each of the 32-thread warps). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpReduce for type int + * typedef cub::WarpReduce WarpReduce; + * + * // Allocate WarpReduce shared memory for 4 warps + * __shared__ typename WarpReduce::TempStorage temp_storage[4]; + * + * // Obtain one input item per thread + * int thread_data = ... + * + * // Return the warp-wide sums to each lane0 (threads 0, 32, 64, and 96) + * int warp_id = threadIdx.x / 32; + * int aggregate = WarpReduce(temp_storage[warp_id]).Sum(thread_data); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is {0, 1, 2, 3, ..., 127}. + * The corresponding output \p aggregate in threads 0, 32, 64, and 96 will \p 496, \p 1520, + * \p 2544, and \p 3568, respectively (and is undefined in other threads). + * + * \par + * The code snippet below illustrates a single warp sum reduction within a block of + * 128 threads. + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpReduce for type int + * typedef cub::WarpReduce WarpReduce; + * + * // Allocate WarpReduce shared memory for one warp + * __shared__ typename WarpReduce::TempStorage temp_storage; + * ... + * + * // Only the first warp performs a reduction + * if (threadIdx.x < 32) + * { + * // Obtain one input item per thread + * int thread_data = ... + * + * // Return the warp-wide sum to lane0 + * int aggregate = WarpReduce(temp_storage).Sum(thread_data); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the warp of threads is {0, 1, 2, 3, ..., 31}. + * The corresponding output \p aggregate in thread0 will be \p 496 (and is undefined in other threads). + * + */ +template < + typename T, + int LOGICAL_WARP_THREADS = CUB_PTX_WARP_THREADS, + int PTX_ARCH = CUB_PTX_ARCH> +class WarpReduce +{ +private: + + /****************************************************************************** + * Constants and type definitions + ******************************************************************************/ + + enum + { + /// Whether the logical warp size and the PTX warp size coincide + IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)), + + /// Whether the logical warp size is a power-of-two + IS_POW_OF_TWO = PowerOfTwo::VALUE, + }; + +public: + + #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + /// Internal specialization. Use SHFL-based reduction if (architecture is >= SM30) and (LOGICAL_WARP_THREADS is a power-of-two) + typedef typename If<(PTX_ARCH >= 300) && (IS_POW_OF_TWO), + WarpReduceShfl, + WarpReduceSmem >::Type InternalWarpReduce; + + #endif // DOXYGEN_SHOULD_SKIP_THIS + + +private: + + /// Shared memory storage layout type for WarpReduce + typedef typename InternalWarpReduce::TempStorage _TempStorage; + + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + /// Shared storage reference + _TempStorage &temp_storage; + + + /****************************************************************************** + * Utility methods + ******************************************************************************/ + +public: + + /// \smemstorage{WarpReduce} + struct TempStorage : Uninitialized<_TempStorage> {}; + + + /******************************************************************//** + * \name Collective constructors + *********************************************************************/ + //@{ + + + /** + * \brief Collective constructor using the specified memory allocation as temporary storage. Logical warp and lane identifiers are constructed from threadIdx.x. + */ + __device__ __forceinline__ WarpReduce( + TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage + : + temp_storage(temp_storage.Alias()) + {} + + + //@} end member group + /******************************************************************//** + * \name Summation reductions + *********************************************************************/ + //@{ + + + /** + * \brief Computes a warp-wide sum in the calling warp. The output is valid in warp lane0. + * + * \smemreuse + * + * \par Snippet + * The code snippet below illustrates four concurrent warp sum reductions within a block of + * 128 threads (one per each of the 32-thread warps). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpReduce for type int + * typedef cub::WarpReduce WarpReduce; + * + * // Allocate WarpReduce shared memory for 4 warps + * __shared__ typename WarpReduce::TempStorage temp_storage[4]; + * + * // Obtain one input item per thread + * int thread_data = ... + * + * // Return the warp-wide sums to each lane0 + * int warp_id = threadIdx.x / 32; + * int aggregate = WarpReduce(temp_storage[warp_id]).Sum(thread_data); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is {0, 1, 2, 3, ..., 127}. + * The corresponding output \p aggregate in threads 0, 32, 64, and 96 will \p 496, \p 1520, + * \p 2544, and \p 3568, respectively (and is undefined in other threads). + * + */ + __device__ __forceinline__ T Sum( + T input) ///< [in] Calling thread's input + { + return InternalWarpReduce(temp_storage).template Reduce(input, LOGICAL_WARP_THREADS, cub::Sum()); + } + + /** + * \brief Computes a partially-full warp-wide sum in the calling warp. The output is valid in warp lane0. + * + * All threads across the calling warp must agree on the same value for \p valid_items. Otherwise the result is undefined. + * + * \smemreuse + * + * \par Snippet + * The code snippet below illustrates a sum reduction within a single, partially-full + * block of 32 threads (one warp). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(int *d_data, int valid_items) + * { + * // Specialize WarpReduce for type int + * typedef cub::WarpReduce WarpReduce; + * + * // Allocate WarpReduce shared memory for one warp + * __shared__ typename WarpReduce::TempStorage temp_storage; + * + * // Obtain one input item per thread if in range + * int thread_data; + * if (threadIdx.x < valid_items) + * thread_data = d_data[threadIdx.x]; + * + * // Return the warp-wide sums to each lane0 + * int aggregate = WarpReduce(temp_storage).Sum( + * thread_data, valid_items); + * + * \endcode + * \par + * Suppose the input \p d_data is {0, 1, 2, 3, 4, ... and \p valid_items + * is \p 4. The corresponding output \p aggregate in thread0 is \p 6 (and is + * undefined in other threads). + * + */ + __device__ __forceinline__ T Sum( + T input, ///< [in] Calling thread's input + int valid_items) ///< [in] Total number of valid items in the calling thread's logical warp (may be less than \p LOGICAL_WARP_THREADS) + { + // Determine if we don't need bounds checking + return InternalWarpReduce(temp_storage).template Reduce(input, valid_items, cub::Sum()); + } + + + /** + * \brief Computes a segmented sum in the calling warp where segments are defined by head-flags. The sum of each segment is returned to the first lane in that segment (which always includes lane0). + * + * \smemreuse + * + * \par Snippet + * The code snippet below illustrates a head-segmented warp sum + * reduction within a block of 32 threads (one warp). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpReduce for type int + * typedef cub::WarpReduce WarpReduce; + * + * // Allocate WarpReduce shared memory for one warp + * __shared__ typename WarpReduce::TempStorage temp_storage; + * + * // Obtain one input item and flag per thread + * int thread_data = ... + * int head_flag = ... + * + * // Return the warp-wide sums to each lane0 + * int aggregate = WarpReduce(temp_storage).HeadSegmentedSum( + * thread_data, head_flag); + * + * \endcode + * \par + * Suppose the set of input \p thread_data and \p head_flag across the block of threads + * is {0, 1, 2, 3, ..., 31 and is {1, 0, 0, 0, 1, 0, 0, 0, ..., 1, 0, 0, 0, + * respectively. The corresponding output \p aggregate in threads 0, 4, 8, etc. will be + * \p 6, \p 22, \p 38, etc. (and is undefined in other threads). + * + * \tparam ReductionOp [inferred] Binary reduction operator type having member T operator()(const T &a, const T &b) + * + */ + template < + typename FlagT> + __device__ __forceinline__ T HeadSegmentedSum( + T input, ///< [in] Calling thread's input + FlagT head_flag) ///< [in] Head flag denoting whether or not \p input is the start of a new segment + { + return HeadSegmentedReduce(input, head_flag, cub::Sum()); + } + + + /** + * \brief Computes a segmented sum in the calling warp where segments are defined by tail-flags. The sum of each segment is returned to the first lane in that segment (which always includes lane0). + * + * \smemreuse + * + * \par Snippet + * The code snippet below illustrates a tail-segmented warp sum + * reduction within a block of 32 threads (one warp). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpReduce for type int + * typedef cub::WarpReduce WarpReduce; + * + * // Allocate WarpReduce shared memory for one warp + * __shared__ typename WarpReduce::TempStorage temp_storage; + * + * // Obtain one input item and flag per thread + * int thread_data = ... + * int tail_flag = ... + * + * // Return the warp-wide sums to each lane0 + * int aggregate = WarpReduce(temp_storage).TailSegmentedSum( + * thread_data, tail_flag); + * + * \endcode + * \par + * Suppose the set of input \p thread_data and \p tail_flag across the block of threads + * is {0, 1, 2, 3, ..., 31 and is {0, 0, 0, 1, 0, 0, 0, 1, ..., 0, 0, 0, 1, + * respectively. The corresponding output \p aggregate in threads 0, 4, 8, etc. will be + * \p 6, \p 22, \p 38, etc. (and is undefined in other threads). + * + * \tparam ReductionOp [inferred] Binary reduction operator type having member T operator()(const T &a, const T &b) + */ + template < + typename FlagT> + __device__ __forceinline__ T TailSegmentedSum( + T input, ///< [in] Calling thread's input + FlagT tail_flag) ///< [in] Head flag denoting whether or not \p input is the start of a new segment + { + return TailSegmentedReduce(input, tail_flag, cub::Sum()); + } + + + + //@} end member group + /******************************************************************//** + * \name Generic reductions + *********************************************************************/ + //@{ + + /** + * \brief Computes a warp-wide reduction in the calling warp using the specified binary reduction functor. The output is valid in warp lane0. + * + * Supports non-commutative reduction operators + * + * \smemreuse + * + * \par Snippet + * The code snippet below illustrates four concurrent warp max reductions within a block of + * 128 threads (one per each of the 32-thread warps). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpReduce for type int + * typedef cub::WarpReduce WarpReduce; + * + * // Allocate WarpReduce shared memory for 4 warps + * __shared__ typename WarpReduce::TempStorage temp_storage[4]; + * + * // Obtain one input item per thread + * int thread_data = ... + * + * // Return the warp-wide reductions to each lane0 + * int warp_id = threadIdx.x / 32; + * int aggregate = WarpReduce(temp_storage[warp_id]).Reduce( + * thread_data, cub::Max()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is {0, 1, 2, 3, ..., 127}. + * The corresponding output \p aggregate in threads 0, 32, 64, and 96 will \p 31, \p 63, + * \p 95, and \p 127, respectively (and is undefined in other threads). + * + * \tparam ReductionOp [inferred] Binary reduction operator type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ T Reduce( + T input, ///< [in] Calling thread's input + ReductionOp reduction_op) ///< [in] Binary reduction operator + { + return InternalWarpReduce(temp_storage).template Reduce(input, LOGICAL_WARP_THREADS, reduction_op); + } + + /** + * \brief Computes a partially-full warp-wide reduction in the calling warp using the specified binary reduction functor. The output is valid in warp lane0. + * + * All threads across the calling warp must agree on the same value for \p valid_items. Otherwise the result is undefined. + * + * Supports non-commutative reduction operators + * + * \smemreuse + * + * \par Snippet + * The code snippet below illustrates a max reduction within a single, partially-full + * block of 32 threads (one warp). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(int *d_data, int valid_items) + * { + * // Specialize WarpReduce for type int + * typedef cub::WarpReduce WarpReduce; + * + * // Allocate WarpReduce shared memory for one warp + * __shared__ typename WarpReduce::TempStorage temp_storage; + * + * // Obtain one input item per thread if in range + * int thread_data; + * if (threadIdx.x < valid_items) + * thread_data = d_data[threadIdx.x]; + * + * // Return the warp-wide reductions to each lane0 + * int aggregate = WarpReduce(temp_storage).Reduce( + * thread_data, cub::Max(), valid_items); + * + * \endcode + * \par + * Suppose the input \p d_data is {0, 1, 2, 3, 4, ... and \p valid_items + * is \p 4. The corresponding output \p aggregate in thread0 is \p 3 (and is + * undefined in other threads). + * + * \tparam ReductionOp [inferred] Binary reduction operator type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ T Reduce( + T input, ///< [in] Calling thread's input + ReductionOp reduction_op, ///< [in] Binary reduction operator + int valid_items) ///< [in] Total number of valid items in the calling thread's logical warp (may be less than \p LOGICAL_WARP_THREADS) + { + return InternalWarpReduce(temp_storage).template Reduce(input, valid_items, reduction_op); + } + + + /** + * \brief Computes a segmented reduction in the calling warp where segments are defined by head-flags. The reduction of each segment is returned to the first lane in that segment (which always includes lane0). + * + * Supports non-commutative reduction operators + * + * \smemreuse + * + * \par Snippet + * The code snippet below illustrates a head-segmented warp max + * reduction within a block of 32 threads (one warp). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpReduce for type int + * typedef cub::WarpReduce WarpReduce; + * + * // Allocate WarpReduce shared memory for one warp + * __shared__ typename WarpReduce::TempStorage temp_storage; + * + * // Obtain one input item and flag per thread + * int thread_data = ... + * int head_flag = ... + * + * // Return the warp-wide reductions to each lane0 + * int aggregate = WarpReduce(temp_storage).HeadSegmentedReduce( + * thread_data, head_flag, cub::Max()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data and \p head_flag across the block of threads + * is {0, 1, 2, 3, ..., 31 and is {1, 0, 0, 0, 1, 0, 0, 0, ..., 1, 0, 0, 0, + * respectively. The corresponding output \p aggregate in threads 0, 4, 8, etc. will be + * \p 3, \p 7, \p 11, etc. (and is undefined in other threads). + * + * \tparam ReductionOp [inferred] Binary reduction operator type having member T operator()(const T &a, const T &b) + */ + template < + typename ReductionOp, + typename FlagT> + __device__ __forceinline__ T HeadSegmentedReduce( + T input, ///< [in] Calling thread's input + FlagT head_flag, ///< [in] Head flag denoting whether or not \p input is the start of a new segment + ReductionOp reduction_op) ///< [in] Reduction operator + { + return InternalWarpReduce(temp_storage).template SegmentedReduce(input, head_flag, reduction_op); + } + + + /** + * \brief Computes a segmented reduction in the calling warp where segments are defined by tail-flags. The reduction of each segment is returned to the first lane in that segment (which always includes lane0). + * + * Supports non-commutative reduction operators + * + * \smemreuse + * + * \par Snippet + * The code snippet below illustrates a tail-segmented warp max + * reduction within a block of 32 threads (one warp). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpReduce for type int + * typedef cub::WarpReduce WarpReduce; + * + * // Allocate WarpReduce shared memory for one warp + * __shared__ typename WarpReduce::TempStorage temp_storage; + * + * // Obtain one input item and flag per thread + * int thread_data = ... + * int tail_flag = ... + * + * // Return the warp-wide reductions to each lane0 + * int aggregate = WarpReduce(temp_storage).TailSegmentedReduce( + * thread_data, tail_flag, cub::Max()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data and \p tail_flag across the block of threads + * is {0, 1, 2, 3, ..., 31 and is {0, 0, 0, 1, 0, 0, 0, 1, ..., 0, 0, 0, 1, + * respectively. The corresponding output \p aggregate in threads 0, 4, 8, etc. will be + * \p 3, \p 7, \p 11, etc. (and is undefined in other threads). + * + * \tparam ReductionOp [inferred] Binary reduction operator type having member T operator()(const T &a, const T &b) + */ + template < + typename ReductionOp, + typename FlagT> + __device__ __forceinline__ T TailSegmentedReduce( + T input, ///< [in] Calling thread's input + FlagT tail_flag, ///< [in] Tail flag denoting whether or not \p input is the end of the current segment + ReductionOp reduction_op) ///< [in] Reduction operator + { + return InternalWarpReduce(temp_storage).template SegmentedReduce(input, tail_flag, reduction_op); + } + + + + //@} end member group +}; + +/** @} */ // end group WarpModule + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cub/warp/warp_scan.cuh b/fastertransformer/cuda/cub/warp/warp_scan.cuh new file mode 100644 index 000000000..c7af0d343 --- /dev/null +++ b/fastertransformer/cuda/cub/warp/warp_scan.cuh @@ -0,0 +1,936 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * The cub::WarpScan class provides [collective](index.html#sec0) methods for computing a parallel prefix scan of items partitioned across a CUDA thread warp. + */ + +#pragma once + +#include "specializations/warp_scan_shfl.cuh" +#include "specializations/warp_scan_smem.cuh" +#include "../thread/thread_operators.cuh" +#include "../util_arch.cuh" +#include "../util_type.cuh" +#include "../util_namespace.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \addtogroup WarpModule + * @{ + */ + +/** + * \brief The WarpScan class provides [collective](index.html#sec0) methods for computing a parallel prefix scan of items partitioned across a CUDA thread warp. ![](warp_scan_logo.png) + * + * \tparam T The scan input/output element type + * \tparam LOGICAL_WARP_THREADS [optional] The number of threads per "logical" warp (may be less than the number of hardware warp threads). Default is the warp size associated with the CUDA Compute Capability targeted by the compiler (e.g., 32 threads for SM20). + * \tparam PTX_ARCH [optional] \ptxversion + * + * \par Overview + * - Given a list of input elements and a binary reduction operator, a [prefix scan](http://en.wikipedia.org/wiki/Prefix_sum) + * produces an output list where each element is computed to be the reduction + * of the elements occurring earlier in the input list. Prefix sum + * connotes a prefix scan with the addition operator. The term \em inclusive indicates + * that the ith output reduction incorporates the ith input. + * The term \em exclusive indicates the ith input is not incorporated into + * the ith output reduction. + * - Supports non-commutative scan operators + * - Supports "logical" warps smaller than the physical warp size (e.g., a logical warp of 8 threads) + * - The number of entrant threads must be an multiple of \p LOGICAL_WARP_THREADS + * + * \par Performance Considerations + * - Uses special instructions when applicable (e.g., warp \p SHFL) + * - Uses synchronization-free communication between warp lanes when applicable + * - Incurs zero bank conflicts for most types + * - Computation is slightly more efficient (i.e., having lower instruction overhead) for: + * - Summation (vs. generic scan) + * - The architecture's warp size is a whole multiple of \p LOGICAL_WARP_THREADS + * + * \par Simple Examples + * \warpcollective{WarpScan} + * \par + * The code snippet below illustrates four concurrent warp prefix sums within a block of + * 128 threads (one per each of the 32-thread warps). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpScan for type int + * typedef cub::WarpScan WarpScan; + * + * // Allocate WarpScan shared memory for 4 warps + * __shared__ typename WarpScan::TempStorage temp_storage[4]; + * + * // Obtain one input item per thread + * int thread_data = ... + * + * // Compute warp-wide prefix sums + * int warp_id = threadIdx.x / 32; + * WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, thread_data); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is {1, 1, 1, 1, ...}. + * The corresponding output \p thread_data in each of the four warps of threads will be + * 0, 1, 2, 3, ..., 31}. + * + * \par + * The code snippet below illustrates a single warp prefix sum within a block of + * 128 threads. + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpScan for type int + * typedef cub::WarpScan WarpScan; + * + * // Allocate WarpScan shared memory for one warp + * __shared__ typename WarpScan::TempStorage temp_storage; + * ... + * + * // Only the first warp performs a prefix sum + * if (threadIdx.x < 32) + * { + * // Obtain one input item per thread + * int thread_data = ... + * + * // Compute warp-wide prefix sums + * WarpScan(temp_storage).ExclusiveSum(thread_data, thread_data); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the warp of threads is {1, 1, 1, 1, ...}. + * The corresponding output \p thread_data will be {0, 1, 2, 3, ..., 31}. + * + */ +template < + typename T, + int LOGICAL_WARP_THREADS = CUB_PTX_WARP_THREADS, + int PTX_ARCH = CUB_PTX_ARCH> +class WarpScan +{ +private: + + /****************************************************************************** + * Constants and type definitions + ******************************************************************************/ + + enum + { + /// Whether the logical warp size and the PTX warp size coincide + IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)), + + /// Whether the logical warp size is a power-of-two + IS_POW_OF_TWO = ((LOGICAL_WARP_THREADS & (LOGICAL_WARP_THREADS - 1)) == 0), + + /// Whether the data type is an integer (which has fully-associative addition) + IS_INTEGER = ((Traits::CATEGORY == SIGNED_INTEGER) || (Traits::CATEGORY == UNSIGNED_INTEGER)) + }; + + /// Internal specialization. Use SHFL-based scan if (architecture is >= SM30) and (LOGICAL_WARP_THREADS is a power-of-two) + typedef typename If<(PTX_ARCH >= 300) && (IS_POW_OF_TWO), + WarpScanShfl, + WarpScanSmem >::Type InternalWarpScan; + + /// Shared memory storage layout type for WarpScan + typedef typename InternalWarpScan::TempStorage _TempStorage; + + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + /// Shared storage reference + _TempStorage &temp_storage; + unsigned int lane_id; + + + + /****************************************************************************** + * Public types + ******************************************************************************/ + +public: + + /// \smemstorage{WarpScan} + struct TempStorage : Uninitialized<_TempStorage> {}; + + + /******************************************************************//** + * \name Collective constructors + *********************************************************************/ + //@{ + + /** + * \brief Collective constructor using the specified memory allocation as temporary storage. Logical warp and lane identifiers are constructed from threadIdx.x. + */ + __device__ __forceinline__ WarpScan( + TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage + : + temp_storage(temp_storage.Alias()), + lane_id(IS_ARCH_WARP ? + LaneId() : + LaneId() % LOGICAL_WARP_THREADS) + {} + + + //@} end member group + /******************************************************************//** + * \name Inclusive prefix sums + *********************************************************************/ + //@{ + + + /** + * \brief Computes an inclusive prefix sum across the calling warp. + * + * \par + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates four concurrent warp-wide inclusive prefix sums within a block of + * 128 threads (one per each of the 32-thread warps). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpScan for type int + * typedef cub::WarpScan WarpScan; + * + * // Allocate WarpScan shared memory for 4 warps + * __shared__ typename WarpScan::TempStorage temp_storage[4]; + * + * // Obtain one input item per thread + * int thread_data = ... + * + * // Compute inclusive warp-wide prefix sums + * int warp_id = threadIdx.x / 32; + * WarpScan(temp_storage[warp_id]).InclusiveSum(thread_data, thread_data); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is {1, 1, 1, 1, ...}. + * The corresponding output \p thread_data in each of the four warps of threads will be + * 1, 2, 3, ..., 32}. + */ + __device__ __forceinline__ void InclusiveSum( + T input, ///< [in] Calling thread's input item. + T &inclusive_output) ///< [out] Calling thread's output item. May be aliased with \p input. + { + InclusiveScan(input, inclusive_output, cub::Sum()); + } + + + /** + * \brief Computes an inclusive prefix sum across the calling warp. Also provides every thread with the warp-wide \p warp_aggregate of all inputs. + * + * \par + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates four concurrent warp-wide inclusive prefix sums within a block of + * 128 threads (one per each of the 32-thread warps). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpScan for type int + * typedef cub::WarpScan WarpScan; + * + * // Allocate WarpScan shared memory for 4 warps + * __shared__ typename WarpScan::TempStorage temp_storage[4]; + * + * // Obtain one input item per thread + * int thread_data = ... + * + * // Compute inclusive warp-wide prefix sums + * int warp_aggregate; + * int warp_id = threadIdx.x / 32; + * WarpScan(temp_storage[warp_id]).InclusiveSum(thread_data, thread_data, warp_aggregate); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is {1, 1, 1, 1, ...}. + * The corresponding output \p thread_data in each of the four warps of threads will be + * 1, 2, 3, ..., 32}. Furthermore, \p warp_aggregate for all threads in all warps will be \p 32. + */ + __device__ __forceinline__ void InclusiveSum( + T input, ///< [in] Calling thread's input item. + T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. + { + InclusiveScan(input, inclusive_output, cub::Sum(), warp_aggregate); + } + + + //@} end member group + /******************************************************************//** + * \name Exclusive prefix sums + *********************************************************************/ + //@{ + + + /** + * \brief Computes an exclusive prefix sum across the calling warp. The value of 0 is applied as the initial value, and is assigned to \p exclusive_output in thread0. + * + * \par + * - \identityzero + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates four concurrent warp-wide exclusive prefix sums within a block of + * 128 threads (one per each of the 32-thread warps). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpScan for type int + * typedef cub::WarpScan WarpScan; + * + * // Allocate WarpScan shared memory for 4 warps + * __shared__ typename WarpScan::TempStorage temp_storage[4]; + * + * // Obtain one input item per thread + * int thread_data = ... + * + * // Compute exclusive warp-wide prefix sums + * int warp_id = threadIdx.x / 32; + * WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, thread_data); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is {1, 1, 1, 1, ...}. + * The corresponding output \p thread_data in each of the four warps of threads will be + * 0, 1, 2, ..., 31}. + * + */ + __device__ __forceinline__ void ExclusiveSum( + T input, ///< [in] Calling thread's input item. + T &exclusive_output) ///< [out] Calling thread's output item. May be aliased with \p input. + { + T initial_value = 0; + ExclusiveScan(input, exclusive_output, initial_value, cub::Sum()); + } + + + /** + * \brief Computes an exclusive prefix sum across the calling warp. The value of 0 is applied as the initial value, and is assigned to \p exclusive_output in thread0. Also provides every thread with the warp-wide \p warp_aggregate of all inputs. + * + * \par + * - \identityzero + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates four concurrent warp-wide exclusive prefix sums within a block of + * 128 threads (one per each of the 32-thread warps). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpScan for type int + * typedef cub::WarpScan WarpScan; + * + * // Allocate WarpScan shared memory for 4 warps + * __shared__ typename WarpScan::TempStorage temp_storage[4]; + * + * // Obtain one input item per thread + * int thread_data = ... + * + * // Compute exclusive warp-wide prefix sums + * int warp_aggregate; + * int warp_id = threadIdx.x / 32; + * WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, thread_data, warp_aggregate); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is {1, 1, 1, 1, ...}. + * The corresponding output \p thread_data in each of the four warps of threads will be + * 0, 1, 2, ..., 31}. Furthermore, \p warp_aggregate for all threads in all warps will be \p 32. + */ + __device__ __forceinline__ void ExclusiveSum( + T input, ///< [in] Calling thread's input item. + T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. + { + T initial_value = 0; + ExclusiveScan(input, exclusive_output, initial_value, cub::Sum(), warp_aggregate); + } + + + //@} end member group + /******************************************************************//** + * \name Inclusive prefix scans + *********************************************************************/ + //@{ + + /** + * \brief Computes an inclusive prefix scan using the specified binary scan functor across the calling warp. + * + * \par + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates four concurrent warp-wide inclusive prefix max scans within a block of + * 128 threads (one per each of the 32-thread warps). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpScan for type int + * typedef cub::WarpScan WarpScan; + * + * // Allocate WarpScan shared memory for 4 warps + * __shared__ typename WarpScan::TempStorage temp_storage[4]; + * + * // Obtain one input item per thread + * int thread_data = ... + * + * // Compute inclusive warp-wide prefix max scans + * int warp_id = threadIdx.x / 32; + * WarpScan(temp_storage[warp_id]).InclusiveScan(thread_data, thread_data, cub::Max()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is {0, -1, 2, -3, ..., 126, -127}. + * The corresponding output \p thread_data in the first warp would be + * 0, 0, 2, 2, ..., 30, 30, the output for the second warp would be 32, 32, 34, 34, ..., 62, 62, etc. + * + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item. + T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOp scan_op) ///< [in] Binary scan operator + { + InternalWarpScan(temp_storage).InclusiveScan(input, inclusive_output, scan_op); + } + + + /** + * \brief Computes an inclusive prefix scan using the specified binary scan functor across the calling warp. Also provides every thread with the warp-wide \p warp_aggregate of all inputs. + * + * \par + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates four concurrent warp-wide inclusive prefix max scans within a block of + * 128 threads (one per each of the 32-thread warps). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpScan for type int + * typedef cub::WarpScan WarpScan; + * + * // Allocate WarpScan shared memory for 4 warps + * __shared__ typename WarpScan::TempStorage temp_storage[4]; + * + * // Obtain one input item per thread + * int thread_data = ... + * + * // Compute inclusive warp-wide prefix max scans + * int warp_aggregate; + * int warp_id = threadIdx.x / 32; + * WarpScan(temp_storage[warp_id]).InclusiveScan( + * thread_data, thread_data, cub::Max(), warp_aggregate); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is {0, -1, 2, -3, ..., 126, -127}. + * The corresponding output \p thread_data in the first warp would be + * 0, 0, 2, 2, ..., 30, 30, the output for the second warp would be 32, 32, 34, 34, ..., 62, 62, etc. + * Furthermore, \p warp_aggregate would be assigned \p 30 for threads in the first warp, \p 62 for threads + * in the second warp, etc. + * + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ void InclusiveScan( + T input, ///< [in] Calling thread's input item. + T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOp scan_op, ///< [in] Binary scan operator + T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. + { + InternalWarpScan(temp_storage).InclusiveScan(input, inclusive_output, scan_op, warp_aggregate); + } + + + //@} end member group + /******************************************************************//** + * \name Exclusive prefix scans + *********************************************************************/ + //@{ + + /** + * \brief Computes an exclusive prefix scan using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p output computed for warp-lane0 is undefined. + * + * \par + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates four concurrent warp-wide exclusive prefix max scans within a block of + * 128 threads (one per each of the 32-thread warps). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpScan for type int + * typedef cub::WarpScan WarpScan; + * + * // Allocate WarpScan shared memory for 4 warps + * __shared__ typename WarpScan::TempStorage temp_storage[4]; + * + * // Obtain one input item per thread + * int thread_data = ... + * + * // Compute exclusive warp-wide prefix max scans + * int warp_id = threadIdx.x / 32; + * WarpScan(temp_storage[warp_id]).ExclusiveScan(thread_data, thread_data, cub::Max()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is {0, -1, 2, -3, ..., 126, -127}. + * The corresponding output \p thread_data in the first warp would be + * ?, 0, 0, 2, ..., 28, 30, the output for the second warp would be ?, 32, 32, 34, ..., 60, 62, etc. + * (The output \p thread_data in warp lane0 is undefined.) + * + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item. + T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOp scan_op) ///< [in] Binary scan operator + { + InternalWarpScan internal(temp_storage); + + T inclusive_output; + internal.InclusiveScan(input, inclusive_output, scan_op); + + internal.Update( + input, + inclusive_output, + exclusive_output, + scan_op, + Int2Type()); + } + + + /** + * \brief Computes an exclusive prefix scan using the specified binary scan functor across the calling warp. + * + * \par + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates four concurrent warp-wide exclusive prefix max scans within a block of + * 128 threads (one per each of the 32-thread warps). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpScan for type int + * typedef cub::WarpScan WarpScan; + * + * // Allocate WarpScan shared memory for 4 warps + * __shared__ typename WarpScan::TempStorage temp_storage[4]; + * + * // Obtain one input item per thread + * int thread_data = ... + * + * // Compute exclusive warp-wide prefix max scans + * int warp_id = threadIdx.x / 32; + * WarpScan(temp_storage[warp_id]).ExclusiveScan(thread_data, thread_data, INT_MIN, cub::Max()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is {0, -1, 2, -3, ..., 126, -127}. + * The corresponding output \p thread_data in the first warp would be + * INT_MIN, 0, 0, 2, ..., 28, 30, the output for the second warp would be 30, 32, 32, 34, ..., 60, 62, etc. + * + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item. + T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + T initial_value, ///< [in] Initial value to seed the exclusive scan + ScanOp scan_op) ///< [in] Binary scan operator + { + InternalWarpScan internal(temp_storage); + + T inclusive_output; + internal.InclusiveScan(input, inclusive_output, scan_op); + + internal.Update( + input, + inclusive_output, + exclusive_output, + scan_op, + initial_value, + Int2Type()); + } + + + /** + * \brief Computes an exclusive prefix scan using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p output computed for warp-lane0 is undefined. Also provides every thread with the warp-wide \p warp_aggregate of all inputs. + * + * \par + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates four concurrent warp-wide exclusive prefix max scans within a block of + * 128 threads (one per each of the 32-thread warps). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpScan for type int + * typedef cub::WarpScan WarpScan; + * + * // Allocate WarpScan shared memory for 4 warps + * __shared__ typename WarpScan::TempStorage temp_storage[4]; + * + * // Obtain one input item per thread + * int thread_data = ... + * + * // Compute exclusive warp-wide prefix max scans + * int warp_aggregate; + * int warp_id = threadIdx.x / 32; + * WarpScan(temp_storage[warp_id]).ExclusiveScan(thread_data, thread_data, cub::Max(), warp_aggregate); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is {0, -1, 2, -3, ..., 126, -127}. + * The corresponding output \p thread_data in the first warp would be + * ?, 0, 0, 2, ..., 28, 30, the output for the second warp would be ?, 32, 32, 34, ..., 60, 62, etc. + * (The output \p thread_data in warp lane0 is undefined.) Furthermore, \p warp_aggregate would be assigned \p 30 for threads in the first warp, \p 62 for threads + * in the second warp, etc. + * + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item. + T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOp scan_op, ///< [in] Binary scan operator + T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. + { + InternalWarpScan internal(temp_storage); + + T inclusive_output; + internal.InclusiveScan(input, inclusive_output, scan_op); + + internal.Update( + input, + inclusive_output, + exclusive_output, + warp_aggregate, + scan_op, + Int2Type()); + } + + + /** + * \brief Computes an exclusive prefix scan using the specified binary scan functor across the calling warp. Also provides every thread with the warp-wide \p warp_aggregate of all inputs. + * + * \par + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates four concurrent warp-wide exclusive prefix max scans within a block of + * 128 threads (one per each of the 32-thread warps). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpScan for type int + * typedef cub::WarpScan WarpScan; + * + * // Allocate WarpScan shared memory for 4 warps + * __shared__ typename WarpScan::TempStorage temp_storage[4]; + * + * // Obtain one input item per thread + * int thread_data = ... + * + * // Compute exclusive warp-wide prefix max scans + * int warp_aggregate; + * int warp_id = threadIdx.x / 32; + * WarpScan(temp_storage[warp_id]).ExclusiveScan(thread_data, thread_data, INT_MIN, cub::Max(), warp_aggregate); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is {0, -1, 2, -3, ..., 126, -127}. + * The corresponding output \p thread_data in the first warp would be + * INT_MIN, 0, 0, 2, ..., 28, 30, the output for the second warp would be 30, 32, 32, 34, ..., 60, 62, etc. + * Furthermore, \p warp_aggregate would be assigned \p 30 for threads in the first warp, \p 62 for threads + * in the second warp, etc. + * + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ void ExclusiveScan( + T input, ///< [in] Calling thread's input item. + T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + T initial_value, ///< [in] Initial value to seed the exclusive scan + ScanOp scan_op, ///< [in] Binary scan operator + T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. + { + InternalWarpScan internal(temp_storage); + + T inclusive_output; + internal.InclusiveScan(input, inclusive_output, scan_op); + + internal.Update( + input, + inclusive_output, + exclusive_output, + warp_aggregate, + scan_op, + initial_value, + Int2Type()); + } + + + //@} end member group + /******************************************************************//** + * \name Combination (inclusive & exclusive) prefix scans + *********************************************************************/ + //@{ + + + /** + * \brief Computes both inclusive and exclusive prefix scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for warp-lane0 is undefined. + * + * \par + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates four concurrent warp-wide exclusive prefix max scans within a block of + * 128 threads (one per each of the 32-thread warps). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpScan for type int + * typedef cub::WarpScan WarpScan; + * + * // Allocate WarpScan shared memory for 4 warps + * __shared__ typename WarpScan::TempStorage temp_storage[4]; + * + * // Obtain one input item per thread + * int thread_data = ... + * + * // Compute exclusive warp-wide prefix max scans + * int inclusive_partial, exclusive_partial; + * WarpScan(temp_storage[warp_id]).Scan(thread_data, inclusive_partial, exclusive_partial, cub::Max()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is {0, -1, 2, -3, ..., 126, -127}. + * The corresponding output \p inclusive_partial in the first warp would be + * 0, 0, 2, 2, ..., 30, 30, the output for the second warp would be 32, 32, 34, 34, ..., 62, 62, etc. + * The corresponding output \p exclusive_partial in the first warp would be + * ?, 0, 0, 2, ..., 28, 30, the output for the second warp would be ?, 32, 32, 34, ..., 60, 62, etc. + * (The output \p thread_data in warp lane0 is undefined.) + * + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ void Scan( + T input, ///< [in] Calling thread's input item. + T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item. + T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item. + ScanOp scan_op) ///< [in] Binary scan operator + { + InternalWarpScan internal(temp_storage); + + internal.InclusiveScan(input, inclusive_output, scan_op); + + internal.Update( + input, + inclusive_output, + exclusive_output, + scan_op, + Int2Type()); + } + + + /** + * \brief Computes both inclusive and exclusive prefix scans using the specified binary scan functor across the calling warp. + * + * \par + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates four concurrent warp-wide prefix max scans within a block of + * 128 threads (one per each of the 32-thread warps). + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpScan for type int + * typedef cub::WarpScan WarpScan; + * + * // Allocate WarpScan shared memory for 4 warps + * __shared__ typename WarpScan::TempStorage temp_storage[4]; + * + * // Obtain one input item per thread + * int thread_data = ... + * + * // Compute inclusive warp-wide prefix max scans + * int warp_id = threadIdx.x / 32; + * int inclusive_partial, exclusive_partial; + * WarpScan(temp_storage[warp_id]).Scan(thread_data, inclusive_partial, exclusive_partial, INT_MIN, cub::Max()); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is {0, -1, 2, -3, ..., 126, -127}. + * The corresponding output \p inclusive_partial in the first warp would be + * 0, 0, 2, 2, ..., 30, 30, the output for the second warp would be 32, 32, 34, 34, ..., 62, 62, etc. + * The corresponding output \p exclusive_partial in the first warp would be + * INT_MIN, 0, 0, 2, ..., 28, 30, the output for the second warp would be 30, 32, 32, 34, ..., 60, 62, etc. + * + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ + template + __device__ __forceinline__ void Scan( + T input, ///< [in] Calling thread's input item. + T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item. + T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item. + T initial_value, ///< [in] Initial value to seed the exclusive scan + ScanOp scan_op) ///< [in] Binary scan operator + { + InternalWarpScan internal(temp_storage); + + internal.InclusiveScan(input, inclusive_output, scan_op); + + internal.Update( + input, + inclusive_output, + exclusive_output, + scan_op, + initial_value, + Int2Type()); + } + + + + //@} end member group + /******************************************************************//** + * \name Data exchange + *********************************************************************/ + //@{ + + /** + * \brief Broadcast the value \p input from warp-lanesrc_lane to all lanes in the warp + * + * \par + * - \smemreuse + * + * \par Snippet + * The code snippet below illustrates the warp-wide broadcasts of values from + * lanes0 in each of four warps to all other threads in those warps. + * \par + * \code + * #include + * + * __global__ void ExampleKernel(...) + * { + * // Specialize WarpScan for type int + * typedef cub::WarpScan WarpScan; + * + * // Allocate WarpScan shared memory for 4 warps + * __shared__ typename WarpScan::TempStorage temp_storage[4]; + * + * // Obtain one input item per thread + * int thread_data = ... + * + * // Broadcast from lane0 in each warp to all other threads in the warp + * int warp_id = threadIdx.x / 32; + * thread_data = WarpScan(temp_storage[warp_id]).Broadcast(thread_data, 0); + * + * \endcode + * \par + * Suppose the set of input \p thread_data across the block of threads is {0, 1, 2, 3, ..., 127}. + * The corresponding output \p thread_data will be + * {0, 0, ..., 0} in warp0, + * {32, 32, ..., 32} in warp1, + * {64, 64, ..., 64} in warp2, etc. + */ + __device__ __forceinline__ T Broadcast( + T input, ///< [in] The value to broadcast + unsigned int src_lane) ///< [in] Which warp lane is to do the broadcasting + { + return InternalWarpScan(temp_storage).Broadcast(input, src_lane); + } + + //@} end member group + +}; + +/** @} */ // end group WarpModule + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/fastertransformer/cuda/cuda_kernels.cu b/fastertransformer/cuda/cuda_kernels.cu index f4db19305..633c16c1b 100644 --- a/fastertransformer/cuda/cuda_kernels.cu +++ b/fastertransformer/cuda/cuda_kernels.cu @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "fastertransformer/common.h" #include "cuda_kernels.h" @@ -168,8 +169,7 @@ void add_bias_input_layernorm(T* out, const T* input, const T* bias, const T* ga float variance = 0.0f; float local_out = 0.0f; - for(int i = tid; i < n; i += blockDim.x) - local_out += (float)(out[blockIdx.x * n + i] + input[blockIdx.x * n + i] + __ldg(&bias[i])); + local_out += (float)(out[blockIdx.x * n + tid] + input[blockIdx.x * n + tid] + __ldg(&bias[tid])); mean = blockReduceSum(local_out); if(threadIdx.x == 0) @@ -181,9 +181,8 @@ void add_bias_input_layernorm(T* out, const T* input, const T* bias, const T* ga s_variance = variance / n + 1e-6f; __syncthreads(); - for(int i = tid; i < n; i += blockDim.x) - out[blockIdx.x * n + i] = - (T)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(__ldg(&gamma[i])) + (float)(__ldg(&beta[i]))); + out[blockIdx.x * n + tid] = + (T)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(__ldg(&gamma[tid])) + (float)(__ldg(&beta[tid]))); } template <> @@ -204,7 +203,7 @@ void add_bias_input_layernorm(half* out, const half* input, const half* bias, const half2* bias_ptr = (const half2*)bias; const half2* gamma_ptr = (const half2*)gamma; const half2* beta_ptr = (const half2*)beta; - + float local_out = 0.0f; int id = blockIdx.x * n / 2 + tid; local_out_fp2 = __half22float2(__hadd2(__hadd2(out_ptr[id], input_ptr[id]), __ldg(&bias_ptr[tid]))); @@ -230,25 +229,127 @@ void add_bias_input_layernorm(half* out, const half* input, const half* bias, out_ptr[id] = __float22half2_rn(local_out_fp2); } -template -__global__ -void broadcast_kernel(T* log_probs, T* cum_log_probs, const int batch_size, const int beam_width, const int vocab_size, const int N) + +template +__global__ +void add_bias_input_layernorm_v2(T* out, const T* __restrict input, const T* __restrict bias, + const T* __restrict gamma, const T* __restrict beta, int n) +{ + const int ite = 4; + const int tid = threadIdx.x; + const int bid = blockIdx.x; + + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + float local_out[ite]; + + float sum = 0.0f; + #pragma unroll + for(int i = 0; i < ite; i++) + { + int col_id = i * blockDim.x + tid; + int id = bid * n + col_id; + local_out[i] = (float)(out[id] + __ldg(&input[id]) + __ldg(&bias[col_id])); + sum += local_out[i]; + } + + mean = blockReduceSum(sum); + if(tid == 0) + s_mean = mean / n; + __syncthreads(); + + float var = 0.0f; + #pragma unroll + for(int i = 0; i < ite; i++) + { + float diff = local_out[i] - s_mean; + var += diff * diff; + } + + variance = blockReduceSum(var); + if(tid == 0) + s_variance = rsqrtf(variance / n + 1e-6f); + __syncthreads(); + + #pragma unroll + for(int i = 0; i < ite; i++) + { + int col_id = i * blockDim.x + tid; + int id = bid * n + col_id; + out[id] = (T)((local_out[i] - s_mean) * s_variance * (float)__ldg(&gamma[col_id]) + (float)__ldg(&beta[col_id])); + } +} + +template <> +__global__ +void add_bias_input_layernorm_v2(half* out, const half* __restrict input, const half* __restrict bias, + const half* __restrict gamma, const half* __restrict beta, int n) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int bid = tid / vocab_size; + const int ite = 4; + const int tid = threadIdx.x; + const int bid = blockIdx.x; + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + half2 local_out_half2[ite]; + + half2* out_ptr = (half2*)out; + const half2* input_ptr = (const half2*)input; + const half2* bias_ptr = (const half2*)bias; + const half2* gamma_ptr = (const half2*)gamma; + const half2* beta_ptr = (const half2*)beta; + + // float sum = 0.0f; + half2 sum = __float2half2_rn(0.0f); + #pragma unroll + for(int i = 0; i < ite; i++) + { + int col_id = i * blockDim.x + tid; + int id = bid * n / 2 + col_id; + local_out_half2[i] = out_ptr[id] + __ldg(&input_ptr[id]) + __ldg(&bias_ptr[col_id]); + sum += local_out_half2[i]; + } + + mean = blockReduceSum((float)(sum.x + sum.y)); + if(threadIdx.x == 0) + s_mean = mean / n; + __syncthreads(); + + float var = 0.0f; + half2 s_mean_2 = __float2half2_rn(s_mean); + #pragma unroll + for(int i = 0; i < ite; i++) + { + local_out_half2[i] = local_out_half2[i] - s_mean_2; + float v1 = (float)local_out_half2[i].x; + float v2 = (float)local_out_half2[i].y; + var += v1 * v1 + v2 * v2; + } - if(tid < N) - log_probs[tid] += cum_log_probs[bid]; + variance = blockReduceSum(var); + if(threadIdx.x == 0) + s_variance = rsqrtf(variance / n + 1e-6f); + __syncthreads(); + + half2 s_var_2 = __float2half2_rn(s_variance); + #pragma unroll + for(int i = 0; i < ite; i++) + { + int col_id = i * blockDim.x + tid; + int id = bid * n / 2 + col_id; + out_ptr[id] = local_out_half2[i] * s_var_2 * __ldg(&gamma_ptr[col_id]) + __ldg(&beta_ptr[col_id]); + } } template void add_bias_act_kernelLauncher(T* out, const T* bias, int m, int n, cudaStream_t stream) { -// dim3 grid(m / 64); - dim3 grid(m / 4); + dim3 grid(ceil(m / 4.)); dim3 block(n / 4); assert(block.x <= 1024); -// dim3 block(n); add_bias_act<<>>(out, bias, m, n); } @@ -259,10 +360,12 @@ void add_bias_input_layernorm_kernelLauncher(T* out, const T* input, const T* bi dim3 grid(m); dim3 block(n); assert(n <= 1024); - add_bias_input_layernorm<<>>(out, input, bias, gamma, beta, m, n); + if(n == 768 || n == 1024) + add_bias_input_layernorm_v2<<>>(out, input, bias, gamma, beta, n); + else + add_bias_input_layernorm<<>>(out, input, bias, gamma, beta, m, n); } - template <> void add_bias_input_layernorm_kernelLauncher(half* out, const half* input, const half* bias, const half* gamma, const half* beta, int m, int n, cudaStream_t stream) @@ -270,20 +373,319 @@ void add_bias_input_layernorm_kernelLauncher(half* out, const half* input, const dim3 grid(m); dim3 block(n / 2); assert(n / 2 <= 1024); - add_bias_input_layernorm<<>>(out, input, bias, gamma, beta, m, n); + + if(m >= 512 && (n == 768 || n == 1024)) + add_bias_input_layernorm_v2<<>>(out, input, bias, gamma, beta, n); + else + add_bias_input_layernorm<<>>(out, input, bias, gamma, beta, m, n); } -void broadcast_kernelLauncher(float* log_probs, float* cum_log_probs, const int batch_size, const int beam_width, - const int vocab_size, cudaStream_t stream) +template +__global__ void update_logits_kernel(T* logits, const T* bias, const int end_id, const bool* finished, const int n) { - - int N = batch_size * beam_width * vocab_size; - dim3 block(1024); - dim3 grid((N - 1) / block.x + 1); + int bid = blockIdx.x; + bool finish = finished[bid]; + int offset = bid * n; + + float max_val = -1 * FLT_MAX; + __shared__ float s_max_val; + __shared__ float s_sum_val; + + for(int tid = threadIdx.x; tid < n; tid += blockDim.x) + { + if(finish) + logits[offset + tid] = (tid == end_id) ? FLT_MAX : -1 * FLT_MAX; + else + logits[offset + tid] += bias[tid]; + max_val = max(max_val, logits[offset + tid]); + } + + max_val = blockReduceMax((float)max_val); + if(threadIdx.x == 0) + s_max_val = max_val; + __syncthreads(); + + float sum_val = 0.0f; + for(int tid = threadIdx.x; tid < n; tid += blockDim.x) + { + logits[offset + tid] = __expf((float)logits[offset + tid] - s_max_val); + sum_val += (float)logits[offset + tid]; + } - broadcast_kernel<<>>(log_probs, cum_log_probs, batch_size, beam_width, vocab_size, N); + sum_val = blockReduceSum(sum_val); + if(threadIdx.x == 0) + s_sum_val = sum_val; + __syncthreads(); + + for(int tid = threadIdx.x; tid < n; tid += blockDim.x) + { + logits[offset + tid] = logf((float)logits[offset + tid] / s_sum_val); + } +} + +template +__global__ void update_logits_kernel_without_softmax(T* logits, const T* bias, const int end_id, const bool* finished, const int n) +{ + int bid = blockIdx.x; + bool finish = finished[bid]; + int offset = bid * n; + + for(int tid = threadIdx.x; tid < n; tid += blockDim.x) + { + if(finish) + logits[offset + tid] = (tid == end_id) ? FLT_MAX : -1 * FLT_MAX; + else + logits[offset + tid] += bias[tid]; + } } +template +__global__ void update_logits_kernel_without_log(T* logits, const T* bias, const int end_id, const bool* finished, const int n) +{ + int bid = blockIdx.x; + bool finish = finished[bid]; + int offset = bid * n; + + float max_val = -1 * FLT_MAX; + __shared__ float s_max_val; + __shared__ float s_sum_val; + + for(int tid = threadIdx.x; tid < n; tid += blockDim.x) + { + if(finish) + logits[offset + tid] = (tid == end_id) ? FLT_MAX : -1 * FLT_MAX; + else + logits[offset + tid] += bias[tid]; + max_val = max(max_val, logits[offset + tid]); + } + + max_val = blockReduceMax((float)max_val); + if(threadIdx.x == 0) + s_max_val = max_val; + __syncthreads(); + + float sum_val = 0.0f; + for(int tid = threadIdx.x; tid < n; tid += blockDim.x) + { + logits[offset + tid] = __expf((float)logits[offset + tid] - s_max_val); + sum_val += (float)logits[offset + tid]; + } + + sum_val = blockReduceSum(sum_val); + if(threadIdx.x == 0) + s_sum_val = sum_val; + __syncthreads(); + + for(int tid = threadIdx.x; tid < n; tid += blockDim.x) + { + logits[offset + tid] = ((float)logits[offset + tid] / s_sum_val); + } +} + +template +__global__ void remove_sequence_length_padding(const T* src, T* tgt, + const int* tmp_mask_offset, + int* mask_offset, + const int n) +{ + const int tid = threadIdx.x; + const int bid = blockIdx.x; + mask_offset[bid] = tmp_mask_offset[bid]; + const int src_seq_id = bid + mask_offset[bid]; + const int tgt_seq_id = bid; + + + for(int i = tid; i < n; i += blockDim.x) + { + tgt[tgt_seq_id * n + i] = src[src_seq_id * n + i]; + } +} + +template +void remove_sequence_length_padding_kernelLauncher(const T* src, T* tgt, + const int* tmp_mask_offset, + int* mask_offset, + const int m, const int n, cudaStream_t stream) +{ + // src: [batch_size*max_seq_len, hidden_dim] + // tgt: [valid_word_num, hidden_dim] + remove_sequence_length_padding<<>>(src, tgt, tmp_mask_offset, mask_offset, n); +} + +template +__global__ void rebuild_sequence_length_padding(const T* src, T* tgt, + const int* mask_offset, + const int n) +{ + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int tgt_seq_id = bid + mask_offset[bid]; + const int src_seq_id = bid; + + for(int i = tid; i < n; i += blockDim.x) + { + tgt[tgt_seq_id * n + i] = src[src_seq_id * n + i]; + } +} + +template +void rebuild_sequence_length_padding_kernelLauncher(const T* src, T* tgt, + const int* mask_offset, const int m, + const int n, cudaStream_t stream) +{ + // src: [valid_word_num, hidden_dim] + // tgt: [batch_size*max_seq_len, hidden_dim] + rebuild_sequence_length_padding<<>>(src, tgt, mask_offset, n); +} + +__global__ void build_sequence_length_padding_offset(const int* sequence_length, + const int batch_size, const int max_seq_len, int* valid_word_num, int* tmp_mask_offset) +{ + // do cumulated sum + int total_seq_len = 0; + int cum_offset = 0; + int index = 0; + for(int i = 0; i < batch_size; i++) + { + const int seq_len = sequence_length[i]; + for(int j = 0; j < seq_len; j++) + { + tmp_mask_offset[index] = cum_offset; + index++; + } + cum_offset += max_seq_len - seq_len; + total_seq_len += seq_len; + } + valid_word_num[0] = total_seq_len; +} + +void build_sequence_length_padding_offset_kernelLauncher(const int* sequence_length, + const int batch_size, const int max_seq_len, int* valid_word_num, int* tmp_mask_offset, + cudaStream_t stream) +{ + build_sequence_length_padding_offset<<<1, 1, 0, stream>>>(sequence_length, + batch_size, max_seq_len, valid_word_num, tmp_mask_offset); +} + +template void rebuild_sequence_length_padding_kernelLauncher(const float* src, float* tgt, + const int* mask_offset, const int m, + const int n, cudaStream_t stream); + + +template void rebuild_sequence_length_padding_kernelLauncher(const half* src, half* tgt, + const int* mask_offset, const int m, + const int n, cudaStream_t stream); + +template void remove_sequence_length_padding_kernelLauncher(const float* src, float* tgt, + const int* tmp_mask_offset, + int* mask_offset, const int m, + const int n, cudaStream_t stream); + +template void remove_sequence_length_padding_kernelLauncher(const half* src, half* tgt, + const int* tmp_mask_offset, + int* mask_offset, const int m, + const int n, cudaStream_t stream); + +void update_logits(float* logits, const float* bias, const int end_id, const bool* finished, + const int m, const int n, cudaStream_t stream) +{ + dim3 grid(m); + dim3 block(min(n, 1024)); + /*n is the vocab_size, e.g., 30000, 7000.... vocab_size is usually very big. */ + update_logits_kernel<<>>(logits, bias, end_id, finished, n); +} + +void update_logits_without_softmax(float* logits, const float* bias, const int end_id, const bool* finished, + const int m, const int n, cudaStream_t stream) +{ + dim3 grid(m); + dim3 block(min(n, 1024)); + /*n is the vocab_size, e.g., 30000, 7000.... vocab_size is usually very big. */ + update_logits_kernel_without_softmax<<>>(logits, bias, end_id, finished, n); +} + +void update_logits_without_log(float* logits, const float* bias, const int end_id, const bool* finished, + const int m, const int n, cudaStream_t stream) +{ + dim3 grid(m); + dim3 block(min(n, 1024)); + /*n is the vocab_size, e.g., 30000, 7000.... vocab_size is usually very big. */ + update_logits_kernel_without_log<<>>(logits, bias, end_id, finished, n); +} + +template void add_bias_act_kernelLauncher( + float* out, const float* bias, int m, int n, cudaStream_t stream); + +template void add_bias_input_layernorm_kernelLauncher( + float* out, const float* input, const float* bias, const float* gamma, const float* beta, + int m, int n, cudaStream_t stream); + +template void add_bias_act_kernelLauncher( + half* out, const half* bias, int m, int n, cudaStream_t stream); + +template void add_bias_input_layernorm_kernelLauncher( + half* out, const half* input, const half* bias, const half* gamma, const half* beta, + int m, int n, cudaStream_t stream); + +/* *********************************** Debug tools *********************************** */ + +template +__global__ +void print_abs_mean_kernel(const T* buf, uint size) +{ + float sum; + for(int i = 0; i < size; i++) + { + sum += abs((float)buf[i]); + // printf("[INFO] buf[%d] %f \n", i, buf[i]); + } + printf("mean: %f \n", (float) sum / (float) size); + printf("sum: %f \n", sum); +} + +template +__global__ +void print_kernel(const T* buf, uint size) +{ + for(int i = 0; i < size; i++) + { + printf("%f ", (float(buf[i]))); + } + printf("\n"); +} + +template +void print_first_k(const T* buf, uint size, cudaStream_t stream) +{ + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); + print_kernel<<<1, 1, 0, stream>>>(buf, size); + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +} + +template +void print_abs_mean(const T* buf, uint size, cudaStream_t stream) +{ + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); + print_abs_mean_kernel<<<1, 1, 0, stream>>>(buf, size); + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +} + +template void print_first_k(const float*, uint size, cudaStream_t); +template void print_first_k(const half*, uint size, cudaStream_t); +template void print_first_k(const int*, uint size, cudaStream_t); + +template void print_abs_mean(const float* buf, uint size, cudaStream_t stream); +template void print_abs_mean(const half* buf, uint size, cudaStream_t stream); +template void print_abs_mean(const int* buf, uint size, cudaStream_t stream); + +/* **************************** end of Debug tools *********************************** */ + +/* *************************** depreciated kernels *********************************** */ + template __global__ void topK_kernel(const T* log_probs, int* ids, const int batch_size, const int N, const int K) @@ -295,11 +697,11 @@ void topK_kernel(const T* log_probs, int* ids, const int batch_size, const int N { bool choosed = false; val = (tid < N ) ? (float)log_probs[ite * N + tid] : -1e20f; - + for(int kids = 0; kids < K; ++kids) { max_val = blockReduceMax(val); - + if(threadIdx.x == 0) s_max_val = max_val; __syncthreads(); @@ -351,7 +753,7 @@ void topK_kernel_2nd(const T* log_probs, int* ids, const int batch_size, const i choosed = true; } __syncthreads(); - + // simply sort the ids if(threadIdx.x == 0 && beam_index - begin_beam_index > 1){ for(int i = begin_beam_index; i < beam_index; i++){ @@ -383,36 +785,6 @@ void topK(const float* log_probs, int* ids, const int batch_size, const int beam topK_kernel_2nd<<<1, block, 0, stream>>>(log_probs, ids, batch_size, beam_width * grid.x, beam_width, N); } -template -__global__ -void update_kernel(T* log_probs, T* cum_log_probs, - int* ids, bool* finished, - int* parent_ids, int* sequence_length, - int* word_ids, int* output_ids, - const int batch_size, const int beam_width, - const int vocab_size, const int end_id, - int* finished_count) -{ - int tid = threadIdx.x; - sequence_length[tid] = finished[tid] ? sequence_length[tid] : sequence_length[tid] + 1; - - int beam_id = ids[tid]; - beam_id /= vocab_size; - int word_id = ids[tid]; - word_id %= vocab_size; - - cum_log_probs[tid] = log_probs[ids[tid]]; - sequence_length[tid] = sequence_length[beam_id]; - finished[tid] = word_id == end_id ? 1 : 0; - parent_ids[tid] = beam_id; - word_ids[tid] = word_id; - output_ids[tid] = word_id; - - // TODO use reduce sum to compute how many sentence are finished - // int fi = finished[tid] - // int total_finish = reduceSum(fi); -} - template __global__ void embedding_lookup_kernel(const T* embedding_table, const int* word_ids, const int hidden_units, T* from_tensor) @@ -421,157 +793,16 @@ __global__ void embedding_lookup_kernel(const T* embedding_table, const int* wor from_tensor[write_pos] = embedding_table[word_ids[blockIdx.x] * hidden_units + threadIdx.x]; } -void update(float* log_probs, float* cum_log_probs, - int* ids, bool* finished, - int* parent_ids, int* sequence_length, - int* word_ids, int* output_ids, - const int batch_size, const int beam_width, - const int vocab_size, cudaStream_t stream, - const int end_id, int* finished_count) -{ - - dim3 grid(1); - dim3 block(batch_size * beam_width); - - assert(block.x <= 1024); - - update_kernel<<>>(log_probs, cum_log_probs, ids, - finished, parent_ids, sequence_length, - word_ids, output_ids, batch_size, - beam_width, vocab_size, end_id, - finished_count); -} - template void embedding_lookup(const T* embedding_table, const int* word_ids, T* from_tensor, const int batch_size, const int beam_width, const int hidden_units, cudaStream_t stream) { - dim3 grid(batch_size * beam_width); - dim3 block(hidden_units); - assert(hidden_units <= 1024); - embedding_lookup_kernel<<>>(embedding_table, word_ids, hidden_units, from_tensor); -} - -template -__global__ void update_logits_kernel(T* logits, const T* bias, const int end_id, const bool* finished, const int n) -{ - int bid = blockIdx.x; - bool finish = finished[bid]; - int offset = bid * n; - - float max_val = -1 * FLT_MAX; - __shared__ float s_max_val; - __shared__ float s_sum_val; - - for(int tid = threadIdx.x; tid < n; tid += blockDim.x) - { - if(finish) - logits[offset + tid] = (tid == end_id) ? FLT_MAX : -1 * FLT_MAX; - else - logits[offset + tid] += bias[tid]; - max_val = max(max_val, logits[offset + tid]); - } - - max_val = blockReduceMax((float)max_val); - if(threadIdx.x == 0) - s_max_val = max_val; - __syncthreads(); - - float sum_val = 0.0f; - for(int tid = threadIdx.x; tid < n; tid += blockDim.x) - { - logits[offset + tid] = __expf((float)logits[offset + tid] - s_max_val); - sum_val += (float)logits[offset + tid]; - } - - sum_val = blockReduceSum(sum_val); - if(threadIdx.x == 0) - s_sum_val = sum_val; - __syncthreads(); - - for(int tid = threadIdx.x; tid < n; tid += blockDim.x) - { - logits[offset + tid] = logf((float)logits[offset + tid] / s_sum_val); - } + dim3 grid(batch_size * beam_width); + dim3 block(hidden_units); + assert(hidden_units <= 1024); + embedding_lookup_kernel<<>>(embedding_table, word_ids, hidden_units, from_tensor); } -void update_logits(float* logits, const float* bias, const int end_id, const bool* finished, - const int m, const int n, cudaStream_t stream) -{ - dim3 grid(m); - dim3 block(min(n, 1024)); - /*n is the vocab_size, e.g., 30000, 7000.... vocab_size is usually very big. */ - update_logits_kernel<<>>(logits, bias, end_id, finished, n); -} - -template -__global__ void init_kernel(bool* finished, int* sequence_length, int* word_ids, T* cum_log_probs, const int sentence_id, const int n, const int beam_width) -{ - int tid = threadIdx.x; - finished[tid] = false; - sequence_length[tid] = 0; - word_ids[tid] = sentence_id; - cum_log_probs[tid] = (T)(tid % beam_width == 0 ? 0.0f: -1e20f); -} - -template -__global__ void update_KV_cache_kernel( - T* key_src_cache, T* key_tgt_cache, - T* value_src_cache, T* value_tgt_cache, - const int* beam_ids, const int batch_size, const int beam_width, const int hidden_dim, const int cache_size, const int step, const int decoder_layers) -{ - int layer_id = blockIdx.x / batch_size / beam_width / step; - int batch_id = (blockIdx.x % (batch_size * beam_width * step)) / (beam_width * step); - int beam_id = (blockIdx.x % (beam_width * step)) / step; - int step_id = blockIdx.x % step; - - int hidden_id = step_id * batch_size * beam_width * hidden_dim + - beam_ids[batch_id * beam_width + beam_id] * hidden_dim; - - int tgt_hidden_id = step_id * batch_size * beam_width * hidden_dim + - batch_id * beam_width * hidden_dim + beam_id * hidden_dim; - - T* key_src_ptr = key_src_cache + layer_id * cache_size; - T* key_tgt_ptr = key_tgt_cache + layer_id * cache_size; - T* value_src_ptr = value_src_cache + layer_id * cache_size; - T* value_tgt_ptr = value_tgt_cache + layer_id * cache_size; - - - for(int tid = threadIdx.x; tid < hidden_dim; tid += blockDim.x) - { - key_tgt_ptr[tgt_hidden_id + tid] = key_src_ptr[hidden_id + tid]; - value_tgt_ptr[tgt_hidden_id + tid] = value_src_ptr[hidden_id + tid]; - } - -} -template -void update_KV_cache(T** key_cache, T** value_cache, const int* beam_ids, const int batch_size, const int beam_width, const int hidden_dim, - const int step, const int cache_size, const int decoder_layers, cudaStream_t stream) -{ - dim3 grid(decoder_layers * batch_size * beam_width * step); - dim3 block(min(1024, hidden_dim)); - - int src_id = step & 0x1; - int tgt_id = 1 - src_id; - - update_KV_cache_kernel<<>>( - key_cache[src_id], key_cache[tgt_id], - value_cache[src_id], value_cache[tgt_id], - beam_ids, batch_size, beam_width, hidden_dim, cache_size, step, decoder_layers); -} - -void init(bool* finished, int* sequence_length, int* word_ids, float* cum_log_probs, const int sentence_id, const int batch_size, - const int beam_width, cudaStream_t stream) -{ - dim3 grid(1); - dim3 block(min(1024, batch_size * beam_width)); - - assert(batch_size * beam_width <= 1024); - - init_kernel<<>>(finished, sequence_length, word_ids, cum_log_probs, sentence_id, batch_size * beam_width, beam_width); -} - - template __global__ void sine_position_encoder_kernel(T* output, int step, int n){ @@ -602,32 +833,12 @@ void sine_position_encoder( sine_position_encoder_kernel<<>>(output, step, n); } -template void add_bias_act_kernelLauncher( - float* out, const float* bias, int m, int n, cudaStream_t stream); - -template void add_bias_input_layernorm_kernelLauncher( - float* out, const float* input, const float* bias, const float* gamma, const float* beta, - int m, int n, cudaStream_t stream); - -template void add_bias_act_kernelLauncher( - half* out, const half* bias, int m, int n, cudaStream_t stream); - -template void add_bias_input_layernorm_kernelLauncher( - half* out, const half* input, const half* bias, const half* gamma, const half* beta, - int m, int n, cudaStream_t stream); - template void embedding_lookup(const float* embedding_table, const int* word_ids, float* from_tensor, const int batch_size, const int beam_width, const int hidden_units, cudaStream_t stream); template void embedding_lookup(const half* embedding_table, const int* word_ids, half* from_tensor, const int batch_size, const int beam_width, const int hidden_units, cudaStream_t stream); -template void update_KV_cache(float** key_cache, float** value_cache, const int* beam_ids, const int batch_size, const int beam_width, const int hidden_dim, - const int step, const int cache_size, const int decoder_layers, cudaStream_t stream); - -template void update_KV_cache(half** key_cache, half** value_cache, const int* beam_ids, const int batch_size, const int beam_width, const int hidden_dim, - const int step, const int cache_size, const int decoder_layers, cudaStream_t stream); - template void sine_position_encoder( float* output, int step, @@ -640,4 +851,6 @@ template void sine_position_encoder( int m, int n, cudaStream_t stream); +/* *************************** end of depreciated kernels *********************************** */ + }//namespace diff --git a/fastertransformer/cuda/cuda_kernels.h b/fastertransformer/cuda/cuda_kernels.h index f127e2818..b4753b314 100644 --- a/fastertransformer/cuda/cuda_kernels.h +++ b/fastertransformer/cuda/cuda_kernels.h @@ -16,9 +16,18 @@ #pragma once #include #include +#include +#include "fastertransformer/arguments.h" +#include "fastertransformer/cuda/topk_kernels.cuh" namespace fastertransformer{ +/* ********************************** common kernel *********************************** */ + +void init_kernelLauncher(bool* finished, int* sequence_length, int* word_ids, + float* cum_log_probs, const int sentence_id, + const int batch_size, const int beam_width, cudaStream_t stream); + template void add_bias_act_kernelLauncher(T* out, const T* bias, int m, int n, cudaStream_t stream); @@ -28,43 +37,125 @@ void add_bias_input_layernorm_kernelLauncher(T* out, const T* input_tensor, const T* beta, int m, int n, cudaStream_t stream); +template +void embedding_lookup_sine_position_encoding_kernel_launcher(T* from_tensor, + const T* embedding_table, + const T* position_encoding_table, + const int* word_ids, + const int batch_size, + const int hidden_units, + cudaStream_t stream); + +template +void remove_sequence_length_padding_kernelLauncher(const T* src, T* tgt, + const int* tmp_mask_offset, + int* mask_offset, const int m, + const int n, cudaStream_t stream); + +template +void rebuild_sequence_length_padding_kernelLauncher(const T* src, T* tgt, + const int* mask_offset, const int m, + const int n, cudaStream_t stream); + +void build_sequence_length_padding_offset_kernelLauncher(const int* sequence_length, + const int batch_size, const int max_seq_len, int* valid_word_num, int* tmp_mask_offset, + cudaStream_t stream); + +/* *************************** end of common kernel *********************************** */ + + +/* ********************************** BeamSearch kernel *********************************** */ + void broadcast_kernelLauncher(float* log_probs, float* cum_log_probs, const int batch_size, const int beam_width, const int vocab_size, cudaStream_t stream); +void update_logits(float* logits, const float* bias, const int end_ids, + const bool* finished, const int m, const int n, + cudaStream_t stream); + +void update_kernelLauncher(float* log_probs, float* cum_log_probs, int* ids, + bool* finished, int* parent_ids, int* sequence_length, + int* word_ids, int* output_ids, + const int batch_size, const int beam_width, + const int vocab_size, cudaStream_t stream, + const int end_id, + int* finished_count); + +void update_kernelLauncher_v2(bool* finished, int* parent_ids, + int* sequence_length, int* word_ids, + int* output_ids, + int* finished_count, + DecodingBeamsearchArguments args, + cudaStream_t stream); + +template +void update_KV_cache_kernelLauncher(T** key_cache, T** value_cache, const int* beam_ids, + const int batch_size, const int beam_width, + const int hidden_dim, const int step, + const int cache_size, const int decoder_layers, + cudaStream_t stream); + +/* *************************** end of BeamSearch kernel *********************************** */ + +/* ********************************** Sampling kernel *********************************** */ + +size_t get_topp_sort_temp_storage_size(const float* log_probs, + const int* id_vals, + float* sorted_log_probs, + int* sorted_id_vals, + int* topp_offset_buf, + const int batch_size, + const int vocab_size); + +void topp_initialization_kernelLauncher(bool* finished, + int* sequence_length, + int* word_ids, + int* topp_id_val_buf, + int* topp_offset_buf, + DecodingSamplingArguments args, + cudaStream_t stream); + +void init_topp_id_val_kernel_kernelLauncher(int* topp_id_val_buf, + int* topp_offset_buf, + const int batch_size, + const int vocab_size, + cudaStream_t stream); + +void update_logits_without_softmax(float* logits, const float* bias, const int end_ids, + const bool* finished, const int m, const int n, + cudaStream_t stream); + +void update_logits_without_log(float* logits, const float* bias, const int end_ids, + const bool* finished, const int m, const int n, + cudaStream_t stream); + +/* *************************** end of Sampling kernel *********************************** */ + +/* *********************************** Debug tools *********************************** */ + +template +void print_first_k(const T* buf, uint size, cudaStream_t stream); + +template +void print_abs_mean(const T* buf, uint size, cudaStream_t stream); + +/* **************************** end of Debug tools *********************************** */ + +/* *************************** depreciated kernels *********************************** */ + void topK(const float* log_probs, int* ids, const int batch_size, const int beam_width, const int vocab_size, cudaStream_t stream); -void update(float* log_probs, float* cum_log_probs, int* ids, - bool* finished, int* parent_ids, int* sequence_length, - int* word_ids, int* output_ids, - const int batch_size, const int beam_width, - const int vocab_size, cudaStream_t stream, - const int end_id, - int* finished_count); - template void embedding_lookup(const T* embedding_table, const int* word_ids, T* from_tensor, const int batch_size, const int beam_width, const int hidden_units, cudaStream_t stream); -void update_logits(float* logits, const float* bias, const int end_ids, - const bool* finished, const int m, const int n, - cudaStream_t stream); - -void init(bool* finished, int* sequence_length, int* word_ids, - float* cum_log_probs, const int sentence_id, - const int batch_size, const int beam_width, cudaStream_t stream); - -template -void update_KV_cache(T** key_cache, T** value_cache, const int* beam_ids, - const int batch_size, const int beam_width, - const int hidden_dim, const int step, - const int cache_size, const int decoder_layers, - cudaStream_t stream); - template void sine_position_encoder(T* output, int step, int m, int n, cudaStream_t stream); +/* ******************** end of depreciated kernels *********************************** */ + }//namespace fastertransformer diff --git a/fastertransformer/cuda/decoding_kernel_check.cpp b/fastertransformer/cuda/decoding_kernel_check.cpp index 4e6a0ed50..ee26f86db 100644 --- a/fastertransformer/cuda/decoding_kernel_check.cpp +++ b/fastertransformer/cuda/decoding_kernel_check.cpp @@ -30,7 +30,7 @@ void init_kernel_check(bool *d_finished, int *d_sequence_length, int *d_word_ids int *h_word_ids = new int[batch_size * beam_width]; float *h_cum_log_probs = new float[batch_size * beam_width]; - init(d_finished, d_sequence_length, d_word_ids, d_cum_log_probs, + init_kernelLauncher(d_finished, d_sequence_length, d_word_ids, d_cum_log_probs, sentence_id, batch_size, beam_width, stream); cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); @@ -346,7 +346,7 @@ void update_kernel_check(float *log_probs, float *cum_log_probs, int *ids, bool check_cuda_error(cudaMemcpy(h_output_ids, output_ids, sizeof(int) * batch_size * beam_width, cudaMemcpyDeviceToHost)); // compute on GPU and copy to GPU output - update(log_probs, cum_log_probs, ids, finished, parent_ids, sequence_length, word_ids, output_ids, + update_kernelLauncher(log_probs, cum_log_probs, ids, finished, parent_ids, sequence_length, word_ids, output_ids, batch_size, beam_width, vocab_size, stream, end_id, finished_count); cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); diff --git a/fastertransformer/cuda/decoding_kernel_check.h b/fastertransformer/cuda/decoding_kernel_check.h index 1e8e903c0..6308a11ef 100644 --- a/fastertransformer/cuda/decoding_kernel_check.h +++ b/fastertransformer/cuda/decoding_kernel_check.h @@ -26,48 +26,6 @@ namespace fastertransformer{ void init_kernel_check(bool* d_finished, int* d_sequence_length, int* d_word_ids, float* d_cum_log_probs, const int sentence_id, const int batch_size, const int beam_width, cudaStream_t stream); -template -void embedding_lookup_kernel_check(const T* embedding_table, const int* word_ids, T* from_tensor, const int batch_size, const int beam_width, - const int hidden_units, const int vocab_size, cudaStream_t stream){ - - printf("[INFO] decoding embedding_lookup check. \n"); - - T *h_embedding_table = new T[vocab_size * hidden_units]; - int *h_word_ids = new int[batch_size * beam_width]; - T *h_from_tensor = new T[batch_size * beam_width * hidden_units]; - - embedding_lookup(embedding_table, word_ids, from_tensor, batch_size, beam_width, hidden_units, stream); - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); - check_cuda_error(cudaMemcpy(h_embedding_table, embedding_table, sizeof(T) * vocab_size * hidden_units, cudaMemcpyDeviceToHost)); - check_cuda_error(cudaMemcpy(h_word_ids, word_ids, sizeof(int) * batch_size * beam_width, cudaMemcpyDeviceToHost)); - check_cuda_error(cudaMemcpy(h_from_tensor, from_tensor, sizeof(T) * batch_size * beam_width * hidden_units, cudaMemcpyDeviceToHost)); - - T *h_from_tensor_cpu = new T[batch_size * beam_width * hidden_units]; - - for(int i = 0; i < batch_size * beam_width; i++){ - const int row_id = h_word_ids[i]; - for(int j = 0; j < hidden_units; j++){ - h_from_tensor_cpu[i * hidden_units + j] = h_embedding_table[row_id * hidden_units + j]; - } - } - - for(int i = 0; i < batch_size * beam_width * hidden_units; i++){ - float diff = (float)(h_from_tensor_cpu[i] - h_from_tensor[i]); - if(diff < 0) diff = diff * -1; - if(diff > 1e-6){ - printf("[ERROR] embedding lookup fail with difference %f. \n", diff); - exit(-1); - } - } - - delete [] h_from_tensor_cpu; - delete [] h_embedding_table; - delete [] h_word_ids; - delete [] h_from_tensor; - printf("[INFO] decoding embedding_lookup check finish. \n"); -} - void update_logits_kernel_check(float* logits, const float* bias, const int end_id, const bool* finished, const int m, const int n, cudaStream_t stream); void broadcast_kernel_check(float* log_probs, float* cum_log_probs, const int batch_size, const int beam_width, @@ -109,7 +67,7 @@ void update_KV_cache_kernel_check(T** key_cache, T** value_cache, const int* bea check_cuda_error(cudaMemcpy(h_beam_ids, beam_ids, sizeof(int) * batch_size * beam_width, cudaMemcpyDeviceToHost)); // compute on GPU and copy the result to CPU - update_KV_cache(key_cache, value_cache, beam_ids, batch_size, beam_width, hidden_dim, step, cache_size, decoder_layers, stream); + update_KV_cache_kernelLauncher(key_cache, value_cache, beam_ids, batch_size, beam_width, hidden_dim, step, cache_size, decoder_layers, stream); cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); check_cuda_error(cudaMemcpy(h_key_cache_tgt_after_update, key_cache[tgt_id], sizeof(T) * cache_size * decoder_layers, cudaMemcpyDeviceToHost)); diff --git a/fastertransformer/cuda/decoding_kernels.cu b/fastertransformer/cuda/decoding_kernels.cu new file mode 100644 index 000000000..1753a72f9 --- /dev/null +++ b/fastertransformer/cuda/decoding_kernels.cu @@ -0,0 +1,443 @@ +/* +* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include "fastertransformer/common.h" + +#include "cuda_kernels.h" +#include "cub/cub.cuh" +#include +#include +#include +#include +#include + +namespace fastertransformer +{ + /* ********************************** common kernel *********************************** */ + + template + __global__ void init_kernel(bool* finished, + int* sequence_length, + int* word_ids, + T* cum_log_probs, + const int sentence_id, + const int n, + const int beam_width) + { + int tid = threadIdx.x; + finished[tid] = false; + sequence_length[tid] = 0; + word_ids[tid] = sentence_id; + cum_log_probs[tid] = (T)(tid % beam_width == 0 ? 0.0f: -1e20f); + } + + void init_kernelLauncher(bool* finished, + int* sequence_length, + int* word_ids, + float* cum_log_probs, + const int sentence_id, + const int batch_size, + const int beam_width, + cudaStream_t stream) + { + dim3 grid(1); + dim3 block(min(1024, batch_size * beam_width)); + assert(batch_size * beam_width <= 1024); + + init_kernel<<>>(finished, + sequence_length, + word_ids, + cum_log_probs, + sentence_id, + batch_size * beam_width, + beam_width); + } + + template + __global__ void embedding_lookup_sine_position_encoding_kernel(T* from_tensor, + const T* embedding_table, + const T* position_encoding, + const int* word_ids, + const int hidden_units) + { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int write_pos = tid + bid * blockDim.x; + // 1. lookup the table + // 2. multiply hidden_dim**0.5 + // 3. add the position encoding + from_tensor[write_pos] = embedding_table[word_ids[bid] * hidden_units + tid] * + (T)sqrtf(float(hidden_units)) + position_encoding[tid]; + } + + template + void embedding_lookup_sine_position_encoding_kernel_launcher(T* from_tensor, + const T* embedding_table, + const T* position_encoding, + const int* word_ids, + const int batch_size, + const int hidden_units, + cudaStream_t stream) + { + assert(hidden_units <= 1024); + dim3 grid(batch_size); + dim3 block(hidden_units); + embedding_lookup_sine_position_encoding_kernel<<>>(from_tensor, + embedding_table, + position_encoding, + word_ids, + hidden_units); + } + + /* *************************** end of common kernel *********************************** */ + + /* ********************************** BeamSearch kernel *********************************** */ + + template + __global__ + void broadcast_kernel(T* log_probs, + T* cum_log_probs, + const int vocab_size, + const int N) + { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int bid = tid / vocab_size; + + if(tid < N) + log_probs[tid] += cum_log_probs[bid]; +} + + void broadcast_kernelLauncher(float* log_probs, + float* cum_log_probs, + const int batch_size, + const int beam_width, + const int vocab_size, + cudaStream_t stream) + { + + int N = batch_size * beam_width * vocab_size; + dim3 block(1024); + dim3 grid((N - 1) / block.x + 1); + + broadcast_kernel<<>>(log_probs, cum_log_probs, vocab_size, N); + } + + template + __global__ + void update_kernel(T* log_probs, T* cum_log_probs, + int* ids, bool* finished, + int* parent_ids, int* sequence_length, + int* word_ids, int* output_ids, + const int batch_size, const int beam_width, + const int vocab_size, const int end_id, + int* finished_count) + { + int tid = threadIdx.x; + sequence_length[tid] = finished[tid] ? sequence_length[tid] : sequence_length[tid] + 1; + + int beam_id = word_ids[tid] / vocab_size; + int word_id = word_ids[tid] % vocab_size; + + cum_log_probs[tid] = log_probs[word_ids[tid]]; + sequence_length[tid] = sequence_length[beam_id]; + finished[tid] = word_id == end_id ? 1 : 0; + parent_ids[tid] = beam_id; + word_ids[tid] = word_id; + output_ids[tid] = word_id; + } + + void update_kernelLauncher(float* log_probs, float* cum_log_probs, + int* ids, bool* finished, + int* parent_ids, int* sequence_length, + int* word_ids, int* output_ids, + const int batch_size, const int beam_width, + const int vocab_size, cudaStream_t stream, + const int end_id, int* finished_count) + { + dim3 grid(1); + dim3 block(batch_size * beam_width); + + assert(block.x <= 1024); + + update_kernel<<>>(log_probs, cum_log_probs, ids, + finished, parent_ids, sequence_length, + word_ids, output_ids, batch_size, + beam_width, vocab_size, end_id, + finished_count); + } + + template + __global__ + void update_kernel_v2(bool* finished, int* parent_ids, + int* sequence_length, + int* word_ids, int* output_ids, + const int vocab_size, const int end_id, + int* finished_count) + { + int tid = threadIdx.x; + sequence_length[tid] = finished[tid] ? sequence_length[tid] : sequence_length[tid] + 1; + + int beam_id = word_ids[tid] / vocab_size; + int word_id = word_ids[tid] % vocab_size; + + sequence_length[tid] = sequence_length[beam_id]; + finished[tid] = word_id == end_id ? 1 : 0; + parent_ids[tid] = beam_id; + word_ids[tid] = word_id; + output_ids[tid] = word_id; + } + + void update_kernelLauncher_v2(bool* finished, int* parent_ids, + int* sequence_length, int* word_ids, + int* output_ids, + int* finished_count, + DecodingBeamsearchArguments args, + cudaStream_t stream) + { + dim3 grid(1); + dim3 block(args.batch_size_ * args.beam_width_); + assert(block.x <= 1024); + + update_kernel_v2<<>>(finished, parent_ids, + sequence_length, word_ids, + output_ids, args.vocab_size_, + args.end_id_, finished_count); + } + + template + __global__ void update_KV_cache_kernel(const T* __restrict key_src_cache, + T* key_tgt_cache, + const T* __restrict value_src_cache, + T* value_tgt_cache, + const int* beam_ids, + const int batch_size, + const int beam_width, + const int hidden_dim, + const int cache_size, + const int step, + const int decoder_layers) + { + int layer_id = blockIdx.x / batch_size / beam_width / step; + int batch_id = (blockIdx.x % (batch_size * beam_width * step)) / (beam_width * step); + int beam_id = (blockIdx.x % (beam_width * step)) / step; + int step_id = blockIdx.x % step; + + int hidden_id = step_id * batch_size * beam_width * hidden_dim + + beam_ids[batch_id * beam_width + beam_id] * hidden_dim; + + int tgt_hidden_id = step_id * batch_size * beam_width * hidden_dim + + batch_id * beam_width * hidden_dim + beam_id * hidden_dim; + + const T* key_src_ptr = key_src_cache + layer_id * cache_size; + T* key_tgt_ptr = key_tgt_cache + layer_id * cache_size; + const T* value_src_ptr = value_src_cache + layer_id * cache_size; + T* value_tgt_ptr = value_tgt_cache + layer_id * cache_size; + + + for(int tid = threadIdx.x; tid < hidden_dim; tid += blockDim.x) + { + key_tgt_ptr[tgt_hidden_id + tid] = key_src_ptr[hidden_id + tid]; + value_tgt_ptr[tgt_hidden_id + tid] = value_src_ptr[hidden_id + tid]; + } + + } + + template <> + __global__ void update_KV_cache_kernel(const half* __restrict key_src_cache, + half* key_tgt_cache, + const half* __restrict value_src_cache, + half* value_tgt_cache, + const int* beam_ids, + const int batch_size, + const int beam_width, + const int hidden_dim, + const int cache_size, + const int step, + const int decoder_layers) + { + int layer_id = blockIdx.x / batch_size / beam_width / step; + int batch_id = (blockIdx.x % (batch_size * beam_width * step)) / (beam_width * step); + int beam_id = (blockIdx.x % (beam_width * step)) / step; + int step_id = blockIdx.x % step; + + int hidden_id = (step_id * batch_size * beam_width * hidden_dim + + beam_ids[batch_id * beam_width + beam_id] * hidden_dim) / 2; + + int tgt_hidden_id = (step_id * batch_size * beam_width * hidden_dim + + batch_id * beam_width * hidden_dim + beam_id * hidden_dim) / 2; + + const half2* key_src_ptr = (const half2*)key_src_cache + layer_id * cache_size / 2; + half2* key_tgt_ptr = (half2*)key_tgt_cache + layer_id * cache_size / 2; + const half2* value_src_ptr = (const half2*)value_src_cache + layer_id * cache_size / 2; + half2* value_tgt_ptr = (half2*)value_tgt_cache + layer_id * cache_size / 2; + + for(int tid = threadIdx.x; tid < hidden_dim / 2; tid += blockDim.x) + { + key_tgt_ptr[tgt_hidden_id + tid] = key_src_ptr[hidden_id + tid]; + value_tgt_ptr[tgt_hidden_id + tid] = value_src_ptr[hidden_id + tid]; + } + + } + + template + void update_KV_cache_kernelLauncher(T** key_cache, + T** value_cache, + const int* beam_ids, + const int batch_size, + const int beam_width, + const int hidden_dim, + const int step, + const int cache_size, + const int decoder_layers, + cudaStream_t stream) + { + dim3 grid(decoder_layers * batch_size * beam_width * step); + dim3 block(min(1024, hidden_dim)); + block.x = block.x / (4 / sizeof(T)); + + int src_id = step & 0x1; + int tgt_id = 1 - src_id; + + update_KV_cache_kernel<<>>( + key_cache[src_id], key_cache[tgt_id], + value_cache[src_id], value_cache[tgt_id], + beam_ids, batch_size, beam_width, hidden_dim, cache_size, step, decoder_layers); + } + + /* *************************** end of BeamSearch kernel *********************************** */ + + /* ********************************** Sampling kernel *********************************** */ + __global__ void topp_initialization_kernel(bool* finished, + int* sequence_length, + int* word_ids, + int* topp_id_val_buf, + int* topp_offset_buf, + const int batch_size, + const int vocab_size, + const int start_id) + { + int tid = threadIdx.x; + int bid = blockIdx.x; + + if(bid == 0) + { + for(int i = tid; i < batch_size + 1; i+= blockDim.x) + { + topp_offset_buf[i] = i * vocab_size; + } + + for(int i = tid; i < batch_size; i+= blockDim.x) + { + finished[i] = false; + sequence_length[i] = 0; + word_ids[i] = start_id; + } + } + + int index = tid + bid * blockDim.x; + while(index < batch_size * vocab_size) + { + topp_id_val_buf[index] = index % vocab_size; + index += blockDim.x * gridDim.x; + } + } + + void topp_initialization_kernelLauncher(bool* finished, + int* sequence_length, + int* word_ids, + int* topp_id_val_buf, + int* topp_offset_buf, + DecodingSamplingArguments args, + cudaStream_t stream) + { + topp_initialization_kernel<<<32, 512, 0, stream>>>(finished, + sequence_length, + word_ids, + topp_id_val_buf, + topp_offset_buf, + args.batch_size_, + args.vocab_size_, + args.start_id_); + } + + size_t get_topp_sort_temp_storage_size(const float* log_probs, + const int* id_vals, + float* sorted_log_probs, + int* sorted_id_vals, + int* topp_offset_buf, + const int batch_size, + const int vocab_size) + { + void *d_temp_storage = NULL; + size_t temp_storage_bytes = 0; + + cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, + temp_storage_bytes, + log_probs, + sorted_log_probs, + id_vals, + sorted_id_vals, + vocab_size * batch_size, + batch_size, + topp_offset_buf, topp_offset_buf + 1); + return temp_storage_bytes; + } + /* *************************** end of Sampling kernel *********************************** */ + + /* ********************************** Instantiation *********************************** */ + template + void embedding_lookup_sine_position_encoding_kernel_launcher(float* from_tensor, + const float* embedding_table, + const float* position_encoding, + const int* word_ids, + const int batch_size, + const int hidden_units, + cudaStream_t stream); + + template + void embedding_lookup_sine_position_encoding_kernel_launcher(half* from_tensor, + const half* embedding_table, + const half* position_encoding, + const int* word_ids, + const int batch_size, + const int hidden_units, + cudaStream_t stream); + + template void update_KV_cache_kernelLauncher(float** key_cache, + float** value_cache, + const int* beam_ids, + const int batch_size, + const int beam_width, + const int hidden_dim, + const int step, + const int cache_size, + const int decoder_layers, + cudaStream_t stream); + + template void update_KV_cache_kernelLauncher(half** key_cache, + half** value_cache, + const int* beam_ids, + const int batch_size, + const int beam_width, + const int hidden_dim, + const int step, + const int cache_size, + const int decoder_layers, + cudaStream_t stream); + + /* *************************** end of Instantiation *********************************** */ + +} // end of name space fastertransformer \ No newline at end of file diff --git a/fastertransformer/cuda/multi_head_attention.h b/fastertransformer/cuda/multi_head_attention.h index 983cbb146..975b0f105 100644 --- a/fastertransformer/cuda/multi_head_attention.h +++ b/fastertransformer/cuda/multi_head_attention.h @@ -33,6 +33,9 @@ class MultiHeadInitParam{ AttentionWeight self_attention; const T* attr_mask; T* attr_out; + + const int* sequence_id_offset; + int valid_word_num; cublasHandle_t cublas_handle; cudaStream_t stream; MultiHeadInitParam(){ @@ -41,6 +44,7 @@ class MultiHeadInitParam{ attr_mask = nullptr; attr_out = nullptr; cublas_handle = nullptr; + sequence_id_offset = nullptr; stream = 0; } }; diff --git a/fastertransformer/cuda/online_softmax_beamsearch_kernels.cu b/fastertransformer/cuda/online_softmax_beamsearch_kernels.cu new file mode 100644 index 000000000..74727ac19 --- /dev/null +++ b/fastertransformer/cuda/online_softmax_beamsearch_kernels.cu @@ -0,0 +1,536 @@ +/* +* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include "fastertransformer/cuda/topk_kernels.cuh" +#include "cub/cub.cuh" + +namespace fastertransformer +{ + +template +__launch_bounds__(THREADBLOCK_SIZE) +__global__ +void batch_topK_kernel(const int* __restrict topk_tmp_id_buf, + const T* __restrict topk_tmp_val_buf, + int* __restrict id_buf, + T* __restrict val_buf) +{ + int thread_id = threadIdx.x; + int block_id = blockIdx.x; + TopK partial; + if (thread_id == 0) + { + for(int i = 0; i < MAX_K; ++i) + { + partial.p[i] = -1; + partial.u[i] = -FLT_MAX; + } + + int index = block_id * MAX_K * MAX_K; + for(int i = 0; i < MAX_K * MAX_K; i++) + { + partial.insert( (T)topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]); + } + + index = block_id * MAX_K; + for(int i = 0; i < MAX_K; i++) + { + id_buf[index + i] = partial.p[i]; + val_buf[index + i] = partial.u[i]; + } + } +} + + +template +__launch_bounds__(THREADBLOCK_SIZE) +__global__ void batch_topk_kernel( + const int * __restrict x, + const T * __restrict y, + int * __restrict z, + T * __restrict v, + int V, + int K, + T diversity_rate) +{ + int thread_id = threadIdx.x; + int vector_id = blockIdx.x; + + // reposition x, y to data for the current vector + x += vector_id * V; + y += vector_id * V; + + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + + __shared__ typename BlockReduce::TempStorage temp_storage; + + TopK partial; + for(int i = 0; i < MAX_K; ++i) + { + partial.p[i] = -1; + partial.u[i] = -FLT_MAX; + } + for(int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE) + { + int i = elem_id % K; + T elem = y[elem_id] + diversity_rate * (T) i; + int elem_idx = elem_id; //x[elem_id]; + partial.insert(elem, elem_idx); + } + + TopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); + + if (thread_id == 0) + { + z += vector_id * K; + v += vector_id * K; + + for(int i = 0; i < MAX_K; ++i) + { + if (i < K) + { + z[i] = x[total.p[i]]; + v[i] = y[total.p[i]]; + } + } + } +} + +struct __align__(8) MD +{ + float m; + float d; +}; + +__device__ __forceinline__ MD reduce_md_op(MD a, MD b) +{ + bool a_bigger = (a.m > b.m); + MD bigger_m = a_bigger ? a : b; + MD smaller_m = a_bigger ? b : a; + MD res; + res.d = bigger_m.d + smaller_m.d * __expf(smaller_m.m - bigger_m.m); + res.m = bigger_m.m; + return res; +} + +template +struct TopKMD +{ + MD md; + TopK topk; +}; + +template +__device__ __forceinline__ TopKMD reduce_topk_md_op(const TopKMD& a, const TopKMD& b) +{ + TopKMD res; + res.md = reduce_md_op(a.md, b.md); + res.topk = reduce_topk_op(a.topk, b.topk); + return res; +} + +template +__launch_bounds__(THREADBLOCK_SIZE) +__global__ void beam_online_softmax_topk_kernel( + const T * __restrict x, + const float * __restrict b, + const T * __restrict c, + const bool * __restrict finished, + int * __restrict z, + T * __restrict v, + int V, + int K, + int E) +{ + int thread_id = threadIdx.x; + int vector_id = blockIdx.x; + + // reposition y to data for the current vector + x += vector_id * V; + + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + TopKMD partial; + bool finish = finished[vector_id]; + for(int i = 0; i < MAX_K; ++i) + { + partial.topk.p[i] = -1; + partial.topk.u[i] = -FLT_MAX; + } + partial.md.m = -FLT_MAX; + partial.md.d = 0.0F; + + if (finish) + { + for(int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE) + { + float elem = (elem_id == E) ? FLT_MAX : -FLT_MAX; + MD new_elem{elem, 1.0F}; + partial.md = reduce_md_op(partial.md, new_elem); + partial.topk.insert(elem, elem_id); + //if (elem_id > THREADBLOCK_SIZE * MAX_K && (elem_id == E)) break; + } + } + else + { + for(int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE) + { + float elem = x[elem_id] + b[elem_id]; + MD new_elem{elem, 1.0F}; + partial.md = reduce_md_op(partial.md, new_elem); + partial.topk.insert(elem, elem_id); + } + } + + TopKMD total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_md_op); + + if (thread_id == 0) + { + z += vector_id * K; + v += vector_id * K; + c += vector_id; + + //float d_total_inverse = __fdividef(1.0F, total.md.d); + float d_total_log = logf(total.md.d); + for(int i = 0; i < MAX_K; ++i) + { + //float val = __expf(total.topk.u[i] - total.md.m) * d_total_inverse; + float val = total.topk.u[i] - total.md.m - d_total_log; + if (i < K) + { + z[i] = total.topk.p[i] + vector_id * V; // faster transformer needs absolute id + v[i] = val + c[0]; + } + } + } +} + +template +__launch_bounds__(THREADBLOCK_SIZE) +__global__ void beam_online_softmax_topk_stage1_kernel( + const T * __restrict x, + const float * __restrict b, + const bool * __restrict finished, + float * __restrict t, + int V, + int K, + int E) +{ + int thread_id = threadIdx.x; + int vector_id = blockIdx.x; + + const int PACKED_TOP_KMD_SIZE = 2 * MAX_K + 2; + + // one will have multiple sections per V + const int v_local = (V + gridDim.y - 1) / gridDim.y; + const int section_start = v_local * blockIdx.y; + int section_end = section_start + v_local; + section_end = (section_end > V)? V : section_end; + + // reposition x to data for the current vector + x += vector_id * V; + + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ float buf_s[PACKED_TOP_KMD_SIZE]; // save intermediate result + + TopKMD partial; + bool finish = finished[vector_id]; + for(int i = 0; i < MAX_K; ++i) + { + partial.topk.p[i] = -1; + partial.topk.u[i] = -FLT_MAX; + } + partial.md.m = -FLT_MAX; + partial.md.d = 0.0F; + + if (finish) + { + for(int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) + { + float elem = (elem_id == E) ? FLT_MAX : -FLT_MAX; + MD new_elem{elem, 1.0F}; + partial.md = reduce_md_op(partial.md, new_elem); + partial.topk.insert(elem, elem_id); + } + } + else + { + for(int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) + { + T elem = x[elem_id] + b[elem_id]; + MD new_elem{elem, 1.0F}; + partial.md = reduce_md_op(partial.md, new_elem); + partial.topk.insert(elem, elem_id); + } + } + + TopKMD total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_md_op); + + if (thread_id == 0) + { + for (int i = 0; i < K; i++) + { + reinterpret_cast(buf_s)[i] = total.topk.p[i] + vector_id * V; // faster transformer needs absolute id + buf_s[MAX_K + i] = total.topk.u[i]; + } + buf_s[2 * MAX_K] = total.md.d; + buf_s[2 * MAX_K + 1] = total.md.m; + } + __syncthreads(); + if (threadIdx.x < PACKED_TOP_KMD_SIZE) + { + t[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + threadIdx.x] = buf_s[threadIdx.x]; + } +} + +template +__launch_bounds__(THREADBLOCK_SIZE) +__global__ void beam_online_softmax_topk_stage2_kernel( + const float * __restrict x, + const T * __restrict c, + int * __restrict z, + T * __restrict v, + int K, + int parts_per_beam) +{ + const int vector_id = blockIdx.x; + const int thread_id = threadIdx.x; + const int PACKED_TOP_KMD_SIZE = 2 * MAX_K + 2; + + extern __shared__ char buf_s_[]; // intermediate result + float * buf_s = reinterpret_cast(buf_s_); + //__shared__ float buf_s[PACKED_TOP_KMD_SIZE * THREADBLOCK_SIZE]; // intermediate result + + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + x += vector_id * PACKED_TOP_KMD_SIZE * parts_per_beam; + + TopKMD partial; + for(int i = 0; i < MAX_K; ++i) + { + partial.topk.p[i] = -1; + partial.topk.u[i] = -FLT_MAX; + } + partial.md.m = -FLT_MAX; + partial.md.d = 0.0F; + + // load and unpack into registers through smem + for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * parts_per_beam; idx += THREADBLOCK_SIZE) + { + buf_s[idx] = x[idx]; + } + __syncthreads(); + + if (threadIdx.x < parts_per_beam) + { + float * b_s = buf_s + thread_id * PACKED_TOP_KMD_SIZE; + for (int i = 0; i < K; i++) + { + partial.topk.p[i] = reinterpret_cast(b_s)[i]; + partial.topk.u[i] = b_s[MAX_K + i]; + } + partial.md.d = b_s[2 * MAX_K]; + partial.md.m = b_s[2 * MAX_K + 1]; + } + __syncthreads(); + + TopKMD total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_md_op); + + if (thread_id == 0) + { + z += vector_id * K; + v += vector_id * K; + c += vector_id; + + float d_total_log = logf(total.md.d); + for(int i = 0; i < MAX_K; ++i) + { + float val = total.topk.u[i] - total.md.m - d_total_log; + if (i < K) + { + z[i] = total.topk.p[i]; + v[i] = val + c[0]; + } + } + } +} + +template +void beam_online_softmax_topk_stage2_kernelLauncher( + const float * temp_storage, + const T * cum_log_probs, + int * ids, + T * vals, + int batch_size, + int beam_width, + int parts_per_beam, + cudaStream_t stream) +{ + // might rewrite beam_online_softmax_topk_stage2_kernel no to depend on constant block size + // in oreder to reduce compilation time + int smem_stage2_size = parts_per_beam * (2 * MAX_K + 2) * sizeof(T); + + if (parts_per_beam <= 32) + { + beam_online_softmax_topk_stage2_kernel + <<>> + (temp_storage, cum_log_probs, ids, vals, + beam_width, parts_per_beam); + return; + } + if (parts_per_beam <= 64) + { + beam_online_softmax_topk_stage2_kernel + <<>> + (temp_storage, cum_log_probs, ids, vals, + beam_width, parts_per_beam); + return; + } + if (parts_per_beam <= 128) + { + beam_online_softmax_topk_stage2_kernel + <<>> + (temp_storage, cum_log_probs, ids, vals, + beam_width, parts_per_beam); + return; + } + assert(0); +} + +template +void topK_softMax_kernelLauncher(const T* log_probs, + const float* bias, + const bool* finished, + T* cum_log_probs, + int* ids, + void* temp_storage, + const int temp_storage_size, + const int batch_size, + const int beam_width, + const int vocab_size, + const int end_id, + T diversity_rate, + cudaStream_t stream) +{ + const int items_per_thread = 1; + const int block_sz = SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE; + + assert(temp_storage_size % 2 == 0); + assert(temp_storage_size >= 2 * batch_size * beam_width * beam_width); + + int* topk_tmp_id_buf = reinterpret_cast(temp_storage); + T* topk_tmp_val_buf = reinterpret_cast(topk_tmp_id_buf + batch_size * beam_width * beam_width); + float* tmp_buffer = reinterpret_cast(topk_tmp_val_buf + batch_size * beam_width * beam_width); + +#ifdef DO_SPLIT_SMALL_TOP_K_SOFTMAX + int voc_parts = 4; + if (batch_size * beam_width < 256) + { + voc_parts = (256 + batch_size * beam_width - 1) / (batch_size * beam_width); + voc_parts = std::min(128, voc_parts); // we implment up to 128 + } + dim3 grid(batch_size * beam_width, voc_parts); + beam_online_softmax_topk_stage1_kernel + <<>> + (log_probs, bias, finished, tmp_buffer, + vocab_size, beam_width, end_id); +#endif + if (beam_width > 1) + { +#ifdef DO_SPLIT_SMALL_TOP_K_SOFTMAX + beam_online_softmax_topk_stage2_kernelLauncher + (tmp_buffer, cum_log_probs, topk_tmp_id_buf, topk_tmp_val_buf, + batch_size, beam_width, voc_parts, stream); +#else + beam_online_softmax_topk_kernel + <<>> + (log_probs, bias, cum_log_probs, finished, topk_tmp_id_buf, + topk_tmp_val_buf, vocab_size, beam_width, end_id); +#endif +#if 0 + // wrong result with diversity_rate != 0.f + batch_topK_kernel<<>> + (topk_tmp_id_buf, topk_tmp_val_buf, ids, cum_log_probs); +#else + batch_topk_kernel<<>> + (topk_tmp_id_buf, topk_tmp_val_buf, + ids, cum_log_probs, beam_width * beam_width, beam_width, diversity_rate); +#endif + } + else + { +#ifdef DO_SPLIT_SMALL_TOP_K_SOFTMAX + beam_online_softmax_topk_stage2_kernelLauncher + (tmp_buffer, cum_log_probs, ids, cum_log_probs, + batch_size, beam_width, voc_parts, stream); +#else + beam_online_softmax_topk_kernel + <<>> + (log_probs, bias, cum_log_probs, finished, ids, + cum_log_probs, vocab_size, beam_width, end_id); +#endif + } +} + +#define CASE_K(K) \ + case K : \ + topK_softMax_kernelLauncher \ + (log_probs, bias, finished, cum_log_probs, ids, temp_storage, temp_storage_size, \ + batch_size, beam_width, vocab_size, end_id, diversity_rate, stream); \ + break; \ + +template +void topK_softMax(const T* log_probs, + const float* bias, + const bool* finished, + T* cum_log_probs, + int* ids, + void* temp_storage, + DecodingBeamsearchArguments args, + cudaStream_t stream) +{ + const int temp_storage_size = args.temp_storage_size_; + const int batch_size = args.batch_size_; + const int beam_width = args.beam_width_; + const int vocab_size = args.vocab_size_; + const int end_id = args.end_id_; + const T diversity_rate = args.beam_search_diversity_rate_; + + switch(beam_width) + { + CASE_K(1); + CASE_K(2); + CASE_K(4); + default : + printf("[ERROR] Topk kernel does not support beamwidth = %d \n", beam_width); + exit(0); + break; + } +} +#undef CASE_K + +template void topK_softMax(const float* log_probs, + const float* bias, + const bool* finished, + float* cum_log_probs, + int* ids, + void * tmp_storage, + DecodingBeamsearchArguments args, + cudaStream_t stream); +} // end of namespace fastertransformer \ No newline at end of file diff --git a/fastertransformer/cuda/open_attention.cu b/fastertransformer/cuda/open_attention.cu index 4f2f3531e..293e86085 100644 --- a/fastertransformer/cuda/open_attention.cu +++ b/fastertransformer/cuda/open_attention.cu @@ -1,21 +1,21 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ /** - * Open sourced multi-head attention - **/ +* Open sourced multi-head attention +**/ #include "fastertransformer/allocator.h" #include "fastertransformer/cuda/multi_head_attention.h" @@ -27,14 +27,15 @@ namespace fastertransformer{ namespace cuda{ /** - * Multi-head attetion open sourced - */ +* Multi-head attetion open sourced +*/ #define FINAL_MASK 0xffffffff template __inline__ __device__ T warpReduceSum(T val) { + #pragma unroll for(int mask = 16; mask > 0; mask >>= 1) val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); return val; @@ -66,6 +67,7 @@ template __inline__ __device__ T warpReduceMax(T val) { + #pragma unroll for(int mask = 16; mask > 0; mask >>= 1) val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); return val; @@ -172,7 +174,6 @@ void add_QKV_bias(half* Q, const half* bias_Q, half* K, const half* bias_K, half half2* src_ptr = (half2*)Q; half2* dst_ptr = (half2*)q_buf_; const half2* bias_ptr = (const half2*)bias_Q; - dst_ptr[target_id] = __hadd2(src_ptr[tid], __ldg(&bias_ptr[bias_id])); src_ptr = (half2*)K; @@ -186,6 +187,31 @@ void add_QKV_bias(half* Q, const half* bias_Q, half* K, const half* bias_K, half dst_ptr[target_id] = __hadd2(src_ptr[tid], __ldg(&bias_ptr[bias_id])); } +template +__global__ +void add_QKV_bias_rebuild_padding(T* Q, const T* bias_Q, T* K, const T* bias_K, T* V, const T* bias_V, T* q_buf_, T* k_buf_, T* v_buf_, + const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int* mask_offset) +{ + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int bdim = blockDim.x; + + const int tgt_batch_id = (bid + mask_offset[bid]) / seq_len; + const int tgt_seq_id = (bid + mask_offset[bid]) % seq_len; + const int tgt_head_id = tid / size_per_head; + const int tgt_hidden_id = tid % size_per_head; + + const int src_id = bid * bdim + tid; + const int tgt_id = tgt_batch_id * head_num * seq_len * size_per_head + \ + tgt_head_id * seq_len * size_per_head + \ + tgt_seq_id * size_per_head + \ + tgt_hidden_id; + + q_buf_[tgt_id] = Q[src_id] + bias_Q[tid]; + k_buf_[tgt_id] = K[src_id] + bias_K[tid]; + v_buf_[tgt_id] = V[src_id] + bias_V[tid]; +} + template __global__ void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, @@ -267,6 +293,143 @@ void softmax_kernel_v2(T* qk_buf_, const T* attr_mask, const int batch_size, con qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / s_sum); } +//grid = (seq_len, batch_size, head_num) +//block.x = max(32, (seq_len + 31)/32*32) +template +__global__ +void softmax_kernel_v3(T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const T scalar) +{ + + float tmp = -1e20f; + int qk_offset; + __shared__ float s_mean, s_max; + if (threadIdx.x < seq_len){ + qk_offset = ((blockIdx.y*head_num + blockIdx.z)*seq_len + blockIdx.x) *seq_len + threadIdx.x; + int mask_offset = (blockIdx.y * seq_len + blockIdx.x) * seq_len + threadIdx.x; + + float qk = static_cast(qk_buf_[qk_offset]); + float mask_val = static_cast(__ldg(&attr_mask[mask_offset])); + + mask_val = (1.0f - mask_val) * -10000.0f; + + tmp = qk * static_cast(scalar) + mask_val; + } + + float max_val = blockReduceMax(tmp); + if (threadIdx.x == 0){ + s_max = max_val; + } + __syncthreads(); + + float qk_tmp = threadIdx.x < seq_len ? __expf(tmp - s_max) : 0.0f; + float sum_val = blockReduceSum(qk_tmp); + if (threadIdx.x == 0){ + s_mean = sum_val + 1e-6f; + s_mean = __fdividef(1.0f, s_mean); + } + __syncthreads(); + + if(threadIdx.x < seq_len) + qk_buf_[qk_offset] = (T)(qk_tmp * s_mean); +} + + +//grid = (seq_len, batch_size, head_num) +//block.x = max(32, (seq_len/2 + 31)/32*32) +//seq_len % 2 == 0 +template <> +__global__ +void softmax_kernel_v3(half* qk_buf_, const half* attr_mask, + const int batch_size, const int head_num, + const int seq_len, const half scalar) +{ + half2* qk_buf_half2Ptr = (half2*) qk_buf_; + const half2* attr_mask_half2Ptr = (const half2*) attr_mask; + + int qk_offset; + int threadIdx2 = threadIdx.x << 1; + __shared__ float s_mean, s_max; + half2 tmp = __float2half2_rn(0.0f); + + float max_val = -1e20f; + half2 qk; + if (threadIdx2 < seq_len){ + qk_offset = ((((blockIdx.y*head_num + blockIdx.z)*seq_len + blockIdx.x) *seq_len) >> 1) + threadIdx.x; + int mask_offset = (((blockIdx.y * seq_len + blockIdx.x) * seq_len) >> 1) + threadIdx.x; + + qk = qk_buf_half2Ptr[qk_offset]; + half2 mask_val = __ldg(&attr_mask_half2Ptr[mask_offset]); + half2 mask_val_tmp = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val), __float2half2_rn(-10000.0f)); + tmp = __hadd2(__hmul2(__half2half2(scalar), qk), mask_val_tmp); + max_val = fmax((float)tmp.x, (float)tmp.y); + } + + max_val = blockDim.x <= 32 ? warpReduceMax(max_val) : blockReduceMax(max_val); + + if (threadIdx.x == 0){ + s_max = max_val; + } + __syncthreads(); + + if (threadIdx2 < seq_len){ + tmp = h2exp(__hsub2(tmp, __float2half2_rn(s_max))); + } + float sum_val = blockDim.x <= 32 ? warpReduceSum((float)(tmp.x + tmp.y)) : blockReduceSum((float)(tmp.x + tmp.y)); + + if (threadIdx.x == 0){ + s_mean = sum_val + 1e-6f; + s_mean = __fdividef(1.0f, s_mean); + } + __syncthreads(); + + if(threadIdx2 < seq_len){ + qk = __hmul2(tmp, __float2half2_rn(s_mean)); + qk_buf_half2Ptr[qk_offset] = qk; + } +} + +//grid = (seq_len, batch_size, head_num) +//block.x = max(32, (seq_len + 31)/32*32) +//for seq_len not larger than 32 +template +__global__ +void softmax_kernel_v3_LE32(T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const T scalar) +{ + + int qk_offset; + __shared__ float s_mean, s_max; + float tmp = -1e20f; + if (threadIdx.x < seq_len){ + qk_offset = ((blockIdx.y*head_num + blockIdx.z)*seq_len + blockIdx.x) *seq_len + threadIdx.x; + int mask_offset = (blockIdx.y * seq_len + blockIdx.x) * seq_len + threadIdx.x; + + float qk = static_cast(qk_buf_[qk_offset]); + float mask_val = static_cast(__ldg(&attr_mask[mask_offset])); + + mask_val = (1.0f - mask_val) * -10000.0f; + + tmp = static_cast(qk) * static_cast(scalar) + mask_val; + } + float max_val = warpReduceMax(tmp); + + if (threadIdx.x == 0){ + s_max = max_val; + } + __syncthreads(); + + tmp = threadIdx.x < seq_len ? __expf(tmp - s_max) : 0.0f; + float sum_val = warpReduceSum(tmp); + + if (threadIdx.x == 0){ + s_mean = sum_val + 1e-6f; + s_mean = __fdividef(1.0f, s_mean); + } + __syncthreads(); + + if(threadIdx.x < seq_len) + qk_buf_[qk_offset] = (T)(tmp * s_mean); +} + template __global__ void transpose(T* src, T* dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head) @@ -297,6 +460,44 @@ void transpose(half* src, half* dst, dst_ptr[target_id] = src_ptr[tid]; } + +template +__global__ +void transpose_rebuild_padding(T* src, T* dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head, + const int* mask_offset) +{ + // TODO: optimize this kernel? + // do remove_sequence_length_padding + const int tid = threadIdx.x; // batch * seq_len or valid_word_num + const int bid = blockIdx.x; // head_num * size_per_head + + const int src_batch_id = (bid + mask_offset[bid]) / seq_len; + const int src_seq_id = (bid + mask_offset[bid]) % seq_len; + + const int dst_seq_id = bid; + + const int head_id = tid / size_per_head; + const int hidden_id = tid % size_per_head; + dst[dst_seq_id * head_num * size_per_head + tid] = src[ src_batch_id * head_num * seq_len * size_per_head + + head_id * seq_len * size_per_head + src_seq_id * size_per_head + hidden_id]; +} + +template +__global__ void rebuild_sequence_length_padding(const T* src, T* tgt, + const int* mask_offset, + const int n) +{ + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int tgt_seq_id = bid + mask_offset[bid]; + const int src_seq_id = bid; + + for(int i = tid; i < n; i += blockDim.x) + { + tgt[tgt_seq_id * n + i] = src[src_seq_id * n + i]; + } +} + template void OpenMultiHeadAttention::multiHeadAttr_nofuse_kernelLauncher( cudaStream_t stream, @@ -315,36 +516,57 @@ void OpenMultiHeadAttention::multiHeadAttr_nofuse_kernelLauncher( const int size_per_head, const DataType_ scalar) { - - int m = batch_size * seq_len; - int k = head_num * size_per_head; + const int k = head_num * size_per_head; dim3 grid; dim3 block; - + if(OpType_ == OperationType::FP32) { - const int word_per_block = 1; - assert(k <= 1024); - assert(m / word_per_block * 3 <= 65536); - - dim3 grid(m / word_per_block * 3); - dim3 block(k); - add_QKV_bias<<>>(Q, bias_Q, K, bias_K, V, bias_V, q_buf_, k_buf_, v_buf_, - batch_size, seq_len, head_num, size_per_head, word_per_block); + if(param_.sequence_id_offset == nullptr || param_.valid_word_num == batch_size * seq_len) + { + const int m = batch_size * seq_len; + const int word_per_block = 1; + assert(k <= 1024); + assert(m / word_per_block * 3 <= 65536); + + dim3 grid(m / word_per_block * 3); + dim3 block(k); + add_QKV_bias<<>>(Q, bias_Q, K, bias_K, V, bias_V, q_buf_, k_buf_, v_buf_, + batch_size, seq_len, head_num, size_per_head, word_per_block); + } + else + { + add_QKV_bias_rebuild_padding<<>>(Q, bias_Q, K, bias_K, + V, bias_V, q_buf_, k_buf_, v_buf_, + batch_size, seq_len, head_num, size_per_head, param_.sequence_id_offset); + } } else { - const int word_per_block = 1; - grid.x = batch_size * seq_len / word_per_block; - block.x = head_num * size_per_head * word_per_block / 2; - - assert(block.x <= 1024); - - add_QKV_bias<<>>(Q, bias_Q, K, bias_K, V, bias_V, q_buf_, k_buf_, - v_buf_, batch_size, seq_len, head_num, size_per_head / 2, word_per_block); + if(param_.sequence_id_offset == nullptr || param_.valid_word_num == batch_size * seq_len) + { + const int word_per_block = 1; + grid.x = batch_size * seq_len / word_per_block; + block.x = head_num * size_per_head * word_per_block / 2; + + assert(block.x <= 1024); + + add_QKV_bias<<>>(Q, bias_Q, K, bias_K, V, bias_V, q_buf_, k_buf_, + v_buf_, batch_size, seq_len, head_num, size_per_head / 2, word_per_block); + } + else + { + add_QKV_bias_rebuild_padding<<>>((half2*)Q, (const half2*)bias_Q, + (half2*)K, (const half2*)bias_K, (half2*)V, (const half2*)bias_V, + (half2*)q_buf_, (half2*)k_buf_, (half2*)v_buf_, + batch_size, seq_len, head_num, size_per_head / 2, param_.sequence_id_offset); + } } + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); + DataType_ alpha = (DataType_)1.0f, beta = (DataType_)0.0f; check_cuda_error(cublasGemmStridedBatchedEx(cublas_handle, @@ -359,28 +581,52 @@ void OpenMultiHeadAttention::multiHeadAttr_nofuse_kernelLauncher( computeType_, static_cast(cublasAlgo_[1]))); - if(seq_len <= 32) - block.x = 32; - else if(seq_len > 32 && seq_len <= 64) - block.x = 64; - else if(seq_len > 64 && seq_len <= 128) - block.x = 128; - else if(seq_len > 128 && seq_len <= 256) - block.x = 256; - else if(seq_len > 256 && seq_len <= 512) - block.x = 512; - else - block.x = 1024; - - if(batch_size * head_num <= 120) - { - grid.x = batch_size * head_num * seq_len; - softmax_kernel_v2<<>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scalar); + //deal with odd seq_len + if (seq_len % 2 != 0){ + if(seq_len <= 32) + block.x = 32; + else if(seq_len > 32 && seq_len <= 64) + block.x = 64; + else if(seq_len > 64 && seq_len <= 128) + block.x = 128; + else if(seq_len > 128 && seq_len <= 256) + block.x = 256; + else if(seq_len > 256 && seq_len <= 512) + block.x = 512; + else + block.x = 1024; + + if(batch_size * head_num <= 120) + { + grid.x = batch_size * head_num * seq_len; + softmax_kernel_v2<<>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scalar); + } + else + { + grid.x = batch_size * head_num; + softmax_kernel<<>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scalar); + } } - else - { - grid.x = batch_size * head_num; - softmax_kernel<<>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scalar); + //deal with even seq_len + else{ + grid.x = seq_len; + grid.y = batch_size; + grid.z = head_num; + if (seq_len <= 32){ + block.x = 32; + softmax_kernel_v3_LE32<<>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scalar); + } + else{ + if (OpType_ == OperationType::FP16){ + block.x = (seq_len/2 + 31)/32*32; + softmax_kernel_v3<<>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scalar); + } + else{ + block.x = (seq_len + 31)/32*32; + softmax_kernel_v3<<>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scalar); + } + } + grid.x = grid.y = grid.z = 1; } check_cuda_error(cublasGemmStridedBatchedEx(cublas_handle, @@ -395,28 +641,45 @@ void OpenMultiHeadAttention::multiHeadAttr_nofuse_kernelLauncher( computeType_, static_cast(cublasAlgo_[2]))); -/* for half2 only */ + /* for half2 only */ if(OpType_ == OperationType::FP16) { - const int seq_per_block = 4; - // const int seq_per_block = 1; - grid.x = batch_size * head_num * seq_len / seq_per_block; - block.x = seq_per_block * size_per_head / 2; - - assert(grid.x * seq_per_block == batch_size * head_num * seq_len); - - transpose<<>>(transpose_dst_, dst, - batch_size, seq_len, head_num, size_per_head / 2); + if(param_.sequence_id_offset == nullptr || param_.valid_word_num == batch_size * seq_len) + { + const int seq_per_block = 4; + grid.x = batch_size * head_num * seq_len / seq_per_block; + block.x = seq_per_block * size_per_head / 2; + + assert(grid.x * seq_per_block == batch_size * head_num * seq_len); + + transpose<<>>(transpose_dst_, dst, + batch_size, seq_len, head_num, size_per_head / 2); + } + else + { + transpose_rebuild_padding<<>>( + (half2*)transpose_dst_, (half2*)dst, + batch_size, seq_len, head_num, size_per_head / 2, param_.sequence_id_offset); + } } else { - const int seq_per_block = 1; - grid.x = batch_size * head_num * seq_len / seq_per_block; - block.x = seq_per_block * size_per_head; - transpose<<>>(transpose_dst_, dst, - batch_size, seq_len, head_num, size_per_head); + if(param_.sequence_id_offset == nullptr || param_.valid_word_num == batch_size * seq_len) + { + const int seq_per_block = 1; + grid.x = batch_size * head_num * seq_len / seq_per_block; + block.x = seq_per_block * size_per_head; + transpose<<>>(transpose_dst_, dst, + batch_size, seq_len, head_num, size_per_head); + } + else + { + transpose_rebuild_padding<<>>(transpose_dst_, dst, + batch_size, seq_len, head_num, size_per_head, param_.sequence_id_offset); + } } - + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); } template void OpenMultiHeadAttention::multiHeadAttr_nofuse_kernelLauncher( diff --git a/fastertransformer/cuda/open_attention.h b/fastertransformer/cuda/open_attention.h index e876d189f..b3cb22d24 100644 --- a/fastertransformer/cuda/open_attention.h +++ b/fastertransformer/cuda/open_attention.h @@ -71,7 +71,8 @@ class OpenMultiHeadAttention: IMultiHeadAttention const IAllocator& allocator_; MultiHeadInitParam param_; - int cublasAlgo_[3]; + int cublasAlgo_[4]; + bool is_fuse_QKV; DataType_* buf_; DataType_* query_buf_; @@ -83,12 +84,16 @@ class OpenMultiHeadAttention: IMultiHeadAttention DataType_* qk_buf_; DataType_* transpose_dst_; + DataType_** qkv_kernel_; + DataType_** qkv_input_; + DataType_** qkv_buf_; int batch_size_; int from_seq_len_; int to_seq_len_; int head_num_; int size_per_head_; + public: //Ctor OpenMultiHeadAttention(const IAllocator& allocator, int batch_size, int from_seq_len, @@ -104,7 +109,7 @@ class OpenMultiHeadAttention: IMultiHeadAttention int qk_buf_size = batch_size_ * head_num_ * from_seq_len_ * from_seq_len_; try { - buf_ = (DataType_*) allocator_.malloc(sizeof(DataType_) * (buf_size * 7 + qk_buf_size)); + buf_ = (DataType_*) allocator_.malloc(sizeof(DataType_) * (buf_size * 7 + qk_buf_size) + sizeof(DataType_*) * 9); query_buf_ = buf_; key_buf_ = buf_ + buf_size; value_buf_ = buf_ + 2 * buf_size; @@ -113,6 +118,9 @@ class OpenMultiHeadAttention: IMultiHeadAttention v_buf_ = buf_ + 5 * buf_size; qk_buf_ = buf_ + 6 * buf_size; transpose_dst_ = qk_buf_ + qk_buf_size; + qkv_kernel_ = (DataType_**)(transpose_dst_ + buf_size); + qkv_input_ = qkv_kernel_ + 3; + qkv_buf_ = qkv_input_ + 3; FILE* fd = fopen("gemm_config.in", "r"); int err = 0; @@ -120,24 +128,30 @@ class OpenMultiHeadAttention: IMultiHeadAttention printf("gemm_config.in is not found\n"); else { - err = fscanf(fd, "%d%*d%*d%d%d", &cublasAlgo_[0], &cublasAlgo_[1], &cublasAlgo_[2]); + float split_time, fused_time; + err = fscanf(fd, "%d %f %*d %*f %*d %*f %d %*f %d %*f %d %f", + &cublasAlgo_[0], &split_time, &cublasAlgo_[1], &cublasAlgo_[2], &cublasAlgo_[3], &fused_time); + is_fuse_QKV = fused_time < split_time * 3 ? true : false; fclose(fd); } - if(err != 3) + if(err != 6) { - printf("loading GEMM algorithms error, using default GEMM algorithms\n"); - if(OpType_ == OperationType::FP32) - { - cublasAlgo_[0] = -1; - cublasAlgo_[1] = -1; - cublasAlgo_[2] = -1; - } - else - { - cublasAlgo_[0] = 99; - cublasAlgo_[1] = 99; - cublasAlgo_[2] = 99; - } + printf("loading GEMM algorithms error, using default GEMM algorithms\n"); + if(OpType_ == OperationType::FP32) + { + cublasAlgo_[0] = -1; + cublasAlgo_[1] = -1; + cublasAlgo_[2] = -1; + cublasAlgo_[3] = -1; + } + else + { + cublasAlgo_[0] = 99; + cublasAlgo_[1] = 99; + cublasAlgo_[2] = 99; + cublasAlgo_[3] = 99; + } + is_fuse_QKV = false; } } catch(std::runtime_error& error) @@ -151,57 +165,74 @@ class OpenMultiHeadAttention: IMultiHeadAttention #ifndef NDEBUG PRINT_FUNC_NAME_(); #endif - int m = batch_size_ * from_seq_len_; - int k = head_num_ * size_per_head_; - int n = k; + const int m = param_.sequence_id_offset == nullptr ? batch_size_ * from_seq_len_ : param_.valid_word_num; + const int k = head_num_ * size_per_head_; + const int n = k; + + const DataType_ alpha = (DataType_)1.0f, beta = (DataType_)0.0f; - DataType_ alpha = (DataType_)1.0f, beta = (DataType_)0.0f; - try { - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.self_attention.query_weight.kernel, AType_, n, - param_.from_tensor, BType_, k, - &beta, - query_buf_, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); + if(is_fuse_QKV == true) + { + check_cuda_error(cublasGemmBatchedEx(param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + (const void* const*) qkv_kernel_, AType_, n, + (const void* const*) qkv_input_, BType_, k, + &beta, + (void* const*)qkv_buf_, CType_, n, + 3, + computeType_, + static_cast(cublasAlgo_[3]))); + } + else + { + check_cuda_error(cublasGemmEx(param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.self_attention.query_weight.kernel, AType_, n, + param_.from_tensor, BType_, k, + &beta, + query_buf_, CType_, n, + computeType_, + static_cast(cublasAlgo_[0]))); #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); #endif - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.self_attention.key_weight.kernel, AType_, n, - param_.to_tensor, BType_, k, - &beta, - key_buf_, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); + check_cuda_error(cublasGemmEx(param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.self_attention.key_weight.kernel, AType_, n, + param_.to_tensor, BType_, k, + &beta, + key_buf_, CType_, n, + computeType_, + static_cast(cublasAlgo_[0]))); #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); #endif - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.self_attention.value_weight.kernel, AType_, n, - param_.to_tensor, BType_, k, - &beta, - value_buf_, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); - + check_cuda_error(cublasGemmEx(param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.self_attention.value_weight.kernel, AType_, n, + param_.to_tensor, BType_, k, + &beta, + value_buf_, CType_, n, + computeType_, + static_cast(cublasAlgo_[0]))); + } + #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); @@ -269,8 +300,17 @@ class OpenMultiHeadAttention: IMultiHeadAttention #ifndef NDEBUG PRINT_FUNC_NAME_(); #endif - //Do all the malloc here param_ = param; + if(is_fuse_QKV == true && param_.from_tensor != nullptr) + { + // For tensorrt, we cannot get the pointer of from tensor until enqueue + const DataType_* hA[] {param_.self_attention.query_weight.kernel, + param_.self_attention.key_weight.kernel, + param_.self_attention.value_weight.kernel, + param_.from_tensor, param_.from_tensor, param_.from_tensor, + query_buf_, key_buf_, value_buf_}; + cudaMemcpyAsync((void*)qkv_kernel_, hA, sizeof(DataType_*) * 9, cudaMemcpyHostToDevice, param_.stream); + } } void trt_initialize(DataType_* from_tensor, DataType_* to_tensor, DataType_* attr_mask, cudaStream_t stream, cublasHandle_t cublas_handle) @@ -280,6 +320,15 @@ class OpenMultiHeadAttention: IMultiHeadAttention param_.attr_mask = attr_mask; param_.stream = stream; param_.cublas_handle = cublas_handle; + if(is_fuse_QKV == true) + { + const DataType_* hA[] {param_.self_attention.query_weight.kernel, + param_.self_attention.key_weight.kernel, + param_.self_attention.value_weight.kernel, + param_.from_tensor, param_.from_tensor, param_.from_tensor, + query_buf_, key_buf_, value_buf_}; + cudaMemcpyAsync((void*)qkv_kernel_, hA, sizeof(DataType_*) * 9, cudaMemcpyHostToDevice, param_.stream); + } } ~OpenMultiHeadAttention() override diff --git a/fastertransformer/cuda/open_decoder.cu b/fastertransformer/cuda/open_decoder.cu index e615f2ffe..92365393e 100644 --- a/fastertransformer/cuda/open_decoder.cu +++ b/fastertransformer/cuda/open_decoder.cu @@ -1,870 +1,1679 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * Open sourced multi-head attention - **/ - -#include "fastertransformer/open_decoder.h" - -namespace fastertransformer{ - -/** - masked multi-head attention - */ -#define FINAL_MASK 0xffffffff -template -__inline__ __device__ -T warpReduceSum(T val) -{ - for(int mask = 16; mask > 0; mask >>= 1) - val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); - return val; -} -/* Calculate the sum of all elements in a block */ -template - __inline__ __device__ -T blockReduceSum(T val) -{ - static __shared__ T shared[32]; - //__shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - val = warpReduceSum(val); - - if(lane == 0) - shared[wid] = val; - - __syncthreads(); - - val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : (T)(0.0f); - val = warpReduceSum(val); - - return val; -} -template -__global__ -void add_bias_relu(T* out, const T* bias, int m, int n) -{ - T val, reg_bias; - - int row_id = blockIdx.x; - int ite = n / blockDim.x; - int tid = threadIdx.x; - - for(int i = 0; i < ite; ++i) - { - reg_bias = __ldg(&bias[i * blockDim.x + tid]); - row_id = blockIdx.x; - - while(row_id < m) - { - val = out[tid + i * blockDim.x + row_id * n] + reg_bias; - out[tid + i * blockDim.x + row_id * n] = (T)(val > 0.0f ? val : 0.0f); - row_id += gridDim.x; - } - } -} - -template <> - __global__ -void add_bias_relu(half* out, const half* bias, int m, int n) -{ - half2 val, reg_bias; - int row_id = blockIdx.x; - int ite = n / blockDim.x / 2; - int tid = threadIdx.x; - - half2* out_ptr = (half2*) out; - const half2* bias_ptr = (half2*) bias; - for(int i = 0; i < ite; ++i) - { - reg_bias = __ldg(&bias_ptr[i * blockDim.x + tid]); - row_id = blockIdx.x; - - while(row_id < m) - { - val = out_ptr[tid + i * blockDim.x + row_id * n / 2]; - val = __hadd2(val, reg_bias); - val.x = val.x > (half)0.0f ? val.x : (half)0.0f; - val.y = val.y > (half)0.0f ? val.y : (half)0.0f; - out_ptr[tid + i * blockDim.x + row_id * n / 2] = val; - row_id += gridDim.x; - } - } -} -template - __inline__ __device__ -T warpReduceMax(T val) -{ - for(int mask = 16; mask > 0; mask >>= 1) - val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); - return val; -} -/* Calculate the maximum of all elements in a block */ -template - __inline__ __device__ -T blockReduceMax(T val) -{ - static __shared__ T shared[32]; -// __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; // in-warp idx - int wid = threadIdx.x >> 5; // warp idx - - val = warpReduceMax(val); // get maxx in each warp - - if(lane == 0) // record in-warp maxx by warp Idx - shared[wid] = val; - - __syncthreads(); - - - val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : 0; - val = warpReduceMax(val); - - return val; -} -template -__global__ -void masked_attention_kernel(T* query_buf, const T* self_Q_bias, - T* key_cache, const T* self_K_bias, T* value_cache, const T* self_V_bias, - T* context_buf, int batch_size, int head_num, int size_per_head, const int step, const T scalar) -{ - extern __shared__ __align__(sizeof(T)) unsigned s_buf[]; - T* sq = reinterpret_cast(s_buf); - T* logits = reinterpret_cast(&sq[size_per_head]); - - int tid = threadIdx.x; - int bid = blockIdx.x / head_num; - int head_id = blockIdx.x % head_num; - - int qkv_id = bid * head_num * size_per_head + head_id * size_per_head + tid; - int qkv_bias_id = head_id * size_per_head + tid; - - if(tid < size_per_head) - sq[tid] = query_buf[qkv_id] + self_Q_bias[qkv_bias_id]; - __syncthreads(); - - //offset for each step - int offset = batch_size * head_num * size_per_head; - for(int ite = 0; ite < step; ++ite) - { - T key = tid < size_per_head ? key_cache[ite * offset + qkv_id] : (T)0.0f; - //for the last step, we should update K + bias_K to the cache - if(ite == step - 1 && tid < size_per_head) - { - key += self_K_bias[qkv_bias_id]; - key_cache[ite * offset + qkv_id] = key; - } - - T val = (tid < size_per_head) ? key * sq[tid] * scalar : (T)(0.0f); - T qk = blockReduceSum(val); - if(threadIdx.x == 0) - logits[ite] = qk; - __syncthreads(); //try to remove - } - __syncthreads(); //try to remove - - __shared__ float s_max_val, s_sum; - float local_i = tid < step ? (float)logits[tid] : -1e20f; - float max_val = blockReduceMax(local_i); - if(tid == 0) - s_max_val = max_val; - __syncthreads(); - - local_i -= s_max_val; - float local_o = tid < step ? __expf(local_i) : 0.0f; - float val = blockReduceSum(local_o); - - if(tid == 0) - s_sum = val + 1e-6; - __syncthreads(); - - if(tid < step) - logits[tid] = local_o / s_sum; - __syncthreads(); - - - if(tid < size_per_head) - { - T sum = (T)0.0f; - for(int ite = 0; ite < step; ++ite) - { - T value = value_cache[ite * offset + qkv_id]; - //for the last step, we should update K + bias_K to the cache - if(ite == step - 1) - { - value += self_V_bias[qkv_bias_id]; - value_cache[ite * offset + qkv_id] = value; - } - sum += value * logits[ite]; - } - context_buf[qkv_id] = sum; - } -} - -template -__global__ -void masked_attention_kernel_v2(T* query_buf, const T* self_Q_bias, - T* key_cache, const T* self_K_bias, T* value_cache, const T* self_V_bias, - T* context_buf, int batch_size, int head_num, int size_per_head, const int step, const T scalar) -{ - extern __shared__ __align__(sizeof(T)) unsigned s_buf[]; - T* sq = reinterpret_cast(s_buf); - T* logits = reinterpret_cast(&sq[size_per_head]); - - int tid = threadIdx.x; - int bid = blockIdx.x / head_num; - int head_id = blockIdx.x % head_num; - - int qkv_id = bid * head_num * size_per_head + head_id * size_per_head + tid; - int qkv_bias_id = head_id * size_per_head + tid; - - if(tid < size_per_head) - sq[tid] = query_buf[qkv_id] + self_Q_bias[qkv_bias_id]; - __syncthreads(); - - int warp_size = 32; - int offset = batch_size * head_num * size_per_head; - int warp_ite = size_per_head / warp_size; - - T qk = (T)0.0f; - - //each warp process one step - int step_id = threadIdx.x >> 5; - if(step_id < step) - { - for(int wite = 0; wite < warp_ite; ++wite) - { - T key = key_cache[step_id * offset + bid * head_num * size_per_head + head_id * size_per_head - + tid % warp_size + wite * warp_size]; - //for the last step, we should update K + bias_K to the cache - if(step_id == step - 1) - { - key += self_K_bias[bid * head_num * size_per_head + head_id * size_per_head + - tid % warp_size + wite * warp_size]; - key_cache[step_id * offset + bid * head_num * size_per_head + head_id * size_per_head - + tid % warp_size + wite * warp_size] = key; - } - qk += key * sq[tid % warp_size + wite * warp_size]; - } - - qk = warpReduceSum(qk * scalar); - if(threadIdx.x % warp_size == 0) - { - logits[step_id] = qk; - printf("step_id %d %f\n", step_id, qk); - } - - } - __syncthreads(); - - __shared__ float s_max_val, s_sum; - float local_i = tid < step ? (float)logits[tid] : -1e20f; - float max_val = blockReduceMax(local_i); - if(tid == 0) - s_max_val = max_val; - __syncthreads(); - - local_i -= s_max_val; - float local_o = tid < step ? __expf(local_i) : 0.0f; - float val = blockReduceSum(local_o); - - if(tid == 0) - s_sum = val; - __syncthreads(); - if(tid < step) - logits[tid] = local_o / s_sum; - __syncthreads(); - - - if(tid < size_per_head) - { - T sum = (T)0.0f; - for(int ite = 0; ite < step; ++ite) - { - T value = value_cache[ite * offset + qkv_id]; - //for the last step, we should update K + bias_K to the cache - if(ite == step - 1) - { - value += self_V_bias[qkv_bias_id]; - value_cache[ite * offset + qkv_id] = value; - } - sum += value * logits[ite]; - } - context_buf[qkv_id] = sum; - } -} - -template -void OpenDecoder::masked_multi_head_attention( - const DataType_* from_tensor, - DataType_* key_cache_, - DataType_* value_cache_, - DataType_* decoder_output, - const int step) -{ - int m = batch_size_; - int n = hidden_units_; - int k = hidden_units_; - - DataType_* key_buf_ = key_cache_ + (step - 1) * m * n; - DataType_* value_buf_ = value_cache_ + (step - 1) * m * n; - - DataType_ alpha = (DataType_)1.0f, beta = (DataType_)0.0f; - - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.self_attention.query_weight.kernel , AType_, n, - from_tensor, BType_, k, - &beta, - query_buf_, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); - - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.self_attention.key_weight.kernel, AType_, n, - from_tensor, BType_, k, - &beta, - key_buf_, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); - - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.self_attention.value_weight.kernel, AType_, n, - from_tensor, BType_, k, - &beta, - value_buf_, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); - - dim3 grid(batch_size_ * head_num_); - dim3 block(128); - - //suppose size_per_head <= 128 - if(step <= 64) - block.x = 64; - else if(step <= 128 && step > size_per_head_) - block.x = 128; - else if(step > 128 && step <= 256) - block.x = 256; - else if(step > 256 && step <= 512) - block.x = 512; - else - block.x = 1024; - - if(block.x < size_per_head_) - block.x = size_per_head_; - - assert(block.x <= 1024); - - DataType_ scalar = 1 / sqrtf(size_per_head_ * 1.0f); - - int shared_size = sizeof(DataType_) * (size_per_head_ + step); - - masked_attention_kernel<<>>( - query_buf_, param_.self_attention.query_weight.bias, - key_cache_, param_.self_attention.key_weight.bias, - value_cache_, param_.self_attention.value_weight.bias, - context_buf_, batch_size_, - head_num_, size_per_head_, step, scalar); - - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.self_attention.attention_output_weight.kernel, AType_, n, - context_buf_, BType_, k, - &beta, - decoder_output, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); -} - -template -__global__ -void cross_attention_kernel( - T* query_buf, const T* Q_bias, - T* key_cache, const T* K_bias, - T* value_cache, const T* V_bias, - const int* length_per_sample, T* context_buf, - int batch_size, int head_num, int size_per_head, int step, const int seq_len, const T scalar) -{ - int tid = threadIdx.x; - int bid = blockIdx.x / head_num; - int head_id = blockIdx.x % head_num; - - extern __shared__ __align__(sizeof(T)) unsigned s_buf[]; - T* sq = reinterpret_cast(s_buf); - T* logits = reinterpret_cast(&sq[size_per_head]); - - int length = __ldg(&length_per_sample[bid]); - - int qkv_id = bid * head_num * size_per_head + head_id * size_per_head + tid; - int qkv_bias_id = head_id * size_per_head + tid; - - if(tid < size_per_head) - sq[tid] = query_buf[qkv_id] + Q_bias[qkv_bias_id]; - __syncthreads(); - - for(int ite = 0; ite < length; ++ite) - { - int key_id = bid * (seq_len * head_num * size_per_head) + ite * (head_num * size_per_head) - + head_id * size_per_head + tid; - - T key = tid < size_per_head ? key_cache[key_id] : (T)(0.0f); - - //For the first step, we should add bias to key memory cache. - //The KV memory cache only need to be updated at the first step. - if(step == 1 && tid < size_per_head) - { - key += K_bias[head_id * size_per_head + tid]; - key_cache[key_id] = key; - } - - T val = (tid < size_per_head) ? key * sq[tid] * scalar : (T)(0.0f); - T qk = blockReduceSum(val); - if(threadIdx.x == 0) - logits[ite] = qk; - __syncthreads(); //try to remove - } - __syncthreads(); - - __shared__ float s_max_val, s_sum; - - float local_i = tid < length ? (float)logits[tid] : -1e20f; - float max_val = blockReduceMax(local_i); - if(tid == 0) - s_max_val = max_val; - __syncthreads(); - - local_i -= s_max_val; - float local_o = tid < length ? __expf(local_i) : 0.0f; - float val = blockReduceSum(local_o); - - if(tid == 0) - s_sum = val + 1e-6; - __syncthreads(); - if(tid < length) - logits[tid] = local_o / s_sum; - __syncthreads(); - - if(tid < size_per_head) - { - T sum = (T)0.0f; - for(int ite = 0; ite < length; ++ite) - { - int value_id = bid * seq_len * head_num * size_per_head + ite * head_num * size_per_head - + head_id * size_per_head + tid; - - T value = value_cache[value_id]; - - //for the first step, we should add bias to key memory cache - if(step == 1) - { - value += V_bias[head_id * size_per_head + tid]; - value_cache[value_id] = value; - } - sum += value * logits[ite]; - } - context_buf[bid * head_num * size_per_head + head_id * size_per_head + tid] = sum; - } -} - -/* attention with source sentence */ -template -void OpenDecoder::cross_multi_head_attention( - const DataType_* from_tensor, - const DataType_* memory_tensor, - DataType_* key_mem_cache, - DataType_* value_mem_cache, - DataType_* decoder_output, - const int* length, - const int seq_len, - const int step) -{ - int m = batch_size_; - int n = hidden_units_; - int k = hidden_units_; - - DataType_ alpha = (DataType_)1.0f, beta = (DataType_)0.0f; - - //reuse the query_buf - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.cross_attention.query_weight.kernel, AType_, n, - from_tensor, BType_, k, - &beta, - query_buf_, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); - - if(step == 1) - { - m *= seq_len; - k = memory_hidden_units_; - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.cross_attention.key_weight.kernel, AType_, n, - memory_tensor, BType_, k, - &beta, - key_mem_cache, CType_, n, - computeType_, - static_cast(cublasAlgo_[1]))); - - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.cross_attention.value_weight.kernel, AType_, n, - memory_tensor, BType_, k, - &beta, - value_mem_cache, CType_, n, - computeType_, - static_cast(cublasAlgo_[1]))); - k = hidden_units_; - } - - dim3 grid(batch_size_ * head_num_); - dim3 block(128); - - if(seq_len <= 64) - block.x = 64; - else if(seq_len <= 128 && seq_len > size_per_head_) - block.x = 128; - else if(seq_len > 128 && seq_len <= 256) - block.x = 256; - else if(seq_len > 256 && seq_len <= 512) - block.x = 512; - else - block.x = 1024; - - if(block.x < size_per_head_) - block.x = size_per_head_; - - assert(block.x <= 1024); - - DataType_ scalar = 1 / sqrtf(size_per_head_ * 1.0f); - - int shared_size = sizeof(DataType_) * (size_per_head_ + seq_len); - cross_attention_kernel<<>>( - query_buf_, param_.cross_attention.query_weight.bias, - key_mem_cache, param_.cross_attention.key_weight.bias, - value_mem_cache, param_.cross_attention.value_weight.bias, - length, context_buf_, - batch_size_, - head_num_, size_per_head_, step, seq_len, scalar); - - m = batch_size_; - n = head_num_ * size_per_head_; - k = n; - - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.cross_attention.attention_output_weight.kernel, AType_, n, - context_buf_, BType_, k, - &beta, - decoder_output, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); -} - -template -__global__ -void decoder_norm1_kernel(const T* input, const T* gamma, const T* beta, T* output, int m, int n) -{ - int tid = threadIdx.x; - - __shared__ float s_mean; - __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - - float local_out = tid < n ? (float)(__ldg(&input[blockIdx.x * n + tid])) : 0.0f; - - mean = blockReduceSum(local_out); - - if(threadIdx.x == 0) - s_mean = mean / n; - __syncthreads(); - - variance = blockReduceSum(tid < n ? (local_out - s_mean) * (local_out - s_mean) : 0.0f); - - if(threadIdx.x == 0) - s_variance = rsqrtf(variance / n + 1e-6); - - __syncthreads(); - - if(tid < n) - output[blockIdx.x * n + tid] = - (T)(((local_out - s_mean) * s_variance) * (float)(__ldg(&gamma[tid])) + (float)(__ldg(&beta[tid]))); -} - -template -__global__ -void decoder_norm2_kernel(const T* input, const T* gamma, const T* beta, const T* bias, T* output, T* norm_output, int m, int n) -{ - int tid = threadIdx.x; - - __shared__ float s_mean; - __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - - float local_out = 0.0f; - if(tid < n) - { - local_out = (float)(__ldg(&input[blockIdx.x * n + tid])); - local_out += (float)(output[blockIdx.x * n + tid]); - local_out += (float)(__ldg(&bias[tid])); - output[blockIdx.x * n + tid] = (T)local_out; - } - - mean = blockReduceSum(local_out); - if(threadIdx.x == 0) - s_mean = mean / n; - __syncthreads(); - - variance = blockReduceSum(tid < n ? (local_out - s_mean) * (local_out - s_mean) : 0.0f); - if(threadIdx.x == 0) - s_variance = rsqrtf(variance / n + 1e-6); - __syncthreads(); - - if(tid < n) - norm_output[blockIdx.x * n + tid] = - (T)((local_out - s_mean) * s_variance * (float)(__ldg(&gamma[tid])) + (float)(__ldg(&beta[tid]))); -} - - -template -void OpenDecoder::decoder_norm1( - const DataType_* input, - const DataType_* gamma, - const DataType_* beta, - DataType_* output, - int m, int n) -{ - dim3 grid(m); - dim3 block(min(n, 1024)); - - /* For general cases, n is equal to hidden_units, e.g., 512/1024. - Since we have warp shuffle inside the code, block.x % 32 should be 0. - */ - if(n % 32 != 0) - block.x = 1024; - - assert(n <= 1024); - -/* should pay attention to the rsqrt precision*/ - decoder_norm1_kernel<<>>(input, gamma, beta, output, m, n); -} - -template -void OpenDecoder::decoder_norm2( - const DataType_* input, - const DataType_* gamma, - const DataType_* beta, - const DataType_* bias, - DataType_* output, - DataType_* norm_output, - int m, int n) -{ - dim3 grid(m); - dim3 block(min(n, 1024)); - - assert(n <= 1024); - - /* For general cases, n is equal to hidden_units, e.g., 512/1024. - Since we have warp shuffle inside the code, block.x % 32 should be 0. - */ - - if(n % 32 != 0) - block.x = 1024; - - /* should pay attention to the rsqrt precision*/ - decoder_norm2_kernel<<>>(input, gamma, beta, bias, output, norm_output, m, n); -} - -template -void OpenDecoder::ffn( - const DataType_* input, - DataType_* ffn_inner, - DataType_* output, - const int m, - const int inner_size, - const int n) -{ - int m1 = m, k1 = n, n1 = inner_size; - DataType_ alpha = (DataType_)1.0f; - DataType_ beta = (DataType_)0.0f; - - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n1, m1, k1, - &alpha, - param_.ffn.intermediate_weight.kernel, AType_, n1, - input, BType_, k1, - &beta, - ffn_inner, CType_, n1, - computeType_, - static_cast(cublasAlgo_[2]))); - - dim3 grid(m1); - dim3 block(n1 / 4); - - assert(block.x <= 1024); - - add_bias_relu<<>>(ffn_inner, param_.ffn.intermediate_weight.bias, m1, n1); - - int m2 = m, n2 = n, k2 = inner_size; - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n2, m2, k2, - &alpha, - param_.ffn.output_weight.kernel, AType_, n2, - ffn_inner, BType_, k2, - &beta, - output, CType_, n2, - computeType_, - static_cast(cublasAlgo_[3]))); -} - -template -__global__ -void add_bias_input_kernel(T* output, const T* input, const T* bias, const int m, const int n) -{ - int id = blockIdx.x * n + threadIdx.x; - output[id] = output[id] + input[id] + __ldg(&bias[threadIdx.x]); -} - -template -void OpenDecoder::add_bias_input(DataType_* output, const DataType_* input, const int m, const int n) -{ - dim3 grid(m); - dim3 block(n); - assert(n <= 1024); - add_bias_input_kernel<<>>(output, input, param_.ffn.output_weight.bias, m, n); -} - -template void OpenDecoder::masked_multi_head_attention( - const float* from_tensor, - float* key_cache, - float* value_cache, - float* decoder_output, - const int step); - -template void OpenDecoder::masked_multi_head_attention( - const half* from_tensor, - half* key_cache, - half* value_cache, - half* decoder_output, - const int step); - -template void OpenDecoder::cross_multi_head_attention( - const float* from_tensor, - const float* memory_tensor, - float* key_mem_cache, - float* value_mem_cache, - float* decoder_output, - const int* length, - const int max_seq_len, - const int step); - -template void OpenDecoder::cross_multi_head_attention( - const half* from_tensor, - const half* memory_tensor, - half* key_mem_cache, - half* value_mem_cache, - half* decoder_output, - const int* length, - const int max_seq_len, - const int step); - -template void OpenDecoder::ffn( - const float* input, - float* ffn_inner, - float* otuput, - const int m, - const int inner_size, - const int n); - -template void OpenDecoder::ffn( - const half* input, - half* ffn_inner, - half* otuput, - const int m, - const int inner_size, - const int n); - -template void OpenDecoder::decoder_norm1( - const float* input, - const float* gamma, - const float* beta, - float* output, - int m, int n); - -template void OpenDecoder::decoder_norm1( - const half* input, - const half* gamma, - const half* beta, - half* output, - int m, int n); - -template void OpenDecoder::decoder_norm2( - const float* input, - const float* gamma, - const float* beta, - const float* bias, - float* output, - float* norm_output, - int m, int n); - -template void OpenDecoder::decoder_norm2( - const half* input, - const half* gamma, - const half* beta, - const half* bias, - half* output, - half* norm_output, - int m, int n); - -template void OpenDecoder::add_bias_input( - float* output, - const float* input, - const int m, - const int n); - -template void OpenDecoder::add_bias_input( - half* output, - const half* input, - const int m, - const int n); - -}//namespace FasterTransformer +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Open sourced multi-head attention + **/ + +#include "fastertransformer/open_decoder.h" + +#include "cub/cub.cuh" + +namespace fastertransformer{ + +const int WARP_SIZE = 32; +const bool ATTENION_OPT = true; +const int ATTENTION_BLOCK_SIZE = 256; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +using Copy_half_t = + typename std::conditional::type + >::type + >::type; + +template +using Copy_t = Copy_half_t; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/** + masked multi-head attention + */ +#define FINAL_MASK 0xffffffff +template +__inline__ __device__ +T warpReduceSum(T val) +{ + for(int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); + return val; +} +/* Calculate the sum of all elements in a block */ +template + __inline__ __device__ +T blockReduceSum(T val) +{ + static __shared__ T shared[32]; + // __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if(lane == 0) + shared[wid] = val; + + __syncthreads(); + + val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + + return val; +} +template +__global__ +void add_bias_relu(T* out, const T* bias, int m, int n) +{ + T val, reg_bias; + + int row_id = blockIdx.x; + int ite = n / blockDim.x; + int tid = threadIdx.x; + + for(int i = 0; i < ite; ++i) + { + reg_bias = __ldg(&bias[i * blockDim.x + tid]); + row_id = blockIdx.x; + + while(row_id < m) + { + val = out[tid + i * blockDim.x + row_id * n] + reg_bias; + out[tid + i * blockDim.x + row_id * n] = (T)(val > 0.0f ? val : 0.0f); + row_id += gridDim.x; + } + } +} + +template <> + __global__ +void add_bias_relu(half* out, const half* bias, int m, int n) +{ + half2 val, reg_bias; + int row_id = blockIdx.x; + int ite = n / blockDim.x / 2; + int tid = threadIdx.x; + + half2* out_ptr = (half2*) out; + const half2* bias_ptr = (half2*) bias; + for(int i = 0; i < ite; ++i) + { + reg_bias = __ldg(&bias_ptr[i * blockDim.x + tid]); + row_id = blockIdx.x; + + while(row_id < m) + { + val = out_ptr[tid + i * blockDim.x + row_id * n / 2]; + val = __hadd2(val, reg_bias); + val.x = val.x > (half)0.0f ? val.x : (half)0.0f; + val.y = val.y > (half)0.0f ? val.y : (half)0.0f; + out_ptr[tid + i * blockDim.x + row_id * n / 2] = val; + row_id += gridDim.x; + } + } +} +template + __inline__ __device__ +T warpReduceMax(T val) +{ + for(int mask = 16; mask > 0; mask >>= 1) + val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); + return val; +} +/* Calculate the maximum of all elements in a block */ +template + __inline__ __device__ +T blockReduceMax(T val) +{ + static __shared__ T shared[32]; +// __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + val = warpReduceMax(val); // get maxx in each warp + + if(lane == 0) // record in-warp maxx by warp Idx + shared[wid] = val; + + __syncthreads(); + + + val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : 0; + val = warpReduceMax(val); + + return val; +} + +template +__global__ +void masked_attention_kernel_opt( + T* __restrict key_buf, T* __restrict value_buf, + T* __restrict query_buf, const T* __restrict self_Q_bias, + T* __restrict key_cache, const T* __restrict self_K_bias, + T* __restrict value_cache, const T* __restrict self_V_bias, + T* __restrict context_buf, int batch_size, int head_num, const int step, const T scalar) +{ + typedef Copy_t copy_t; + const int elems_per_thread = size_per_head / WARP_SIZE; + + union Access_t + { + copy_t v; + T x[elems_per_thread]; // supported size 1,2,4 + }; + typedef struct Float_n_t + { + T x[elems_per_thread]; // supported size 1,2,4 + } float_n_t; + + __shared__ float_n_t sq[block_sz]; + + __shared__ float logits[1024]; // only use [0 ~ step-1], the step should be smaller than 1024 + + const int tid = threadIdx.x; + const int warp_num = block_sz / WARP_SIZE; + const int bid = blockIdx.x; + const int head_id = blockIdx.x % head_num; + const int warp_id = tid / WARP_SIZE; // warp_id in block + const int lane_id = tid % WARP_SIZE; // lane_id in warp + + typedef cub::BlockReduce MaxValBlockReduce; + typedef cub::BlockReduce BlockReduce; + __shared__ typename MaxValBlockReduce::TempStorage max_val_block_temp_storage; + __shared__ typename BlockReduce::TempStorage block_temp_storage; + __shared__ typename cub::WarpReduce::TempStorage temp_storage[warp_num]; + + int qkv_id = bid * size_per_head; + int qkv_bias_id = head_id * size_per_head; + + query_buf = &query_buf[qkv_id]; + key_buf = &key_buf[qkv_id]; + value_buf = &value_buf[qkv_id]; + self_K_bias = &self_K_bias[qkv_bias_id]; + key_cache = &key_cache[qkv_id]; + self_Q_bias = &self_Q_bias[qkv_bias_id]; + self_V_bias = &self_V_bias[qkv_bias_id]; + value_cache = &value_cache[qkv_id]; + context_buf = &context_buf[qkv_id]; + + Access_t bias_r, query_buf_r; + Access_t key_val_r, key_buf_r; + Access_t value_val_r, value_buf_r; + + // each warp will have its own copy of sq + query_buf_r.v = *((copy_t *)query_buf + lane_id); + key_buf_r.v = *((copy_t *)key_buf + lane_id); + bias_r.v = *((copy_t *)self_Q_bias + lane_id); + float qb_r[elems_per_thread]; + for (int i = 0; i < elems_per_thread; ++i) + { + qb_r[i] = (float)query_buf_r.x[i] + (float)bias_r.x[i]; + } + + //offset for each step + int offset = batch_size * head_num * size_per_head; + bias_r.v = *((copy_t *) self_K_bias + lane_id); + for(int ite = warp_id; ite < step; ite += warp_num) + { + key_val_r.v = *((copy_t *)&key_cache[ite * offset] + lane_id); + //for the last step, we should update K + bias_K to the cache + if(ite == step - 1) + { + for (int i = 0; i < elems_per_thread; i++) + { + key_val_r.x[i] = (float)key_buf_r.x[i] + (float)bias_r.x[i]; + } + *((copy_t *)&key_cache[ite * offset] + lane_id) = key_val_r.v; + } + float val = 0.f; + for (int i = 0; i < elems_per_thread; i++) + { + val = val + (float)key_val_r.x[i] * qb_r[i] * (float)scalar; + } + float qk = cub::WarpReduce(temp_storage[warp_id]).Sum(val); + if (lane_id == 0) + { + logits[ite] = qk; + } + } + __syncthreads(); + + __shared__ float s_max_val, s_sum; + + float local_i = -1e20f; + for(int i = tid; i < step; i += blockDim.x) + local_i = max(local_i, logits[i]); + + float max_val = MaxValBlockReduce(max_val_block_temp_storage).Reduce(local_i, cub::Max()); + if(tid == 0) + s_max_val = max_val; + __syncthreads(); + + + float local_o = 0.0f; + for(int i = tid; i < step; i += blockDim.x) + { + logits[i] = __expf(logits[i] - s_max_val); + local_o += logits[i]; + } + float val = BlockReduce(block_temp_storage).Sum(local_o); + + if(tid == 0) + s_sum = val + 1e-6; + __syncthreads(); + + float s_sum_inverse = __fdividef(1.0f, s_sum); + for(int i = tid; i < step; i += blockDim.x) + { + logits[i] = logits[i] * s_sum_inverse; + } + __syncthreads(); + + // This optimization introduces discrepancy because of different order in FP32 summation + float sum_r[elems_per_thread] = {0.f}; + bias_r.v = *((copy_t *) self_V_bias + lane_id); + value_buf_r.v = *((copy_t *)value_buf + lane_id); + + for(int ite = warp_id; ite < step; ite += warp_num) + { + value_val_r.v = *((copy_t *)&value_cache[ite * offset] + lane_id); + //for the last step, we should update K + bias_K to the cache + if(ite == step - 1) + { + for (int i = 0; i < elems_per_thread; i++) + { + value_val_r.x[i] = (float)value_buf_r.x[i] + (float)bias_r.x[i]; + } + *((copy_t *)&value_cache[ite * offset] + lane_id) = value_val_r.v; + } + for (int i = 0; i < elems_per_thread; ++i) + { + sum_r[i] += (float)value_val_r.x[i] * logits[ite]; + } + } + for (int i = 0; i < elems_per_thread; i++) + { + sq[warp_id * WARP_SIZE + lane_id].x[i] = sum_r[i]; + } + __syncthreads(); + if (warp_id == 0) + { + #pragma unroll + for (int j = 1; j < warp_num; j++) + { + for (int i = 0; i < elems_per_thread; ++i) + { + sum_r[i] = sum_r[i] + (float)sq[j * WARP_SIZE + tid].x[i]; + } + } + } + __syncthreads(); + #pragma unroll + for (int i = 0; i < elems_per_thread; i++) + { + value_val_r.x[i] = sum_r[i]; + } + if (warp_id == 0) + { + *((copy_t *)context_buf + lane_id) = value_val_r.v; + } +} + +// only use for compile +template +__global__ +void masked_attention_kernel_opt_half2( + float* __restrict key_buf, float* __restrict value_buf, + float* __restrict query_buf, const float* __restrict self_Q_bias, + float* __restrict key_cache, const float* __restrict self_K_bias, + float* __restrict value_cache, const float* __restrict self_V_bias, + float* __restrict context_buf, int batch_size, int head_num, const int step, const float scalar) {} + +template +__global__ +void masked_attention_kernel_opt_half2( + half* __restrict key_buf, half* __restrict value_buf, + half* __restrict query_buf, const half* __restrict self_Q_bias, + half* __restrict key_cache, const half* __restrict self_K_bias, + half* __restrict value_cache, const half* __restrict self_V_bias, + half* __restrict context_buf, int batch_size, int head_num, const int step, const half scalar) +{ + half2* key_buf_ptr = (half2*)key_buf; + half2* value_buf_ptr = (half2*)value_buf; + half2* query_buf_ptr = (half2*)query_buf; + half2* key_cache_ptr = (half2*)key_cache; + half2* value_cache_ptr = (half2*)value_cache; + const half2* self_Q_bias_ptr = (const half2*)self_Q_bias; + const half2* self_K_bias_ptr = (const half2*)self_K_bias; + const half2* self_V_bias_ptr = (const half2*)self_V_bias; + half2* context_buf_ptr = (half2*)context_buf; + + typedef Copy_t copy_t; + const int elems_per_thread = size_per_head / 2 / WARP_SIZE; + + union Access_t + { + copy_t v; + half2 x[elems_per_thread]; // supported size 1,2,4 + }; + typedef struct Half_n_t + { + half2 x[elems_per_thread]; // supported size 1,2,4 + } half_n_t; + + __shared__ half_n_t sq[block_sz]; + + __shared__ float logits[1024]; // only use [0 ~ step-1] + + const int tid = threadIdx.x; + const int warp_num = block_sz / WARP_SIZE; + const int bid = blockIdx.x; + const int head_id = blockIdx.x % head_num; + const int warp_id = tid / WARP_SIZE; // warp_id in block + const int lane_id = tid % WARP_SIZE; // lane_id in warp + + typedef cub::BlockReduce MaxValBlockReduce; + typedef cub::BlockReduce BlockReduce; + __shared__ typename MaxValBlockReduce::TempStorage max_val_block_temp_storage; + __shared__ typename BlockReduce::TempStorage block_temp_storage; + __shared__ typename cub::WarpReduce::TempStorage temp_storage[warp_num]; + + int qkv_id = bid * size_per_head / 2; + int qkv_bias_id = head_id * size_per_head / 2; + + query_buf_ptr = &query_buf_ptr[qkv_id]; + key_buf_ptr = &key_buf_ptr[qkv_id]; + value_buf_ptr = &value_buf_ptr[qkv_id]; + self_K_bias_ptr = &self_K_bias_ptr[qkv_bias_id]; + key_cache_ptr = &key_cache_ptr[qkv_id]; + self_Q_bias_ptr = &self_Q_bias_ptr[qkv_bias_id]; + self_V_bias_ptr = &self_V_bias_ptr[qkv_bias_id]; + value_cache_ptr = &value_cache_ptr[qkv_id]; + context_buf_ptr = &context_buf_ptr[qkv_id]; + + Access_t bias_r, query_buf_r; + Access_t key_val_r, key_buf_r; + Access_t value_val_r, value_buf_r; + + // each warp will have its own copy of sq + query_buf_r.v = *((copy_t *)query_buf_ptr + lane_id); + key_buf_r.v = *((copy_t *)key_buf_ptr + lane_id); + bias_r.v = *((copy_t *)self_Q_bias_ptr + lane_id); + half2 qb_r[elems_per_thread]; + for (int i = 0; i < elems_per_thread; ++i) + { + qb_r[i] = __hadd2(query_buf_r.x[i], bias_r.x[i]); + } + + //offset for each step + int offset = batch_size * head_num * size_per_head / 2; + bias_r.v = *((copy_t *) self_K_bias + lane_id); + for(int ite = warp_id; ite < step; ite += warp_num) + { + key_val_r.v = *((copy_t *)&key_cache_ptr[ite * offset] + lane_id); + //for the last step, we should update K + bias_K to the cache + if(ite == step - 1) + { + for (int i = 0; i < elems_per_thread; i++) + { + key_val_r.x[i] = __hadd2(key_buf_r.x[i], bias_r.x[i]); + } + *((copy_t *)&key_cache_ptr[ite * offset] + lane_id) = key_val_r.v; + } + float val = 0.f; + for (int i = 0; i < elems_per_thread; i++) + { + half2 val2 = __hmul2(key_val_r.x[i], qb_r[i]); + val = val + (float)((val2.x + val2.y) * scalar); + } + float qk = cub::WarpReduce(temp_storage[warp_id]).Sum(val); + if (lane_id == 0) + { + logits[ite] = qk; + } + } + __syncthreads(); + + __shared__ float s_max_val, s_sum; + float local_i = -1e20f; + for(int i = tid; i < step; i += blockDim.x) + local_i = max(local_i, logits[i]); + + float max_val = MaxValBlockReduce(max_val_block_temp_storage).Reduce(local_i, cub::Max()); + if(tid == 0) + s_max_val = max_val; + __syncthreads(); + + float local_o = 0.0f; + for(int i = tid; i < step; i += blockDim.x) + { + logits[i] = __expf(logits[i] - s_max_val); + local_o += logits[i]; + } + float val = BlockReduce(block_temp_storage).Sum(local_o); + + if(tid == 0) + s_sum = val + 1e-6; + __syncthreads(); + + float s_sum_inverse = __fdividef(1.0f, s_sum); + for(int i = tid; i < step; i += blockDim.x) + { + logits[i] = logits[i] * s_sum_inverse; + } + __syncthreads(); + + // This optimization introduces discrepancy because of different order in FP32 summation + half2 sum_r[elems_per_thread]; + for(int i = 0; i < elems_per_thread; i++) + { + sum_r[i].x = (half)0.f; + sum_r[i].y = (half)0.f; + } + bias_r.v = *((copy_t *) self_V_bias_ptr + lane_id); + value_buf_r.v = *((copy_t *)value_buf_ptr + lane_id); + + for(int ite = warp_id; ite < step; ite += warp_num) + { + value_val_r.v = *((copy_t *)&value_cache_ptr[ite * offset] + lane_id); + //for the last step, we should update K + bias_K to the cache + if(ite == step - 1) + { + for (int i = 0; i < elems_per_thread; i++) + { + value_val_r.x[i] = __hadd2(value_buf_r.x[i], bias_r.x[i]); + } + *((copy_t *)&value_cache_ptr[ite * offset] + lane_id) = value_val_r.v; + } + for (int i = 0; i < elems_per_thread; ++i) + { + half2 logit2_val; + logit2_val.x = (half)logits[ite]; + logit2_val.y = (half)logits[ite]; + sum_r[i] = __hadd2(sum_r[i], __hmul2(value_val_r.x[i], logit2_val)); + } + } + for (int i = 0; i < elems_per_thread; i++) + { + sq[warp_id * WARP_SIZE + lane_id].x[i] = sum_r[i]; + } + __syncthreads(); + if (warp_id == 0) + { + #pragma unroll + for (int j = 1; j < warp_num; j++) + { + for (int i = 0; i < elems_per_thread; ++i) + { + sum_r[i] = __hadd2(sum_r[i], sq[j * WARP_SIZE + tid].x[i]); + } + } + } + __syncthreads(); + #pragma unroll + for (int i = 0; i < elems_per_thread; i++) + { + value_val_r.x[i] = sum_r[i]; + } + if (warp_id == 0) + { + *((copy_t *)context_buf_ptr + lane_id) = value_val_r.v; + } +} + +template +__global__ +void masked_attention_kernel( + T* key_buf, T* value_buf, + T* query_buf, const T* self_Q_bias, + T* key_cache, const T* self_K_bias, T* value_cache, const T* self_V_bias, + T* context_buf, int batch_size, int head_num, int size_per_head, const int step, const T scalar) +{ + extern __shared__ __align__(sizeof(T)) unsigned s_buf[]; + T* sq = reinterpret_cast(s_buf); + T* logits = reinterpret_cast(&sq[size_per_head]); + + int tid = threadIdx.x; + int bid = blockIdx.x / head_num; + int head_id = blockIdx.x % head_num; + + int qkv_id = bid * head_num * size_per_head + head_id * size_per_head + tid; + int qkv_bias_id = head_id * size_per_head + tid; + + if(tid < size_per_head) + sq[tid] = query_buf[qkv_id] + self_Q_bias[qkv_bias_id]; + __syncthreads(); + + //offset for each step + int offset = batch_size * head_num * size_per_head; + for(int ite = 0; ite < step; ++ite) + { + T key = tid < size_per_head ? key_cache[ite * offset + qkv_id] : (T)0.0f; + //for the last step, we should update K + bias_K to the cache + if(ite == step - 1 && tid < size_per_head) + { + key = key_buf[qkv_id] + self_K_bias[qkv_bias_id]; + key_cache[ite * offset + qkv_id] = key; + } + + T val = (tid < size_per_head) ? key * sq[tid] * scalar : (T)(0.0f); + T qk = blockReduceSum(val); + if(threadIdx.x == 0) + logits[ite] = qk; + __syncthreads(); //try to remove + } + __syncthreads(); //try to remove + + __shared__ float s_max_val, s_sum; + float local_i = tid < step ? (float)logits[tid] : -1e20f; + float max_val = blockReduceMax(local_i); + if(tid == 0) + s_max_val = max_val; + __syncthreads(); + + local_i -= s_max_val; + float local_o = tid < step ? __expf(local_i) : 0.0f; + float val = blockReduceSum(local_o); + + if(tid == 0) + s_sum = val + 1e-6; + __syncthreads(); + + if(tid < step) + logits[tid] = local_o / s_sum; + __syncthreads(); + + if(tid < size_per_head) + { + T sum = (T)0.0f; + for(int ite = 0; ite < step; ++ite) + { + T value = value_cache[ite * offset + qkv_id]; + //for the last step, we should update K + bias_K to the cache + if(ite == step - 1) + { + value = value_buf[qkv_id] + self_V_bias[qkv_bias_id]; + value_cache[ite * offset + qkv_id] = value; + } + sum += value * logits[ite]; + } + context_buf[qkv_id] = sum; + } +} + +template +__global__ +void masked_attention_kernel_v2(T* query_buf, const T* self_Q_bias, + T* key_cache, const T* self_K_bias, T* value_cache, const T* self_V_bias, + T* context_buf, int batch_size, int head_num, int size_per_head, const int step, const T scalar) +{ + extern __shared__ __align__(sizeof(T)) unsigned s_buf[]; + T* sq = reinterpret_cast(s_buf); + T* logits = reinterpret_cast(&sq[size_per_head]); + + int tid = threadIdx.x; + int bid = blockIdx.x / head_num; + int head_id = blockIdx.x % head_num; + + int qkv_id = bid * head_num * size_per_head + head_id * size_per_head + tid; + int qkv_bias_id = head_id * size_per_head + tid; + + if(tid < size_per_head) + sq[tid] = query_buf[qkv_id] + self_Q_bias[qkv_bias_id]; + __syncthreads(); + + int warp_size = 32; + int offset = batch_size * head_num * size_per_head; + int warp_ite = size_per_head / warp_size; + + T qk = (T)0.0f; + + //each warp process one step + int step_id = threadIdx.x >> 5; + if(step_id < step) + { + for(int wite = 0; wite < warp_ite; ++wite) + { + T key = key_cache[step_id * offset + bid * head_num * size_per_head + head_id * size_per_head + + tid % warp_size + wite * warp_size]; + //for the last step, we should update K + bias_K to the cache + if(step_id == step - 1) + { + key += self_K_bias[bid * head_num * size_per_head + head_id * size_per_head + + tid % warp_size + wite * warp_size]; + key_cache[step_id * offset + bid * head_num * size_per_head + head_id * size_per_head + + tid % warp_size + wite * warp_size] = key; + } + qk += key * sq[tid % warp_size + wite * warp_size]; + } + + qk = warpReduceSum(qk * scalar); + if(threadIdx.x % warp_size == 0) + { + logits[step_id] = qk; + printf("step_id %d %f\n", step_id, qk); + } + + } + __syncthreads(); + + __shared__ float s_max_val, s_sum; + float local_i = tid < step ? (float)logits[tid] : -1e20f; + float max_val = blockReduceMax(local_i); + if(tid == 0) + s_max_val = max_val; + __syncthreads(); + + local_i -= s_max_val; + float local_o = tid < step ? __expf(local_i) : 0.0f; + float val = blockReduceSum(local_o); + + if(tid == 0) + s_sum = val; + __syncthreads(); + if(tid < step) + logits[tid] = local_o / s_sum; + __syncthreads(); + + + if(tid < size_per_head) + { + T sum = (T)0.0f; + for(int ite = 0; ite < step; ++ite) + { + T value = value_cache[ite * offset + qkv_id]; + //for the last step, we should update K + bias_K to the cache + if(ite == step - 1) + { + value += self_V_bias[qkv_bias_id]; + value_cache[ite * offset + qkv_id] = value; + } + sum += value * logits[ite]; + } + context_buf[qkv_id] = sum; + } +} + +template +void masked_attention_dispatch( + T* key_buf, T* value_buf, + T* query_buf, const T* self_Q_bias, + T* key_cache, const T* self_K_bias, T* value_cache, const T* self_V_bias, + T* context_buf, int batch_size, int head_num, int size_per_head, const int step, cudaStream_t stream) + { + const int block_sz = ATTENTION_BLOCK_SIZE; + T scalar = (T)(1.f / sqrtf(size_per_head * 1.0f)); + + dim3 grid(batch_size * head_num); + + int cond = size_per_head * ((ATTENION_OPT)? 1:0); + switch (cond) + { + case 32: + masked_attention_kernel_opt<32, block_sz, T><<>>( + key_buf, value_buf, + query_buf, self_Q_bias, key_cache, self_K_bias, value_cache, self_V_bias, context_buf, + batch_size, head_num, step, scalar); + break; + case 64: + if(sizeof(T) == 2) + masked_attention_kernel_opt_half2<64, block_sz><<>>( + key_buf, value_buf, + query_buf, self_Q_bias, key_cache, self_K_bias, value_cache, self_V_bias, context_buf, + batch_size, head_num, step, scalar); + else + masked_attention_kernel_opt<64, block_sz, T><<>>( + key_buf, value_buf, + query_buf, self_Q_bias, + key_cache, self_K_bias, + value_cache, self_V_bias, + context_buf, + batch_size, head_num, step, scalar); + break; + case 128: + if(sizeof(T) == 2) + masked_attention_kernel_opt_half2<128, block_sz><<>>( + key_buf, value_buf, + query_buf, self_Q_bias, key_cache, self_K_bias, value_cache, self_V_bias, context_buf, + batch_size, head_num, step, scalar); + else + masked_attention_kernel_opt<128, block_sz, T><<>>( + key_buf, value_buf, + query_buf, self_Q_bias, key_cache, self_K_bias, value_cache, self_V_bias, context_buf, + batch_size, head_num, step, scalar); + break; + default: + // default path + int block_size = 128; + + //suppose size_per_head <= 128 + if(step <= 64) + block_size = 64; + else if(step <= 128 && step > size_per_head) + block_size = 128; + else if(step > 128 && step <= 256) + block_size = 256; + else if(step > 256 && step <= 512) + block_size = 512; + else + block_size = 1024; + + if((int)block_size < size_per_head) + block_size = size_per_head; + + assert(block_size <= 1024); + dim3 block(block_size); + T scalar = 1 / sqrtf(size_per_head * 1.0f); + + + int shared_size = sizeof(T) * (size_per_head + step); + masked_attention_kernel<<>>( + key_buf, value_buf, + query_buf, self_Q_bias, + key_cache, self_K_bias, + value_cache, self_V_bias, + context_buf, batch_size, + head_num, size_per_head, step, scalar); + } + } + +template +void OpenDecoder::masked_multi_head_attention( + const DataType_* from_tensor, + DataType_* key_cache_, + DataType_* value_cache_, + DataType_* decoder_output, + const int step) +{ + int m = batch_size_; + int n = hidden_units_; + int k = hidden_units_; + + DataType_ alpha = (DataType_)1.0f, beta = (DataType_)0.0f; + + if(is_fuse_QKV == true) + { + check_cuda_error(cublasGemmBatchedEx(param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + (const void* const*) qkv_kernel_, AType_, n, + (const void* const*) qkv_input_, BType_, k, + &beta, + (void* const*)qkv_buf_, CType_, n, + 3, + computeType_, + static_cast(cublasAlgo_[4]))); + } + else + { + key_buf_ = key_cache_ + (step - 1) * m * n; + value_buf_ = value_cache_ + (step - 1) * m * n; + + check_cuda_error(cublasGemmEx(param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.self_attention.query_weight.kernel , AType_, n, + from_tensor, BType_, k, + &beta, + query_buf_, CType_, n, + computeType_, + static_cast(cublasAlgo_[0]))); + + check_cuda_error(cublasGemmEx(param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.self_attention.key_weight.kernel, AType_, n, + from_tensor, BType_, k, + &beta, + key_buf_, CType_, n, + computeType_, + static_cast(cublasAlgo_[0]))); + + check_cuda_error(cublasGemmEx(param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.self_attention.value_weight.kernel, AType_, n, + from_tensor, BType_, k, + &beta, + value_buf_, CType_, n, + computeType_, + static_cast(cublasAlgo_[0]))); + } + + masked_attention_dispatch( + key_buf_, value_buf_, + query_buf_, param_.self_attention.query_weight.bias, + key_cache_, param_.self_attention.key_weight.bias, + value_cache_, param_.self_attention.value_weight.bias, + context_buf_, batch_size_, + head_num_, size_per_head_, step, param_.stream); + + check_cuda_error(cublasGemmEx(param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.self_attention.attention_output_weight.kernel, AType_, n, + context_buf_, BType_, k, + &beta, + decoder_output, CType_, n, + computeType_, + static_cast(cublasAlgo_[0]))); +} + +template +__global__ +void cross_attention_kernel_opt( + T* __restrict query_buf, const T* __restrict Q_bias, + T* __restrict key_cache, const T* __restrict K_bias, + T* __restrict value_cache, const T* __restrict V_bias, + const int* length_per_sample, T* __restrict context_buf, + int batch_size, int head_num, const int step, const int seq_len, const float scalar) +{ + typedef Copy_t copy_t; + const int elems_per_thread = size_per_head / WARP_SIZE; + union Access_t + { + copy_t v; + T x[elems_per_thread]; // supported size 1,2,4 + }; + typedef struct Float_n_t + { + float x[elems_per_thread]; // supported size 1,2,4 + } float_n_t; + + __shared__ float_n_t sq[block_sz]; + __shared__ float logits[1024]; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int warp_num = block_sz / WARP_SIZE; + + typedef cub::BlockReduce MaxValBlockReduce; + typedef cub::BlockReduce BlockReduce; + __shared__ typename MaxValBlockReduce::TempStorage max_val_block_temp_storage; + __shared__ typename BlockReduce::TempStorage block_temp_storage; + + __shared__ typename cub::WarpReduce::TempStorage temp_storage[warp_num]; + + const int tid = threadIdx.x; + const int bid = blockIdx.x / head_num; + const int head_id = blockIdx.x % head_num; + + int length = __ldg(&length_per_sample[bid]); + + const int lane_id = tid % WARP_SIZE; + + int qkv_id = bid * head_num * size_per_head + head_id * size_per_head; + int qkv_bias_id = head_id * size_per_head; + + int key_value_id = bid * (seq_len * head_num * size_per_head) + + + head_id * size_per_head; + + query_buf = &query_buf[qkv_id]; + K_bias = &K_bias[qkv_bias_id]; + key_cache = &key_cache[key_value_id]; + Q_bias = &Q_bias[qkv_bias_id]; + V_bias = &V_bias[qkv_bias_id]; + value_cache = &value_cache[key_value_id]; + context_buf = &context_buf[qkv_id]; + + Access_t bias_r, key_val_r, query_buf_r; + + // each warp will have its own copy of sq + query_buf_r.v = *((copy_t *)query_buf + lane_id); + bias_r.v = *((copy_t *)Q_bias + lane_id); + float qb_r[elems_per_thread]; + for (int i = 0; i < elems_per_thread; ++i) + { + qb_r[i] = (float)query_buf_r.x[i] + (float)bias_r.x[i]; + } + + //offset for each step + int offset = head_num * size_per_head; + + bias_r.v = *((copy_t *) K_bias + lane_id); + for(int ite = warp_id; ite < length; ite += warp_num) + { + key_val_r.v = *((copy_t *)&key_cache[ite * offset] + lane_id); + + //For the first step, we should add bias to key memory cache. + //The KV memory cache only need to be updated at the first step. + if (step == 1) + { + for (int i = 0; i < elems_per_thread; i++) + { + key_val_r.x[i] = (float)key_val_r.x[i] + (float)bias_r.x[i]; + } + *((copy_t *)&key_cache[ite * offset] + lane_id) = key_val_r.v; + } + float val = 0.f; + for (int i = 0; i < elems_per_thread; i++) + { + val = val + (float)key_val_r.x[i] * qb_r[i] * scalar; + } + float qk = cub::WarpReduce(temp_storage[warp_id]).Sum(val); + if (lane_id == 0) + { + logits[ite] = qk; + } + } + __syncthreads(); + + __shared__ float s_max_val, s_sum; + float local_i = -1e20f; + for(int i = tid; i < length; i += blockDim.x) + local_i = max(local_i, logits[i]); + + float max_val = MaxValBlockReduce(max_val_block_temp_storage).Reduce(local_i, cub::Max()); + if(tid == 0) + s_max_val = max_val; + __syncthreads(); + + float local_o = 0.0f; + for(int i = tid; i < length; i += blockDim.x) + { + logits[i] = __expf(logits[i] - s_max_val); + local_o += logits[i]; + } + float val = BlockReduce(block_temp_storage).Sum(local_o); + + if(tid == 0) + s_sum = val + 1e-6; + __syncthreads(); + + float s_sum_inverse = __fdividef(1.0f, s_sum); + for(int i = tid; i < length; i += blockDim.x) + { + logits[i] = logits[i] * s_sum_inverse; + } + __syncthreads(); + + // This optimization introduces discrepancy because of different order in FP32 summation + float sum_r[elems_per_thread] = {0.f}; + bias_r.v = *((copy_t *) V_bias + lane_id); + for(int ite = warp_id; ite < length; ite += warp_num) + { + key_val_r.v = *((copy_t *)&value_cache[ite * offset] + lane_id); + + //For the first step, we should add bias to key memory cache. + if(step == 1) + { + for (int i = 0; i < elems_per_thread; i++) + { + key_val_r.x[i] = (float)key_val_r.x[i] + (float)bias_r.x[i]; + } + *((copy_t *)&value_cache[ite * offset] + lane_id) = key_val_r.v; + } + for (int i = 0; i < elems_per_thread; ++i) + { + sum_r[i] += (float)key_val_r.x[i] * logits[ite]; + } + } + for (int i = 0; i < elems_per_thread; i++) + { + sq[warp_id * WARP_SIZE + lane_id].x[i] = sum_r[i]; + } + __syncthreads(); + if (threadIdx.x < WARP_SIZE) + { + #pragma unroll + for (int j = 1; j < warp_num; j++) + { + for (int i = 0; i < elems_per_thread; ++i) + { + sum_r[i] = sum_r[i] + (float)sq[j * WARP_SIZE + threadIdx.x].x[i]; + } + } + } + __syncthreads(); + #pragma unroll + for (int i = 0; i < elems_per_thread; i++) + { + key_val_r.x[i] = sum_r[i]; + } + if (threadIdx.x < WARP_SIZE) + { + *((copy_t *)context_buf + lane_id) = key_val_r.v; + } +} + +template +__global__ +void cross_attention_kernel( + T* query_buf, const T* Q_bias, + T* key_cache, const T* K_bias, + T* value_cache, const T* V_bias, + const int* length_per_sample, T* context_buf, + int batch_size, int head_num, int size_per_head, int step, const int seq_len, const T scalar) +{ + int tid = threadIdx.x; + int bid = blockIdx.x / head_num; + int head_id = blockIdx.x % head_num; + + extern __shared__ __align__(sizeof(T)) unsigned s_buf[]; + T* sq = reinterpret_cast(s_buf); + T* logits = reinterpret_cast(&sq[size_per_head]); + + int length = __ldg(&length_per_sample[bid]); + + int qkv_id = bid * head_num * size_per_head + head_id * size_per_head + tid; + int qkv_bias_id = head_id * size_per_head + tid; + + if(tid < size_per_head) + sq[tid] = query_buf[qkv_id] + Q_bias[qkv_bias_id]; + __syncthreads(); + + for(int ite = 0; ite < length; ++ite) + { + int key_id = bid * (seq_len * head_num * size_per_head) + ite * (head_num * size_per_head) + + head_id * size_per_head + tid; + + T key = tid < size_per_head ? key_cache[key_id] : (T)(0.0f); + + //For the first step, we should add bias to key memory cache. + //The KV memory cache only need to be updated at the first step. + if(step == 1 && tid < size_per_head) + { + key += K_bias[head_id * size_per_head + tid]; + key_cache[key_id] = key; + } + + T val = (tid < size_per_head) ? key * sq[tid] * scalar : (T)(0.0f); + T qk = blockReduceSum(val); + if(threadIdx.x == 0) + logits[ite] = qk; + __syncthreads(); //try to remove + } + __syncthreads(); + + __shared__ float s_max_val, s_sum; + + float local_i = tid < length ? (float)logits[tid] : -1e20f; + float max_val = blockReduceMax(local_i); + if(tid == 0) + s_max_val = max_val; + __syncthreads(); + + local_i -= s_max_val; + float local_o = tid < length ? __expf(local_i) : 0.0f; + float val = blockReduceSum(local_o); + + if(tid == 0) + s_sum = val + 1e-6; + __syncthreads(); + if(tid < length) + logits[tid] = local_o / s_sum; + __syncthreads(); + + if(tid < size_per_head) + { + T sum = (T)0.0f; + for(int ite = 0; ite < length; ++ite) + { + int value_id = bid * seq_len * head_num * size_per_head + ite * head_num * size_per_head + + head_id * size_per_head + tid; + + T value = value_cache[value_id]; + + //for the first step, we should add bias to key memory cache + if(step == 1) + { + value += V_bias[head_id * size_per_head + tid]; + value_cache[value_id] = value; + } + sum += value * logits[ite]; + } + context_buf[bid * head_num * size_per_head + head_id * size_per_head + tid] = sum; + } +} + +template +void cross_attention_dispatch(T* query_buf, const T* Q_bias, + T* key_cache, const T* K_bias, T* value_cache, const T* V_bias, const int* length, + T* context_buf, int batch_size, int head_num, int size_per_head, int step, int seq_len, cudaStream_t stream) + { + const int block_sz = ATTENTION_BLOCK_SIZE; + float scalar = 1.f / sqrtf(size_per_head * 1.0f); + + dim3 grid(batch_size * head_num); + + int cond = size_per_head * ((ATTENION_OPT)? 1:0); + switch (cond) + { + case 32: + cross_attention_kernel_opt<<>>( + query_buf, Q_bias, key_cache, K_bias, value_cache, V_bias, length, context_buf, + batch_size, head_num, step, seq_len, scalar); + break; + case 64: + cross_attention_kernel_opt<<>>( + query_buf, Q_bias, key_cache, K_bias, value_cache, V_bias, length, context_buf, + batch_size, head_num, step, seq_len, scalar); + break; + case 128: + cross_attention_kernel_opt<<>>( + query_buf, Q_bias, key_cache, K_bias, value_cache, V_bias, length, context_buf, + batch_size, head_num, step, seq_len, scalar); + break; + default: + // default path + + int block_size = 128; + + if(seq_len <= 64) + block_size = 64; + else if(seq_len <= 128 && seq_len > size_per_head) + block_size = 128; + else if(seq_len > 128 && seq_len <= 256) + block_size = 256; + else if(seq_len > 256 && seq_len <= 512) + block_size = 512; + else + block_size = 1024; + + if(block_size < size_per_head) + block_size = size_per_head; + + assert(block_size <= 1024); + dim3 block(block_size); + + int shared_size = sizeof(T) * (size_per_head + seq_len); + cross_attention_kernel<<>>( + query_buf, Q_bias, + key_cache, K_bias, + value_cache, V_bias, + length, context_buf, + batch_size, + head_num, size_per_head, step, seq_len, scalar); + } + } + +/* attention with source sentence */ +template +void OpenDecoder::cross_multi_head_attention( + const DataType_* from_tensor, + const DataType_* memory_tensor, + DataType_* key_mem_cache, + DataType_* value_mem_cache, + DataType_* decoder_output, + const int* length, + const int seq_len, + const int step) +{ + int m = batch_size_; + int n = hidden_units_; + int k = hidden_units_; + + DataType_ alpha = (DataType_)1.0f, beta = (DataType_)0.0f; + + //reuse the query_buf + check_cuda_error(cublasGemmEx(param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.cross_attention.query_weight.kernel, AType_, n, + from_tensor, BType_, k, + &beta, + query_buf_, CType_, n, + computeType_, + static_cast(cublasAlgo_[0]))); + + if(step == 1) + { + m *= seq_len; + k = memory_hidden_units_; + check_cuda_error(cublasGemmEx(param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.cross_attention.key_weight.kernel, AType_, n, + memory_tensor, BType_, k, + &beta, + key_mem_cache, CType_, n, + computeType_, + static_cast(cublasAlgo_[1]))); + + check_cuda_error(cublasGemmEx(param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.cross_attention.value_weight.kernel, AType_, n, + memory_tensor, BType_, k, + &beta, + value_mem_cache, CType_, n, + computeType_, + static_cast(cublasAlgo_[1]))); + k = hidden_units_; + } + + cross_attention_dispatch( + query_buf_, param_.cross_attention.query_weight.bias, + key_mem_cache, param_.cross_attention.key_weight.bias, + value_mem_cache, param_.cross_attention.value_weight.bias, + length, context_buf_, batch_size_, + head_num_, size_per_head_, step, seq_len, param_.stream); + + m = batch_size_; + n = head_num_ * size_per_head_; + k = n; + + check_cuda_error(cublasGemmEx(param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.cross_attention.attention_output_weight.kernel, AType_, n, + context_buf_, BType_, k, + &beta, + decoder_output, CType_, n, + computeType_, + static_cast(cublasAlgo_[0]))); +} + +template +__global__ +void decoder_norm1_kernel(const T* __restrict input, + const T* __restrict gamma, + const T* __restrict beta, + T* output, + int m, int n) +{ + int tid = threadIdx.x; + + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + + float local_out = tid < n ? (float)(__ldg(&input[blockIdx.x * n + tid])) : 0.0f; + + mean = blockReduceSum(local_out); + + if(threadIdx.x == 0) + s_mean = mean / n; + __syncthreads(); + + variance = blockReduceSum(tid < n ? (local_out - s_mean) * (local_out - s_mean) : 0.0f); + + if(threadIdx.x == 0) + s_variance = rsqrtf(variance / n + 1e-6); + + __syncthreads(); + + if(tid < n) + output[blockIdx.x * n + tid] = + (T)(((local_out - s_mean) * s_variance) * (float)(__ldg(&gamma[tid])) + (float)(__ldg(&beta[tid]))); +} + +template <> +__global__ +void decoder_norm1_kernel(const half* __restrict input, + const half* __restrict gamma, + const half* __restrict beta, + half* output, + int m, int n) +{ + const int tid = threadIdx.x; + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + float2 local_out_fp2; + + const half2* input_ptr = (const half2*)input; + const half2* gamma_ptr = (const half2*)gamma; + const half2* beta_ptr = (const half2*)beta; + half2* output_ptr = (half2*)output; + + float local_out = 0.0f; + int id = blockIdx.x * blockDim.x + tid; + if(tid < blockDim.x) + { + local_out_fp2 = __half22float2(__ldg(&input_ptr[id])); + local_out += local_out_fp2.x; + local_out += local_out_fp2.y; + } + + mean = blockReduceSum(local_out); + if(tid == 0) + s_mean = mean / n; + __syncthreads(); + + variance = blockReduceSum(tid < blockDim.x ? + (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean) + (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean) + : 0.0f); + if(tid == 0) + s_variance = rsqrtf(variance / n + 1e-6); + __syncthreads(); + + if(tid < blockDim.x) + { + float2 gamma_val = __half22float2(__ldg(&gamma_ptr[tid])); + float2 beta_val = __half22float2(__ldg(&beta_ptr[tid])); + local_out_fp2.x = (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x; + local_out_fp2.y = (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; + output_ptr[id] = __float22half2_rn(local_out_fp2); + } +} + +template +__global__ +void decoder_norm2_kernel(const T* __restrict input, + const T* __restrict gamma, + const T* __restrict beta, + const T* __restrict bias, + T* output, T* norm_output, + int m, int n) +{ + int tid = threadIdx.x; + + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + + float local_out = 0.0f; + if(tid < n) + { + local_out = (float)(__ldg(&input[blockIdx.x * n + tid])); + local_out += (float)(output[blockIdx.x * n + tid]); + local_out += (float)(__ldg(&bias[tid])); + output[blockIdx.x * n + tid] = (T)local_out; + } + + mean = blockReduceSum(local_out); + if(threadIdx.x == 0) + s_mean = mean / n; + __syncthreads(); + + variance = blockReduceSum(tid < n ? (local_out - s_mean) * (local_out - s_mean) : 0.0f); + if(threadIdx.x == 0) + s_variance = rsqrtf(variance / n + 1e-6); + __syncthreads(); + + if(tid < n) + norm_output[blockIdx.x * n + tid] = + (T)((local_out - s_mean) * s_variance * (float)(__ldg(&gamma[tid])) + (float)(__ldg(&beta[tid]))); +} + +template <> +__global__ +void decoder_norm2_kernel(const half* __restrict input, + const half* __restrict gamma, + const half* __restrict beta, + const half* __restrict bias, + half* output, half* norm_output, + int m, int n) +{ + const int tid = threadIdx.x; + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + float2 local_out_fp2; + + const half2* input_ptr = (const half2*)input; + const half2* gamma_ptr = (const half2*)gamma; + const half2* beta_ptr = (const half2*)beta; + const half2* bias_ptr = (const half2*)bias; + half2* output_ptr = (half2*)output; + half2* norm_output_ptr = (half2*)norm_output; + + float local_out = 0.0f; + int id = blockIdx.x * blockDim.x + tid; + if(tid < blockDim.x) + { + output_ptr[id] = __hadd2(__hadd2(output_ptr[id], __ldg(&input_ptr[id])), __ldg(&bias_ptr[tid])); + local_out_fp2 = __half22float2(output_ptr[id]); + local_out += local_out_fp2.x; + local_out += local_out_fp2.y; + } + + mean = blockReduceSum(local_out); + if(tid == 0) + s_mean = mean / n; + __syncthreads(); + + variance = blockReduceSum(tid < blockDim.x ? + (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean) + (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean) + : 0.0f); + if(tid == 0) + s_variance = rsqrtf(variance / n + 1e-6); + __syncthreads(); + + if(tid < blockDim.x) + { + float2 gamma_val = __half22float2(__ldg(&gamma_ptr[tid])); + float2 beta_val = __half22float2(__ldg(&beta_ptr[tid])); + local_out_fp2.x = (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x; + local_out_fp2.y = (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; + norm_output_ptr[id] = __float22half2_rn(local_out_fp2); + } +} + +template +void OpenDecoder::decoder_norm1( + const DataType_* input, + const DataType_* gamma, + const DataType_* beta, + DataType_* output, + int m, int n) +{ + dim3 grid(m); + dim3 block(min(n, 1024)); + + /* For general cases, n is equal to hidden_units, e.g., 512/1024. + Since we have warp shuffle inside the code, block.x % 32 should be 0. + */ + if(n % 32 != 0) + block.x = 1024; + + block.x = block.x / (4 / sizeof(DataType_)); // if using half, only need half of block.x + assert(block.x <= 1024); + +/* should pay attention to the rsqrt precision*/ + decoder_norm1_kernel<<>>(input, gamma, beta, output, m, n); +} + +template +void OpenDecoder::decoder_norm2( + const DataType_* input, + const DataType_* gamma, + const DataType_* beta, + const DataType_* bias, + DataType_* output, + DataType_* norm_output, + int m, int n) +{ + dim3 grid(m); + dim3 block(min(n, 1024)); + + + /* For general cases, n is equal to hidden_units, e.g., 512/1024. + Since we have warp shuffle inside the code, block.x % 32 should be 0. + */ + + if(n % 32 != 0) + block.x = 1024; + + block.x = block.x / (4 / sizeof(DataType_)); // if using half, only need half of block.x + assert(block.x <= 1024); + + /* should pay attention to the rsqrt precision*/ + decoder_norm2_kernel<<>>(input, gamma, beta, bias, output, norm_output, m, n); +} + +template +void OpenDecoder::ffn( + const DataType_* input, + DataType_* ffn_inner, + DataType_* output, + const int m, + const int inner_size, + const int n) +{ + int m1 = m, k1 = n, n1 = inner_size; + DataType_ alpha = (DataType_)1.0f; + DataType_ beta = (DataType_)0.0f; + + check_cuda_error(cublasGemmEx(param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n1, m1, k1, + &alpha, + param_.ffn.intermediate_weight.kernel, AType_, n1, + input, BType_, k1, + &beta, + ffn_inner, CType_, n1, + computeType_, + static_cast(cublasAlgo_[2]))); + + dim3 grid(m1); + dim3 block(n1 / 4); + + assert(block.x <= 1024); + + add_bias_relu<<>>(ffn_inner, param_.ffn.intermediate_weight.bias, m1, n1); + + int m2 = m, n2 = n, k2 = inner_size; + check_cuda_error(cublasGemmEx(param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n2, m2, k2, + &alpha, + param_.ffn.output_weight.kernel, AType_, n2, + ffn_inner, BType_, k2, + &beta, + output, CType_, n2, + computeType_, + static_cast(cublasAlgo_[3]))); +} + +template +__global__ +void add_bias_input_kernel(T* output, const T* input, const T* bias, const int m, const int n) +{ + int id = blockIdx.x * n + threadIdx.x; + output[id] = output[id] + input[id] + __ldg(&bias[threadIdx.x]); +} + +template +void OpenDecoder::add_bias_input(DataType_* output, const DataType_* input, const int m, const int n) +{ + dim3 grid(m); + dim3 block(n); + assert(n <= 1024); + add_bias_input_kernel<<>>(output, input, param_.ffn.output_weight.bias, m, n); +} + +template void OpenDecoder::masked_multi_head_attention( + const float* from_tensor, + float* key_cache, + float* value_cache, + float* decoder_output, + const int step); + +template void OpenDecoder::masked_multi_head_attention( + const half* from_tensor, + half* key_cache, + half* value_cache, + half* decoder_output, + const int step); + +template void OpenDecoder::cross_multi_head_attention( + const float* from_tensor, + const float* memory_tensor, + float* key_mem_cache, + float* value_mem_cache, + float* decoder_output, + const int* length, + const int max_seq_len, + const int step); + +template void OpenDecoder::cross_multi_head_attention( + const half* from_tensor, + const half* memory_tensor, + half* key_mem_cache, + half* value_mem_cache, + half* decoder_output, + const int* length, + const int max_seq_len, + const int step); + +template void OpenDecoder::ffn( + const float* input, + float* ffn_inner, + float* otuput, + const int m, + const int inner_size, + const int n); + +template void OpenDecoder::ffn( + const half* input, + half* ffn_inner, + half* otuput, + const int m, + const int inner_size, + const int n); + +template void OpenDecoder::decoder_norm1( + const float* input, + const float* gamma, + const float* beta, + float* output, + int m, int n); + +template void OpenDecoder::decoder_norm1( + const half* input, + const half* gamma, + const half* beta, + half* output, + int m, int n); + +template void OpenDecoder::decoder_norm2( + const float* input, + const float* gamma, + const float* beta, + const float* bias, + float* output, + float* norm_output, + int m, int n); + +template void OpenDecoder::decoder_norm2( + const half* input, + const half* gamma, + const half* beta, + const half* bias, + half* output, + half* norm_output, + int m, int n); + +template void OpenDecoder::add_bias_input( + float* output, + const float* input, + const int m, + const int n); + +template void OpenDecoder::add_bias_input( + half* output, + const half* input, + const int m, + const int n); + +}//namespace FasterTransformer diff --git a/fastertransformer/cuda/topk_kernels.cu b/fastertransformer/cuda/topk_kernels.cu new file mode 100644 index 000000000..0c273fd84 --- /dev/null +++ b/fastertransformer/cuda/topk_kernels.cu @@ -0,0 +1,660 @@ +/* +* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include "fastertransformer/cuda/topk_kernels.cuh" +#include "cub/cub.cuh" + +namespace fastertransformer +{ + +template +__launch_bounds__(THREADBLOCK_SIZE) +__global__ +void beam_topK_kernel(const T* log_probs, + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + const int vocab_size, + T diversity_rate) +{ + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + int thread_id = threadIdx.x; + int block_id = blockIdx.x; + TopK partial; + + #pragma unroll + for(int i = 0; i < MAX_K; ++i) + { + partial.p[i] = -1; + partial.u[i] = -FLT_MAX; + } + + #pragma unroll + for(int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) + { + int index = elem_id + block_id * vocab_size; + partial.insert(log_probs[index], index); + } + + TopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); + + if (thread_id == 0) + { + int index = block_id * MAX_K; + + #pragma unroll + for(int i = 0; i < MAX_K; ++i) + { + topk_tmp_id_buf[index + i] = total.p[i]; + topk_tmp_val_buf[index + i] = total.u[i] + diversity_rate * (T)i; + } + } +} + +template +__launch_bounds__(THREADBLOCK_SIZE) +__global__ +void batch_topK_kernel(int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + int* id_buf) +{ + int thread_id = threadIdx.x; + int block_id = blockIdx.x; + TopK partial; + if (thread_id == 0) + { + for(int i = 0; i < MAX_K; ++i) + { + partial.p[i] = -1; + partial.u[i] = -FLT_MAX; + } + + int index = block_id * MAX_K * MAX_K; + for(int i = 0; i < MAX_K * MAX_K; i++) + { + partial.insert( (T)topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]); + } + + index = block_id * MAX_K; + for(int i = 0; i < MAX_K; i++) + { + id_buf[index + i] = partial.p[i]; + } + } +} + +template +__launch_bounds__(THREADBLOCK_SIZE) +__global__ +void batch_topK_kernel_v2(int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + int* id_buf) +{ + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + int tid = threadIdx.x; + int bid = blockIdx.x; + TopK partial; + #pragma unroll + for(int i = 0; i < MAX_K; ++i) + { + partial.p[i] = -1; + partial.u[i] = -FLT_MAX; + } + + int ite = MAX_K * MAX_K / THREADBLOCK_SIZE; + #pragma unroll + for(int i = 0; i < ite; i++) + { + int index = bid * MAX_K * MAX_K + i * THREADBLOCK_SIZE + tid; + partial.insert( (T)topk_tmp_val_buf[index], topk_tmp_id_buf[index]); + } + + TopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); + + if(tid == 0) + { + #pragma unroll + for(int i = 0; i < MAX_K; i++) + id_buf[bid * MAX_K + i] = total.p[i]; + } +} + +template +__global__ void topk_stage_1_opt3( + const T* __restrict log_probs, + T* tmp_log_probs, + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + const int k, + const int vocab_size +) +{ + typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int row_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs + const int block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam + const int tmp_log_buf_index = row_id * vocab_size; + const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM_ * k + block_lane * k; + TopK_2 partial; + + for(int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size; elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) + { + int index = elem_id + tmp_log_buf_index; + tmp_log_probs[index] = log_probs[index]; + } + + + for(int ite = 0; ite < k; ite++) + { + partial.init(); + #pragma unroll + for(int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size; elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) + { + int index = elem_id + tmp_log_buf_index; + partial.insert(tmp_log_probs[index], index); + } + + TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); + + if (tid == 0) + { + const int index = tmp_topk_buf_index + ite; + topk_tmp_id_buf[index] = total.p; + topk_tmp_val_buf[index] = total.u; + tmp_log_probs[total.p] = -FLT_MAX; + } + __syncthreads(); + } +} + +template +__global__ void topk_stage_2_opt3( + const int* __restrict topk_tmp_id_buf, + T* topk_tmp_val_buf, + int* ids, + const int k) +{ + const int size = k * k * BLOCKS_PER_BEAM_; + const int tid = threadIdx.x; + const int batch_id = blockIdx.x; + + typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + extern __shared__ char array[]; + T *s_val = topk_tmp_val_buf + batch_id * size; + int *s_id = (int*)(array); + + TopK_2 partial; + + for(int ite = 0; ite < k; ite++) + { + partial.init(); + #pragma unroll + for(int i = tid; i < size; i+= BLOCK_SIZE_) + { + partial.insert(s_val[i], i); + } + + TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); + + if(tid == 0) + { + s_id[ite] = total.p; + s_val[total.p] = -FLT_MAX; + } + __syncthreads(); + } + if(tid < k) ids[batch_id * k + tid] = topk_tmp_id_buf[batch_id * size + s_id[tid]]; +} + +template +__global__ void topk_stage_1_opt2_general( + const T* __restrict log_probs, + T* tmp_log_probs, + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + const int k, + const int vocab_size +) +{ + typedef cub::BlockReduce, BLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int row_id = bid / BLOCKS_PER_BEAM; // row id for log_probs + const int block_lane = bid % BLOCKS_PER_BEAM; // block id for a beam + const int tmp_log_buf_index = row_id * vocab_size; + const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM * k + block_lane * k; + TopK_2 partial; + + for(int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size; elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM) + { + int index = elem_id + tmp_log_buf_index; + tmp_log_probs[index] = log_probs[index]; + } + + + for(int ite = 0; ite < k; ite++) + { + partial.init(); + #pragma unroll + for(int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size; elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM) + { + int index = elem_id + tmp_log_buf_index; + partial.insert(tmp_log_probs[index], index); + } + + TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); + + if (tid == 0) + { + const int index = tmp_topk_buf_index + ite; + topk_tmp_id_buf[index] = total.p; + topk_tmp_val_buf[index] = total.u; + tmp_log_probs[total.p] = -FLT_MAX; + } + __syncthreads(); + } +} + +template +__global__ void topk_stage_2_opt2_general( + const int* __restrict topk_tmp_id_buf, + T* topk_tmp_val_buf, + int* ids, + const int k) +{ + const int size = k * k * BLOCKS_PER_BEAM; + const int tid = threadIdx.x; + const int batch_id = blockIdx.x; + + typedef cub::BlockReduce, BLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + extern __shared__ char array[]; + T *s_val = topk_tmp_val_buf + batch_id * size; + int *s_id = (int*)(array); + + TopK_2 partial; + + for(int ite = 0; ite < k; ite++) + { + partial.init(); + #pragma unroll + for(int i = tid; i < size; i+= BLOCK_SIZE) + { + partial.insert(s_val[i], i); + } + + TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); + + if(tid == 0) + { + s_id[ite] = total.p; + s_val[total.p] = -FLT_MAX; + } + __syncthreads(); + } + if(tid < k) ids[batch_id * k + tid] = topk_tmp_id_buf[batch_id * size + s_id[tid]]; +} + +#define CASE_K_DIV(K,BLOCK_SIZE_1, BLOCK_SIZE_2) \ + case K: \ + beam_topK_kernel<<>>(log_probs, \ + topk_tmp_id_buf, topk_tmp_val_buf, vocab_size, diversity_rate); \ + if (K < 10) \ + batch_topK_kernel<<>>(topk_tmp_id_buf, topk_tmp_val_buf, ids); \ + else \ + batch_topK_kernel_v2<<>>(topk_tmp_id_buf, topk_tmp_val_buf, ids); \ + break; \ + +#define CASE_K(K,BLOCK_SIZE_1_, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_) \ + case K: \ + topk_stage_1_opt3<<>>( \ + log_probs, \ + temp_log_probs, \ + topk_tmp_id_buf, \ + topk_tmp_val_buf, \ + beam_width, vocab_size); \ + topk_stage_2_opt3<<>>( \ + topk_tmp_id_buf, \ + topk_tmp_val_buf, \ + ids, \ + beam_width); \ + break; \ + +template +void topK_kernelLauncher(T* log_probs, + T* temp_log_probs, + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + int* ids, + DecodingBeamsearchArguments args, + cudaStream_t stream) +{ + const int batch_size = args.batch_size_; + const int beam_width = args.beam_width_; + const int vocab_size = args.vocab_size_; + const T diversity_rate = args.beam_search_diversity_rate_; + if(diversity_rate == 0.0f) + { + switch(beam_width) + { + CASE_K(1,128,128,8); + CASE_K(4,128,128,8); + CASE_K(10,128,128,8); + CASE_K(16,128,128,5); + CASE_K(32,256,128,1); + CASE_K(64,256,256,1); + default: + topk_stage_1_opt2_general<<>>( + log_probs, + temp_log_probs, + topk_tmp_id_buf, + topk_tmp_val_buf, + beam_width, vocab_size); + topk_stage_2_opt2_general<<>>( + topk_tmp_id_buf, + topk_tmp_val_buf, + ids, + beam_width); + break; + } + } + else + { + switch(beam_width) + { + CASE_K_DIV(1,256,256); + CASE_K_DIV(4,256,256); + CASE_K_DIV(16,256,64); + CASE_K_DIV(64,256,64); + default: + printf("[ERROR] Topk kernel does not support beamwidth = %d \n", beam_width); + exit(0); + break; + } + } + +} +#undef CASE_K +#undef CASE_K_DIV + +template void topK_kernelLauncher(float* log_probs, + float* temp_log_probs, + int* topk_tmp_id_buf, + float* topk_tmp_val_buf, + int* ids, + DecodingBeamsearchArguments args, + cudaStream_t stream); + +// Sampling kernels +template +__global__ void sampling(int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + int* ids, + int* sequence_length, + bool* finished_buf, + const int candidate_num, + int random_num, + const int end_id, + const int vocab_size) +{ + int tid = threadIdx.x; + int bid = blockIdx.x; + __shared__ T sum; + __shared__ T rand_num; + + if(tid < candidate_num) + { + T max_val = topk_tmp_val_buf[bid * candidate_num]; + topk_tmp_val_buf[bid * candidate_num + tid] = __expf(topk_tmp_val_buf[bid * candidate_num + tid] - max_val); + } + + if(tid == 0) + { + sum = 0.0f; + for(int i = 0; i < candidate_num; i++) + { + sum = sum + topk_tmp_val_buf[bid * candidate_num + i]; + } + + curandState_t local_state; + curand_init((T)random_num, bid, 0, &local_state); + rand_num = (T)curand_uniform(&local_state) * sum; + + ids[bid] = topk_tmp_id_buf[bid * candidate_num + candidate_num - 1] % vocab_size; + for(int i = 0; i < candidate_num; i++) + { + rand_num = rand_num - topk_tmp_val_buf[bid * candidate_num + i]; + if(rand_num <= 0.0f){ + ids[bid] = topk_tmp_id_buf[bid * candidate_num + i] % vocab_size; + break; + } + } + + sequence_length[bid] = finished_buf[bid] ? sequence_length[bid] : sequence_length[bid] + 1; + finished_buf[bid] = ids[bid] == end_id ? 1 : 0; + } +} + +#define CASE_K(K) \ + case K : \ + beam_topK_kernel<<>>(log_probs, \ + topk_tmp_id_buf, topk_tmp_val_buf, vocab_size, 0.0f); \ + break; \ + +template +void topK_sampling_kernel_kernelLauncher(T* log_probs, + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + int* ids, + int* sequence_length, + bool* finished_buf, + int random_num, + DecodingSamplingArguments args, + cudaStream_t stream) +{ + const int batch_size = args.batch_size_; + const int vocab_size = args.vocab_size_; + const int candidate_num = args.candidate_num_; + const int end_id = args.end_id_; + const int block_size = 256; + switch(candidate_num) + { + CASE_K(1); + CASE_K(2); + CASE_K(4); + default: + printf("[ERROR] Topk kernel does not support candidate_num = %d \n", candidate_num); + exit(0); + break; + } + sampling <<< batch_size, candidate_num, 0, stream>>> (topk_tmp_id_buf, topk_tmp_val_buf, + ids, sequence_length, finished_buf, + candidate_num, random_num, end_id, vocab_size); +} + +template void topK_sampling_kernel_kernelLauncher(float* log_probs, + int* topk_tmp_id_buf, + float* topk_tmp_val_buf, + int* ids, + int* sequence_length, + bool* finished_buf, + int random_num, + DecodingSamplingArguments args, + cudaStream_t stream); + +__global__ void init_topp_id_val(int* topp_id_val_buf, + int* topp_offset_buf, + const int batch_size, + const int vocab_size) +{ + int tid = threadIdx.x; + int bid = blockIdx.x; + + if(bid == 0) + { + for(int i = tid; i < batch_size + 1; i+= blockDim.x) + { + topp_offset_buf[i] = i * vocab_size; + } + } + + while(tid < vocab_size) + { + topp_id_val_buf[bid * vocab_size + tid] = tid; + tid += blockDim.x; + } +} + + +void init_topp_id_val_kernel_kernelLauncher(int* topp_id_val_buf, + int* topp_offset_buf, + const int batch_size, + const int vocab_size, + cudaStream_t stream) +{ + init_topp_id_val<<>>(topp_id_val_buf, + topp_offset_buf, + batch_size, + vocab_size); +} + +// Sampling kernels +template +__global__ void top_p_sampling(T* sorted_log_probs, + int* sorted_id_vals, + int* ids, + int* sequence_length, + bool* finished_buf, + const int vocab_size, + const int random_num, + const float prob_threshold, + const int end_id) +{ + int tid = threadIdx.x; + curandState_t local_state; + curand_init((T)random_num, tid, 0, &local_state); + T rand_num = (T)curand_uniform(&local_state) * prob_threshold; + ids[tid] = sorted_id_vals[vocab_size - 1]; + + for(int i = tid * vocab_size; i < tid * vocab_size + vocab_size; i++) + { + rand_num = rand_num - sorted_log_probs[i]; + if(rand_num <= 0) + { + ids[tid] = sorted_id_vals[i]; + break; + } + } + + sequence_length[tid] = finished_buf[tid] ? sequence_length[tid] : sequence_length[tid] + 1; + finished_buf[tid] = ids[tid] == end_id ? 1 : 0; +} + +template +__global__ void sort_kernel(const T* log_probs, + const int* id_vals, + T* sorted_log_probs, + int* sorted_id_vals, + const int vocab_size) +{ + typedef cub::BlockRadixSort BlockRadixSort; + __shared__ typename BlockRadixSort::TempStorage temp_storage; + // Obtain a segment of consecutive items that are blocked across threads + T thread_keys[32]; + int thread_values[32]; + + int tid = threadIdx.x; + int bid = blockIdx.x; + for(int i = 0; i < 32; i++) + { + int index = tid + 256 * i + bid * vocab_size; + thread_keys[i] = log_probs[index]; + thread_values[i] = id_vals[index]; + } + BlockRadixSort(temp_storage).SortDescending(thread_keys, thread_values); + + for(int i = 0; i < 32; i++) + { + int index = tid + 256 * i + bid * vocab_size; + sorted_log_probs[index] = thread_keys[i]; + sorted_id_vals[index] = thread_values[i]; + } +} + +template +void topP_sampling_kernel_kernelLauncher(const T* log_probs, + const int* id_vals, + T* sorted_log_probs, + int* sorted_id_vals, + int* topp_offset_buf, + void* temp_storage, + bool* finished_buf, + int step, + DecodingSamplingArguments args, + int* output_ids, + int* sequence_length, + cudaStream_t stream) +{ + // sort_kernel<<>>(log_probs, + // id_vals, + // sorted_log_probs, + // sorted_id_vals, + // vocab_size); + cub::DeviceSegmentedRadixSort::SortPairsDescending(temp_storage, + args.temp_storage_size_, + log_probs, + sorted_log_probs, + id_vals, + sorted_id_vals, + args.vocab_size_ * args.batch_size_, + args.batch_size_, + topp_offset_buf, topp_offset_buf + 1); + + + top_p_sampling<<<1, args.batch_size_, 0, stream>>>(sorted_log_probs, + sorted_id_vals, + output_ids + (step - 1) * args.batch_size_, + sequence_length, + finished_buf, + args.vocab_size_, + step, + args.probability_threshold_, + args.end_id_); +} + +template void topP_sampling_kernel_kernelLauncher(const float* log_probs, + const int* id_vals, + float* sorted_log_probs, + int* sorted_id_vals, + int* topp_offset_buf, + void* temp_storage, + bool* finished_buf, + int step, + DecodingSamplingArguments args, + int* output_ids, + int* sequence_length, + cudaStream_t stream); + +} // end of namespace fastertransformer \ No newline at end of file diff --git a/fastertransformer/cuda/topk_kernels.cuh b/fastertransformer/cuda/topk_kernels.cuh new file mode 100644 index 000000000..9b9b19492 --- /dev/null +++ b/fastertransformer/cuda/topk_kernels.cuh @@ -0,0 +1,161 @@ +/* +* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#pragma once +#include +#include +#include +#include +#include +#include "fastertransformer/arguments.h" +#include "fastertransformer/cuda/cuda_kernels.h" +#include + +namespace fastertransformer{ + +#define DO_SPLIT_SMALL_TOP_K_SOFTMAX +static const int SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE = 256; +static const int SMALL_TOP_K_SOFTMAX_MAX_VOC_PARTS = 128; +static const int MAX_K = 4; + +template +struct TopK +{ + int p[MAX_K]; + T u[MAX_K]; + + __device__ __forceinline__ void insert(T elem, int elem_id) + { + if (elem > u[MAX_K-1] || (p[MAX_K-1] == -1) || ((elem == u[MAX_K-1]) && (elem_id < p[MAX_K-1]))) + //if (elem > u[MAX_K-1] || ((elem == u[MAX_K-1]) && (elem_id < p[MAX_K-1]))) + { + u[MAX_K-1] = elem; + p[MAX_K-1] = elem_id; + } + + for(int k = MAX_K - 2; k >= 0; --k) + { + if ((u[k+1] > u[k]) || (p[k] == -1) || ((u[k+1] == u[k])&&(p[k+1] < p[k]))) + //if ((u[k+1] > u[k]) || ((u[k+1] == u[k])&&(p[k+1] < p[k]))) + { + T u2 = u[k]; + int p2 = p[k]; + u[k] = u[k+1]; + p[k] = p[k+1]; + u[k+1] = u2; + p[k+1] = p2; + } + } + } + + __device__ __forceinline__ void init() + { + #pragma unroll + for(int i = 0; i < MAX_K; i++) + { + p[i] = -1; + u[i] = -FLT_MAX; + } + } +}; + +template +__device__ __forceinline__ TopK reduce_topk_op(const TopK& a, const TopK& b) +{ + TopK res = a; + for(int i = 0; i < MAX_K; ++i) + res.insert(b.u[i], b.p[i]); + return res; +} + +template +struct TopK_2 +{ + int p = -1; + T u = -FLT_MAX; + + __device__ __forceinline__ void insert(T elem, int elem_id) + { + if(elem > u) + { + u = elem; + p = elem_id; + } + } + + __device__ __forceinline__ void init() + { + u = -FLT_MAX; + p = -1; + } +}; + +template +__device__ __forceinline__ TopK_2 reduce_topk_op_2(const TopK_2& a, const TopK_2& b) +{ + return a.u > b.u ? a : b; +} + +template +void topK_kernelLauncher(T* log_probs, + T* temp_log_probs, + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + int* ids, + DecodingBeamsearchArguments args, + cudaStream_t stream); + +template +void topK_softMax(const T* log_probs, + const float* bias, + const bool* finished, + T* cum_log_probs, + int* ids, + void * tmp_storage, + DecodingBeamsearchArguments args, + cudaStream_t stream); + +/* *************************** end of BeamSearch kernel *********************************** */ + +/* ********************************** Sampling kernel *********************************** */ + +template +void topK_sampling_kernel_kernelLauncher(T* log_probs, + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + int* ids, + int* sequence_length, + bool* finished_buf, + int random_num, + DecodingSamplingArguments args, + cudaStream_t stream); + +template +void topP_sampling_kernel_kernelLauncher(const T* log_probs, + const int* id_vals, + T* sorted_log_probs, + int* sorted_id_vals, + int* topp_offset_buf, + void* temp_storage, + bool* finished_buf, + int step, + DecodingSamplingArguments args, + int* output_ids, + int* sequence_length, + cudaStream_t stream); + +/* *************************** end of Sampling kernel *********************************** */ + +}//namespace fastertransformer diff --git a/fastertransformer/decoding_beamsearch.h b/fastertransformer/decoding_beamsearch.h new file mode 100644 index 000000000..0b7b0a5e5 --- /dev/null +++ b/fastertransformer/decoding_beamsearch.h @@ -0,0 +1,449 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Decoder transformer + **/ + +#pragma once + +#include "fastertransformer/common.h" +#include "fastertransformer/allocator.h" +#include "fastertransformer/open_decoder.h" +#include "fastertransformer/cuda/cuda_kernels.h" +#include "fastertransformer/arguments.h" +#include + +namespace fastertransformer +{ + +template +class DecodingBeamsearch +{ +private: + typedef DecoderTransformerTraits Traits_; + typedef typename Traits_::DataType DataType_; + const IAllocator &allocator_; + struct DecodingBeamsearchArguments args_; + + const cudaDataType_t computeType_ = Traits_::computeType; + const cudaDataType_t AType_ = Traits_::AType; + const cudaDataType_t BType_ = Traits_::BType; + const cudaDataType_t CType_ = Traits_::CType; + int cublasAlgo_[1] = {20}; + + OpenDecoder *decoder_; + DataType_ **K_cache_; + DataType_ **V_cache_; + DataType_ **K_mem_cache_; + DataType_ **V_mem_cache_; + DataType_ *from_tensor_[2]; + DataType_ *decoder_buf_; + DataType_ *decoder_normed_result_buf_; + float *logits_buf_; + float *cum_log_buf_; + int *word_ids_buf_; + int *topk_ids_buf_; + bool *finished_buf_; + void *buf_; + int *finished_count_buf_; + bool *h_finished_buf_; + float *topk_val_buf_; + float *temp_storage_; + float *temp_log_probs_buf_; + + bool is_fuse_topk_softMax_; + +public: + DecodingBeamsearch(const IAllocator &allocator, const int batch_size, + const int beam_width, const int seq_len, + const int head_num, const int size_per_head, + const int vocab_size, const int decoder_layers, + const int memory_hidden_units, const int memory_max_seq_len, + const int start_id, const int end_id, + const float beam_search_diversity_rate=-0.0f, + const bool is_fuse_topk_softMax=false) : allocator_(allocator), + is_fuse_topk_softMax_(is_fuse_topk_softMax) + { +#ifndef NDEBUG + PRINT_FUNC_NAME_(); +#endif + args_.batch_size_ = batch_size; + args_.beam_width_ = beam_width; + args_.seq_len_ = seq_len; + args_.head_num_ = head_num; + args_.size_per_head_ = size_per_head; + args_.hidden_units_ = head_num * size_per_head; + args_.decoder_layers_ = decoder_layers; + args_.vocab_size_ = vocab_size; + args_.start_id_ = start_id; + args_.end_id_ = end_id; + args_.beam_search_diversity_rate_ = beam_search_diversity_rate; + + K_cache_ = new DataType_ *[2]; + V_cache_ = new DataType_ *[2]; + + K_mem_cache_ = new DataType_ *[args_.decoder_layers_]; + V_mem_cache_ = new DataType_ *[args_.decoder_layers_]; + + decoder_ = new OpenDecoder(batch_size * beam_width, memory_max_seq_len, + head_num, size_per_head, memory_hidden_units); + + int from_tensor_size = args_.batch_size_ * args_.beam_width_ * args_.hidden_units_; // type T + int decoder_workspace_size = decoder_->getWorkspaceSize(); // type T + int decoder_normed_result_buffer_size = args_.batch_size_ * args_.beam_width_ * args_.hidden_units_; // type T + int cache_size = args_.batch_size_ * args_.beam_width_ * args_.seq_len_ * args_.hidden_units_; // type T + int mem_cache_size = args_.batch_size_ * args_.beam_width_ * memory_max_seq_len * args_.hidden_units_; // type T + + int logits_buf_size = args_.batch_size_ * args_.beam_width_ * args_.vocab_size_; // type float + int temp_log_probs_buf_size = args_.batch_size_ * args_.beam_width_ * args_.vocab_size_; // type float + int cum_log_buf_size = args_.batch_size_ * args_.beam_width_; // type float + int word_ids_buf_size = args_.batch_size_ * args_.beam_width_; //type int + int finished_buf_size = args_.batch_size_ * args_.beam_width_; //type bool + int finished_count_size = (int)(ceil(1 / 32.)) * 32; // type int + + int topk_ids_buf_size = args_.batch_size_ * args_.beam_width_ * (ceil)((args_.beam_width_ * args_.vocab_size_ * 1.0) / 1024.0); // type int + int topk_val_buf_size = args_.batch_size_ * args_.beam_width_ * args_.beam_width_; // type float + int storage_size_per_beam = 2 * args_.beam_width_ + SMALL_TOP_K_SOFTMAX_MAX_VOC_PARTS * (2 * MAX_K + 2); + args_.temp_storage_size_ = args_.batch_size_ * args_.beam_width_ * storage_size_per_beam; // type float + + // prevent memory misalinged address + logits_buf_size = (int)(ceil(logits_buf_size / 4.)) * 4; + temp_log_probs_buf_size = (int)(ceil(temp_log_probs_buf_size / 4.)) * 4; + cum_log_buf_size = (int)(ceil(cum_log_buf_size / 4.)) * 4; + word_ids_buf_size = (int)(ceil(word_ids_buf_size / 4.)) * 4; + finished_buf_size = (int)(ceil(finished_buf_size / 32.)) * 32; + topk_ids_buf_size = (int)(ceil(topk_ids_buf_size / 4.)) * 4; + topk_val_buf_size = (int)(ceil(topk_val_buf_size / 4.)) * 4; + args_.temp_storage_size_ = (int)(ceil(args_.temp_storage_size_ / 4.)) * 4; + + int datatype_buf_size = from_tensor_size * 2 + decoder_workspace_size + + (cache_size * 4 + mem_cache_size * 2) * args_.decoder_layers_ + decoder_normed_result_buffer_size; + + buf_ = reinterpret_cast(allocator_.malloc( + sizeof(DataType_) * datatype_buf_size + + sizeof(float) * (logits_buf_size + temp_log_probs_buf_size + cum_log_buf_size) + + sizeof(int) * word_ids_buf_size + + sizeof(bool) * finished_buf_size + + sizeof(int) * topk_ids_buf_size + + sizeof(float) * topk_val_buf_size + + sizeof(float) * args_.temp_storage_size_ + // should be always float + sizeof(int) * finished_count_size )); + + from_tensor_[0] = (DataType_ *)buf_; + from_tensor_[1] = (DataType_ *)(from_tensor_[0] + from_tensor_size); + + for (int i = 0; i < args_.decoder_layers_; ++i) + { + K_mem_cache_[i] = from_tensor_[1] + from_tensor_size + i * mem_cache_size * 2; + V_mem_cache_[i] = from_tensor_[1] + from_tensor_size + i * mem_cache_size * 2 + mem_cache_size; + } + + /* We use two-way buffer since we have to update KV buf at the end of each step. */ + K_cache_[0] = V_mem_cache_[decoder_layers - 1] + mem_cache_size + 0 * cache_size * args_.decoder_layers_; + K_cache_[1] = V_mem_cache_[decoder_layers - 1] + mem_cache_size + 1 * cache_size * args_.decoder_layers_; + V_cache_[0] = V_mem_cache_[decoder_layers - 1] + mem_cache_size + 2 * cache_size * args_.decoder_layers_; + V_cache_[1] = V_mem_cache_[decoder_layers - 1] + mem_cache_size + 3 * cache_size * args_.decoder_layers_; + + decoder_buf_ = V_cache_[1] + cache_size * args_.decoder_layers_; + decoder_normed_result_buf_ = (decoder_buf_ + decoder_workspace_size); + logits_buf_ = (float *)(decoder_normed_result_buf_ + decoder_normed_result_buffer_size); + temp_log_probs_buf_ = (float *)(logits_buf_ + logits_buf_size); + cum_log_buf_ = (float *)(temp_log_probs_buf_ + temp_log_probs_buf_size); + word_ids_buf_ = (int *)(cum_log_buf_ + cum_log_buf_size); + finished_buf_ = (bool *)(word_ids_buf_ + word_ids_buf_size); + topk_ids_buf_ = (int *)(finished_buf_ + finished_buf_size); + topk_val_buf_ = (float*)(topk_ids_buf_ + topk_ids_buf_size); + temp_storage_ = (float*)(topk_val_buf_ + topk_val_buf_size); + finished_count_buf_ = (int *)(temp_storage_ + args_.temp_storage_size_); + + h_finished_buf_ = new bool[finished_buf_size]; + + FILE *fd = fopen("decoding_gemm_config.in", "r"); + int err = 0; + if (fd == NULL) + printf("[WARNING] decoding_gemm_config.in is not found\n"); + else + { + err = fscanf(fd, "%d", &cublasAlgo_[0]); + fclose(fd); + } + if (err != 1) + { + printf("[WARNING] decoding loading GEMM algorithms error, using default GEMM algorithms!\n"); + if (Traits_::OpType == OperationType::FP32) + { + cublasAlgo_[0] = CUBLAS_GEMM_DEFAULT; + } + else + { + cublasAlgo_[0] = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + } + else + { + // check that the gemm_config setting is runnable + if (Traits_::OpType == OperationType::FP32) + { + if (cublasAlgo_[0] > CUBLAS_GEMM_ALGO23 || cublasAlgo_[0] < CUBLAS_GEMM_DEFAULT) + { + // the algorithm is not for FP32 + printf("[ERROR] cuBLAS Algorithm %d is not used in FP32. \n", (int)cublasAlgo_[0]); + exit(-1); + } + } + else + { + if (cublasAlgo_[0] > CUBLAS_GEMM_ALGO15_TENSOR_OP || cublasAlgo_[0] < CUBLAS_GEMM_DEFAULT_TENSOR_OP) + { + // the algorithm is not for FP16 + printf("[ERROR] cuBLAS Algorithm %d is not used in FP16. \n", (int)cublasAlgo_[0]); + exit(-1); + } + } + } + } + + void forward(const DecoderInitParam *param, + DecodingInitParam decoding_params) + { + +#ifndef NDEBUG + PRINT_FUNC_NAME_(); +#endif + const int m = args_.batch_size_ * args_.beam_width_; + const int k = args_.hidden_units_; + const int n = args_.vocab_size_; + + /* + sequence_length initialize to 0 + finished: false + word_ids: start_id_ + cum_log_probs (for eacm beam, the first element is 0). e.g., [0 -inf -inf -inf][0 -inf -inf -inf] + */ + + init_kernelLauncher(finished_buf_, decoding_params.sequence_length, word_ids_buf_, cum_log_buf_, + args_.start_id_, args_.batch_size_, args_.beam_width_, decoding_params.stream); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); + + /* + User can check the init by init_kernel_check. + init_kernel_check will compare the results of GPU and CPU. + Note that init_kernel_check contains init and uses do not need to call it again. + */ + // init_kernel_check(finished_buf_, decoding_params.sequence_length, word_ids_buf_, cum_log_buf_, + // start_id_, batch_size_, beam_width_, decoding_params.stream); +#endif + + int cache_size = m * args_.seq_len_ * args_.hidden_units_; // type T + + for (int step = 1; step <= args_.seq_len_; ++step) + { + //we use two-way buffer + int kv_cache_id = step & 0x1; + + embedding_lookup_sine_position_encoding_kernel_launcher(from_tensor_[0], + decoding_params.embedding_table, + decoding_params.position_encoding_table + (step - 1) * args_.hidden_units_, + word_ids_buf_, + m, + args_.hidden_units_, + decoding_params.stream); + + int from_id, out_id; + for (int layer = 0; layer < args_.decoder_layers_; ++layer) + { + /* + For the first layer (layer-0), from_id is 0. We also stored the embedding lookup + result in from_tensor_[0] + */ + from_id = layer & 0x1; + out_id = 1 - from_id; + + /* + We use one decoder_ object to process multiple decoder layers. + + At the beginning of each decoder layer, we initialize the decoder object + with corresponding weights and decoder_buf_. + + The decoder_buf_ is reused. + */ + decoder_->initialize(param[layer], decoder_buf_); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + decoder_->forward(from_tensor_[from_id], decoding_params.memory_tensor, + K_cache_[kv_cache_id] + layer * cache_size, + V_cache_[kv_cache_id] + layer * cache_size, + K_mem_cache_[layer], V_mem_cache_[layer], + decoding_params.memory_sequence_length, from_tensor_[out_id], step); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + } + decoder_->decoder_norm1(from_tensor_[out_id], decoding_params.layernorm.gamma, + decoding_params.layernorm.beta, decoder_normed_result_buf_, m, k); + + float alpha = (float)1.0f; + float beta = (float)0.0f; + + check_cuda_error(cublasGemmEx(decoding_params.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + decoding_params.embedding_kernel, AType_, n, + decoder_normed_result_buf_, BType_, k, + &beta, + logits_buf_, CUDA_R_32F, n, + CUDA_R_32F, + static_cast(cublasAlgo_[0]))); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + // Beamsearch + if(is_fuse_topk_softMax_ == true) + { + topK_softMax(logits_buf_, + decoding_params.embedding_bias, + finished_buf_, + cum_log_buf_, + word_ids_buf_, + reinterpret_cast(temp_storage_), + args_, + decoding_params.stream); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + update_kernelLauncher_v2(finished_buf_, + decoding_params.parent_ids + (step - 1) * m, + decoding_params.sequence_length, + word_ids_buf_, + decoding_params.output_ids + (step - 1) * m, + finished_count_buf_, + args_, + decoding_params.stream); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + } + else + { + update_logits(logits_buf_, decoding_params.embedding_bias, args_.end_id_, finished_buf_, m, n, decoding_params.stream); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); + + /* + User can check the update_logits by update_logits_kernel_check. + update_logits_kernel_check will compare the results of GPU and CPU. + Note that update_logits_kernel_check contains update_logits and uses do not need to call it again. + */ + // update_logits_kernel_check(logits_buf_, decoding_params.embedding_bias, args_.end_id_, finished_buf_, m, n, decoding_params.stream); +#endif + + /* adding cum_log_buf_ to logits_buf_ */ + broadcast_kernelLauncher(logits_buf_, cum_log_buf_, args_.batch_size_, + args_.beam_width_, args_.vocab_size_, decoding_params.stream); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); + + /* + User can check the broadcast_kernel by broadcast_kernel_check. + broadcast_kernel_check will compare the results of GPU and CPU. + Note that broadcast_kernel_check contains broadcast_kernelLauncher and uses do not need to call it again. + */ + // broadcast_kernel_check(logits_buf_, cum_log_buf_, batch_size_, beam_width_, vocab_size_, decoding_params.stream); +#endif + + topK_kernelLauncher(logits_buf_, + temp_log_probs_buf_, + topk_ids_buf_, + topk_val_buf_, + word_ids_buf_, + args_, + decoding_params.stream); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + update_kernelLauncher(logits_buf_, cum_log_buf_, topk_ids_buf_, + finished_buf_, + decoding_params.parent_ids + (step - 1) * m, + decoding_params.sequence_length, + word_ids_buf_, + decoding_params.output_ids + (step - 1) * m, + args_.batch_size_, args_.beam_width_, args_.vocab_size_, + decoding_params.stream, args_.end_id_, finished_count_buf_); + } + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + update_KV_cache_kernelLauncher(K_cache_, V_cache_, + decoding_params.parent_ids + (step - 1) * m, + args_.batch_size_, args_.beam_width_, args_.hidden_units_, step, + cache_size, args_.decoder_layers_, decoding_params.stream); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); + + /* + User can check the update_KV_cache by update_KV_cache_kernel_check. + update_KV_cache_kernel_check will compare the results of GPU and CPU. + Note that update_KV_cache_kernel_check contains update_KV_cache and uses do not need to call it again. + */ + // update_KV_cache_kernel_check(K_cache_, V_cache_, decoding_params.parent_ids + (step - 1) * batch_size_ * beam_width_, batch_size_, beam_width_, hidden_units_, step, cache_size, decoder_layers_, decoding_params.stream); +#endif + + // TODO + // Find a better method to check the is_finished + cudaMemcpy(h_finished_buf_, finished_buf_, sizeof(bool) * m, cudaMemcpyDeviceToHost); + int sum = 0; + for(int i = 0; i < m; i++){ + sum += (int)h_finished_buf_[i]; + } + if(sum == m) break; + } // end for decoding step for llop + } // end of forward + + virtual ~DecodingBeamsearch() + { + delete [] K_cache_; + delete [] V_cache_; + delete [] K_mem_cache_; + delete [] V_mem_cache_; + delete [] h_finished_buf_; + delete decoder_; + allocator_.free(buf_); + } +}; + +} //namespace fastertransformer diff --git a/fastertransformer/decoding_opennmt.h b/fastertransformer/decoding_opennmt.h deleted file mode 100644 index a4079dcce..000000000 --- a/fastertransformer/decoding_opennmt.h +++ /dev/null @@ -1,387 +0,0 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * Decoder transformer - **/ - -#pragma once - -#include "fastertransformer/common.h" -#include "fastertransformer/allocator.h" -#include "fastertransformer/open_decoder.h" -#include "fastertransformer/cuda/cuda_kernels.h" -#include "fastertransformer/beamsearch_opennmt.h" -#include - -namespace fastertransformer -{ - -template -class DecodingInitParam -{ -public: - /* weights for masked_multi_head_attention */ - const T *embedding_table; - const T *embedding_kernel; - const float *embedding_bias; - - const T *memory_tensor; - const int *memory_sequence_length; - - LayerNormWeight layernorm; - - int *output_ids; - int *parent_ids; - int *sequence_length; - cublasHandle_t cublas_handle; - cudaStream_t stream; -}; - -template -class DecodingOpenNMT -{ -private: - typedef DecoderTransformerTraits Traits_; - typedef typename Traits_::DataType DataType_; - const IAllocator &allocator_; - - const cudaDataType_t computeType_ = Traits_::computeType; - const cudaDataType_t AType_ = Traits_::AType; - const cudaDataType_t BType_ = Traits_::BType; - const cudaDataType_t CType_ = Traits_::CType; - int cublasAlgo_[1] = {20}; - - int batch_size_; - int beam_width_; - int seq_len_; - int head_num_; - int size_per_head_; - int hidden_units_; - int decoder_layers_; - int vocab_size_; - OpenDecoder *decoder_; - DataType_ **K_cache_; - DataType_ **V_cache_; - DataType_ **K_mem_cache_; - DataType_ **V_mem_cache_; - DataType_ *from_tensor_[2]; - DataType_ *decoder_buf_; - DataType_ *decoder_normed_result_buf_; - float *logits_buf_; - float *cum_log_buf_; - int *word_ids_buf_; - int *topk_ids_buf_; - bool *finished_buf_; - void *buf_; - int start_id_; - int end_id_; - int *finished_count_buf_; - bool *h_finished_buf_; - -public: - DecodingOpenNMT(const IAllocator &allocator, const int batch_size, - const int beam_width, const int seq_len, - const int head_num, const int size_per_head, - const int vocab_size, const int decoder_layers, - const int memory_hidden_units, const int memory_max_seq_len, - const int start_id, const int end_id) : allocator_(allocator), batch_size_(batch_size), beam_width_(beam_width), - seq_len_(seq_len), head_num_(head_num), size_per_head_(size_per_head), - vocab_size_(vocab_size), decoder_layers_(decoder_layers), - start_id_(start_id), end_id_(end_id) - { -#ifndef NDEBUG - PRINT_FUNC_NAME_(); -#endif - K_cache_ = new DataType_ *[2]; - V_cache_ = new DataType_ *[2]; - - K_mem_cache_ = new DataType_ *[decoder_layers_]; - V_mem_cache_ = new DataType_ *[decoder_layers_]; - - hidden_units_ = head_num_ * size_per_head_; - decoder_ = new OpenDecoder(allocator, batch_size * beam_width, memory_max_seq_len, - head_num, size_per_head, memory_hidden_units); - - int from_tensor_size = batch_size_ * beam_width_ * hidden_units_; // type T - int decoder_workspace_size = decoder_->getWorkspaceSize(); // type T - int decoder_normed_result_buffer_size = batch_size_ * beam_width_ * hidden_units_; // type T - int cache_size = batch_size_ * beam_width_ * seq_len_ * hidden_units_; // type T - - int logits_buf_size = batch_size_ * beam_width_ * vocab_size_; // type float - int cum_log_buf_size = batch_size_ * beam_width_; // type float - int word_ids_buf_size = batch_size_ * beam_width_; //type int - int finished_buf_size = batch_size_ * beam_width_; //type bool - int finished_count_size = (int)(ceil(1 / 4.)) * 4; // type int - - //type int - int topk_ids_buf_size = batch_size_ * beam_width_ * (ceil)((beam_width_ * vocab_size_ * 1.0) / 1024.0); - // prevent memory misalinged address - cum_log_buf_size = (int)(ceil(cum_log_buf_size / 4.)) * 4; - word_ids_buf_size = (int)(ceil(word_ids_buf_size / 4.)) * 4; - finished_buf_size = (int)(ceil(finished_buf_size / 32.)) * 32; - topk_ids_buf_size = (int)(ceil(topk_ids_buf_size / 4.)) * 4; - - - int datatype_buf_size = from_tensor_size * 2 + decoder_workspace_size + - cache_size * 6 * decoder_layers_ + decoder_normed_result_buffer_size; - - buf_ = reinterpret_cast(allocator_.malloc( - sizeof(DataType_) * datatype_buf_size + - sizeof(float) * (logits_buf_size + cum_log_buf_size) + - sizeof(int) * word_ids_buf_size + - sizeof(bool) * finished_buf_size + - sizeof(int) * topk_ids_buf_size + - sizeof(int) * finished_count_size )); - - from_tensor_[0] = (DataType_ *)buf_; - from_tensor_[1] = (DataType_ *)(from_tensor_[0] + from_tensor_size); - - for (int i = 0; i < decoder_layers_; ++i) - { - K_mem_cache_[i] = from_tensor_[1] + from_tensor_size + i * cache_size * 2; - V_mem_cache_[i] = from_tensor_[1] + from_tensor_size + i * cache_size * 2 + cache_size; - } - - /* We use two-way buffer since we have to update KV buf at the end of each step. */ - K_cache_[0] = V_mem_cache_[decoder_layers - 1] + cache_size + 0 * cache_size * decoder_layers_; - K_cache_[1] = V_mem_cache_[decoder_layers - 1] + cache_size + 1 * cache_size * decoder_layers_; - V_cache_[0] = V_mem_cache_[decoder_layers - 1] + cache_size + 2 * cache_size * decoder_layers_; - V_cache_[1] = V_mem_cache_[decoder_layers - 1] + cache_size + 3 * cache_size * decoder_layers_; - - decoder_buf_ = V_cache_[1] + cache_size * decoder_layers_; - decoder_normed_result_buf_ = (decoder_buf_ + decoder_workspace_size); - logits_buf_ = (float *)(decoder_normed_result_buf_ + decoder_normed_result_buffer_size); - cum_log_buf_ = (float *)(logits_buf_ + logits_buf_size); - word_ids_buf_ = (int *)(cum_log_buf_ + cum_log_buf_size); - finished_buf_ = (bool *)(word_ids_buf_ + word_ids_buf_size); - topk_ids_buf_ = (int *)(finished_buf_ + finished_buf_size); - finished_count_buf_ = (int *)(topk_ids_buf_ + topk_ids_buf_size); - - h_finished_buf_ = new bool[finished_buf_size]; - - FILE *fd = fopen("decoding_gemm_config.in", "r"); - int err = 0; - if (fd == NULL) - printf("[WARNING] decoding_gemm_config.in is not found\n"); - else - { - err = fscanf(fd, "%d%*d%*d", &cublasAlgo_[0]); - fclose(fd); - } - if (err != 1) - { - printf("[WARNING] decoding loading GEMM algorithms error, using default GEMM algorithms!\n"); - if (Traits_::OpType == OperationType::FP32) - { - cublasAlgo_[0] = CUBLAS_GEMM_DEFAULT; - } - else - { - cublasAlgo_[0] = CUBLAS_GEMM_DEFAULT_TENSOR_OP; - } - } - else - { - // check that the gemm_config setting is runnable - if (Traits_::OpType == OperationType::FP32) - { - if (cublasAlgo_[0] > CUBLAS_GEMM_ALGO23 || cublasAlgo_[0] < CUBLAS_GEMM_DEFAULT) - { - // the algorithm is not for FP32 - printf("[ERROR] cuBLAS Algorithm %d is not used in FP32. \n", (int)cublasAlgo_[0]); - exit(-1); - } - } - else - { - if (cublasAlgo_[0] > CUBLAS_GEMM_ALGO15_TENSOR_OP || cublasAlgo_[0] < CUBLAS_GEMM_DEFAULT_TENSOR_OP) - { - // the algorithm is not for FP16 - printf("[ERROR] cuBLAS Algorithm %d is not used in FP16. \n", (int)cublasAlgo_[0]); - exit(-1); - } - } - } - } - - void forward(const DecoderInitParam *param, - DecodingInitParam decoding_params) - { - -#ifndef NDEBUG - PRINT_FUNC_NAME_(); -#endif - int m = batch_size_ * beam_width_; - int k = hidden_units_; - int n = vocab_size_; - - /* - sequence_length initialize to 0 - finished: false - word_ids: start_id_ - cum_log_probs (for eacm beam, the first element is 0). e.g., [0 -inf -inf -inf][0 -inf -inf -inf] - */ - -#ifdef NDEBUG - init(finished_buf_, decoding_params.sequence_length, word_ids_buf_, cum_log_buf_, - start_id_, batch_size_, beam_width_, decoding_params.stream); -#else - init(finished_buf_, decoding_params.sequence_length, word_ids_buf_, cum_log_buf_, - start_id_, batch_size_, beam_width_, decoding_params.stream); - - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); - - /* - User can check the init by init_kernel_check. - init_kernel_check will compare the results of GPU and CPU. - Note that init_kernel_check contains init and uses do not need to call it again. - */ - // init_kernel_check(finished_buf_, decoding_params.sequence_length, word_ids_buf_, cum_log_buf_, - // start_id_, batch_size_, beam_width_, decoding_params.stream); -#endif - - int cache_size = batch_size_ * beam_width_ * seq_len_ * hidden_units_; // type T - - for (int step = 1; step <= seq_len_; ++step) - { - //we use two-way buffer - int kv_cache_id = step & 0x1; - -#ifdef NDEBUG - embedding_lookup(decoding_params.embedding_table, word_ids_buf_, from_tensor_[0], - batch_size_, beam_width_, hidden_units_, decoding_params.stream); -#else - embedding_lookup(decoding_params.embedding_table, word_ids_buf_, from_tensor_[0], - batch_size_, beam_width_, hidden_units_, decoding_params.stream); - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); - - /* - User can check the embedding_lookup by embedding_lookup_kernel_check. - embedding_lookup_kernel_check will compare the results of GPU and CPU. - Note that embedding_lookup_kernel_check contains embedding_lookup and uses do not need to call it again. - */ - // embedding_lookup_kernel_check(decoding_params.embedding_table, word_ids_buf_, from_tensor_[0], - // batch_size_, beam_width_, hidden_units_, vocab_size_, decoding_params.stream); -#endif - - sine_position_encoder(from_tensor_[0], step, m, hidden_units_, decoding_params.stream); - - int from_id, out_id; - for (int layer = 0; layer < decoder_layers_; ++layer) - { - /* - For the first layer (layer-0), from_id is 0. We also stored the embedding lookup - result in from_tensor_[0] - */ - from_id = layer & 0x1; - out_id = 1 - from_id; - - /* - We use one decoder_ object to process multiple decoder layers. - - At the beginning of each decoder layer, we initialize the decoder object - with corresponding weights and decoder_buf_. - - The decoder_buf_ is reused. - */ - decoder_->initialize(param[layer], decoder_buf_); - -#ifndef NDEBUG - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); -#endif - decoder_->forward(from_tensor_[from_id], decoding_params.memory_tensor, - K_cache_[kv_cache_id] + layer * cache_size, - V_cache_[kv_cache_id] + layer * cache_size, - K_mem_cache_[layer], V_mem_cache_[layer], - decoding_params.memory_sequence_length, from_tensor_[out_id], step); - -#ifndef NDEBUG - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); -#endif - } - decoder_->decoder_norm1(from_tensor_[out_id], decoding_params.layernorm.gamma, - decoding_params.layernorm.beta, decoder_normed_result_buf_, m, k); - - float alpha = (float)1.0f; - float beta = (float)0.0f; - - check_cuda_error(cublasGemmEx(decoding_params.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - decoding_params.embedding_kernel, AType_, n, - decoder_normed_result_buf_, BType_, k, - &beta, - logits_buf_, CUDA_R_32F, n, - CUDA_R_32F, - static_cast(cublasAlgo_[0]))); - -#ifndef NDEBUG - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); -#endif - -#ifdef NDEBUG - update_logits(logits_buf_, decoding_params.embedding_bias, end_id_, finished_buf_, m, n, decoding_params.stream); -#else - update_logits(logits_buf_, decoding_params.embedding_bias, end_id_, finished_buf_, m, n, decoding_params.stream); - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); - - /* - User can check the update_logits by update_logits_kernel_check. - update_logits_kernel_check will compare the results of GPU and CPU. - Note that update_logits_kernel_check contains update_logits and uses do not need to call it again. - */ - // update_logits_kernel_check(logits_buf_, decoding_params.embedding_bias, end_id_, finished_buf_, m, n, decoding_params.stream); -#endif - BeamSearch_OpenNMT( - logits_buf_, cum_log_buf_, finished_buf_, - K_cache_, - V_cache_, - decoding_params.parent_ids + (step - 1) * batch_size_ * beam_width_, - decoding_params.sequence_length, - word_ids_buf_, - topk_ids_buf_, - decoding_params.output_ids + (step - 1) * batch_size_ * beam_width_, - batch_size_, beam_width_, vocab_size_, hidden_units_, step, cache_size, decoder_layers_, decoding_params.stream, - end_id_, - finished_count_buf_); - -#ifndef NDEBUG - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); -#endif - - // TODO - // Find a better method to check the is_finished - cudaMemcpy(h_finished_buf_, finished_buf_, sizeof(bool) * batch_size_ * beam_width_, cudaMemcpyDeviceToHost); - int sum = 0; - for(int i = 0; i < batch_size_ * beam_width_; i++){ - sum += (int)h_finished_buf_[i]; - } - if(sum == batch_size_ * beam_width_) break; - } - } - - virtual ~DecodingOpenNMT() {} -}; - -} //namespace fastertransformer diff --git a/fastertransformer/decoding_sampling.h b/fastertransformer/decoding_sampling.h new file mode 100644 index 000000000..cdb14c9a9 --- /dev/null +++ b/fastertransformer/decoding_sampling.h @@ -0,0 +1,422 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Decoder transformer + **/ + +#pragma once + +#include "fastertransformer/common.h" +#include "fastertransformer/allocator.h" +#include "fastertransformer/open_decoder.h" +#include "fastertransformer/cuda/cuda_kernels.h" +#include "fastertransformer/arguments.h" +#include + +namespace fastertransformer +{ + +template +class DecodingSampling +{ +private: + typedef DecoderTransformerTraits Traits_; + typedef typename Traits_::DataType DataType_; + const IAllocator &allocator_; + struct DecodingSamplingArguments args_; + + const cudaDataType_t computeType_ = Traits_::computeType; + const cudaDataType_t AType_ = Traits_::AType; + const cudaDataType_t BType_ = Traits_::BType; + const cudaDataType_t CType_ = Traits_::CType; + int cublasAlgo_[1] = {20}; + + OpenDecoder *decoder_; + DataType_ **K_cache_; + DataType_ **V_cache_; + DataType_ **K_mem_cache_; + DataType_ **V_mem_cache_; + DataType_ *from_tensor_[2]; + DataType_ *decoder_buf_; + DataType_ *decoder_normed_result_buf_; + float *logits_buf_; + float *cum_log_buf_; + int *word_ids_buf_; + bool *finished_buf_; + int *topk_ids_buf_; + float *topk_val_buf_; + void *buf_; + // int start_id_; + // int end_id_; + int *finished_count_buf_; + bool *h_finished_buf_; + + int *topp_id_vals_buf_; + float *topp_sorted_log_prob_buf_; + int *topp_sorted_id_vals_buf_; + int *topp_offset_buf_; + + void *temp_storage_; + // size_t temp_storage_size_; + + +public: + DecodingSampling(const IAllocator &allocator, const int batch_size, + const int seq_len, + const int head_num, const int size_per_head, + const int vocab_size, const int decoder_layers, + const int memory_hidden_units, const int memory_max_seq_len, + const int start_id, const int end_id, + const int candidate_num=0, + const float probability_threshold=0.0) : allocator_(allocator) + { + args_.batch_size_ = batch_size; + args_.seq_len_ = seq_len; + args_.head_num_ = head_num; + args_.size_per_head_ = size_per_head; + args_.hidden_units_ = head_num * size_per_head; + args_.decoder_layers_ = decoder_layers; + args_.vocab_size_ = vocab_size; + args_.candidate_num_ = candidate_num; + args_.probability_threshold_ = probability_threshold; + args_.start_id_ = start_id; + args_.end_id_ = end_id; + + if(args_.candidate_num_ == 0 && args_.probability_threshold_ == 0.0) + { + printf("[ERROR] Candidate_num for topk is 0 and probability threshold for top p is 0.0 \n"); + exit(-1); + } + else if(args_.candidate_num_ != 0 && args_.probability_threshold_ != 0.0) + { + printf("[ERROR] Candidate_num for topk is not 0 and probability threshold for top p is not 0.0 \n"); + exit(-1); + } +#ifndef NDEBUG + PRINT_FUNC_NAME_(); +#endif + K_cache_ = new DataType_ *[1]; + V_cache_ = new DataType_ *[1]; + + K_mem_cache_ = new DataType_ *[args_.decoder_layers_]; + V_mem_cache_ = new DataType_ *[args_.decoder_layers_]; + + decoder_ = new OpenDecoder(batch_size, memory_max_seq_len, + head_num, size_per_head, memory_hidden_units); + + int from_tensor_size = args_.batch_size_ * args_.hidden_units_; // type T + int decoder_workspace_size = decoder_->getWorkspaceSize(); // type T + int decoder_normed_result_buffer_size = args_.batch_size_ * args_.hidden_units_; // type T + int cache_size = args_.batch_size_ * args_.seq_len_ * args_.hidden_units_; // type T + int mem_cache_size = args_.batch_size_ * memory_max_seq_len * args_.hidden_units_; // type T + + int logits_buf_size = args_.batch_size_ * args_.vocab_size_; // type float + int cum_log_buf_size = args_.batch_size_; // type float + int word_ids_buf_size = args_.batch_size_; //type int + int finished_buf_size = args_.batch_size_; //type bool + int finished_count_size = (int)(ceil(1 / 32.)) * 32; // type int + + int topk_ids_buf_size = args_.batch_size_ * args_.candidate_num_; // type int + int topk_val_buf_size = args_.batch_size_ * args_.candidate_num_; // type float + int topp_id_vals_buf_size = args_.batch_size_ * args_.vocab_size_; + int topp_sorted_log_prob_buf_size = args_.batch_size_ * args_.vocab_size_; + int topp_sorted_id_vals_buf_size = args_.batch_size_ * args_.vocab_size_; + + // prevent memory misalinged address + logits_buf_size = (int)(ceil(logits_buf_size / 4.)) * 4; + cum_log_buf_size = (int)(ceil(cum_log_buf_size / 4.)) * 4; + word_ids_buf_size = (int)(ceil(word_ids_buf_size / 4.)) * 4; + finished_buf_size = (int)(ceil(finished_buf_size / 32.)) * 32; + topk_ids_buf_size = (int)(ceil(topk_ids_buf_size / 4.)) * 4; + topk_val_buf_size = (int)(ceil(topk_val_buf_size / 4.)) * 4; + topp_id_vals_buf_size = (int)(ceil(topp_id_vals_buf_size / 4.)) * 4; + topp_sorted_log_prob_buf_size = (int)(ceil(topp_sorted_log_prob_buf_size / 4.)) * 4; + topp_sorted_id_vals_buf_size = (int)(ceil(topp_sorted_id_vals_buf_size / 4.)) * 4; + + args_.temp_storage_size_ = get_topp_sort_temp_storage_size(logits_buf_, + topp_id_vals_buf_, + topp_sorted_log_prob_buf_, + topp_sorted_id_vals_buf_, + topp_offset_buf_, + args_.batch_size_, + args_.vocab_size_); + + int topp_offset_buf_size = args_.batch_size_ + 1; + args_.temp_storage_size_ = (int)(ceil(args_.temp_storage_size_ / 4.)) * 4; + topp_offset_buf_size = (int)(ceil(topp_offset_buf_size / 4.)) * 4; + + int datatype_buf_size = from_tensor_size * 2 + decoder_workspace_size + + (cache_size * 4 + mem_cache_size * 2) * args_.decoder_layers_ + decoder_normed_result_buffer_size; + + buf_ = reinterpret_cast(allocator_.malloc( + sizeof(DataType_) * datatype_buf_size + + sizeof(float) * (logits_buf_size + cum_log_buf_size) + + sizeof(int) * word_ids_buf_size + + sizeof(bool) * finished_buf_size + + sizeof(int) * finished_count_size + + sizeof(int) * topk_ids_buf_size + + sizeof(float) * topk_val_buf_size + + sizeof(int) * (topp_id_vals_buf_size + topp_sorted_id_vals_buf_size + topp_offset_buf_size) + + sizeof(float) * topp_sorted_log_prob_buf_size + + args_.temp_storage_size_ )); + + from_tensor_[0] = (DataType_ *)buf_; + from_tensor_[1] = (DataType_ *)(from_tensor_[0] + from_tensor_size); + + for (int i = 0; i < args_.decoder_layers_; ++i) + { + K_mem_cache_[i] = from_tensor_[1] + from_tensor_size + i * mem_cache_size * 2; + V_mem_cache_[i] = from_tensor_[1] + from_tensor_size + i * mem_cache_size * 2 + mem_cache_size; + } + + /* We use two-way buffer since we have to update KV buf at the end of each step. */ + K_cache_[0] = V_mem_cache_[args_.decoder_layers_ - 1] + mem_cache_size + 0 * cache_size * args_.decoder_layers_; + V_cache_[0] = V_mem_cache_[args_.decoder_layers_ - 1] + mem_cache_size + 1 * cache_size * args_.decoder_layers_; + + decoder_buf_ = V_cache_[0] + cache_size * args_.decoder_layers_; + decoder_normed_result_buf_ = (decoder_buf_ + decoder_workspace_size); + logits_buf_ = (float *)(decoder_normed_result_buf_ + decoder_normed_result_buffer_size); + cum_log_buf_ = (float *)(logits_buf_ + logits_buf_size); + word_ids_buf_ = (int *)(cum_log_buf_ + cum_log_buf_size); + finished_buf_ = (bool *)(word_ids_buf_ + word_ids_buf_size); + topk_ids_buf_ = (int *)(finished_buf_ + finished_buf_size); + topk_val_buf_ = (float*)(topk_ids_buf_ + topk_ids_buf_size); + finished_count_buf_ = (int *)(topk_val_buf_ + topk_val_buf_size); + topp_id_vals_buf_ = (int*)(finished_count_buf_ + finished_count_size); + topp_sorted_id_vals_buf_ = (int*)(topp_id_vals_buf_ + topp_id_vals_buf_size); + topp_offset_buf_ = (int*)(topp_sorted_id_vals_buf_ + topp_sorted_id_vals_buf_size); + topp_sorted_log_prob_buf_ = (float*)(topp_offset_buf_ + topp_offset_buf_size); + temp_storage_ = (void*)(topp_sorted_log_prob_buf_ + topp_sorted_log_prob_buf_size); + + h_finished_buf_ = new bool[finished_buf_size]; + + FILE *fd = fopen("decoding_gemm_config.in", "r"); + int err = 0; + if (fd == NULL) + printf("[WARNING] decoding_gemm_config.in is not found\n"); + else + { + err = fscanf(fd, "%d", &cublasAlgo_[0]); + fclose(fd); + } + if (err != 1) + { + printf("[WARNING] decoding loading GEMM algorithms error, using default GEMM algorithms!\n"); + if (Traits_::OpType == OperationType::FP32) + { + cublasAlgo_[0] = CUBLAS_GEMM_DEFAULT; + } + else + { + cublasAlgo_[0] = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + } + else + { + // check that the gemm_config setting is runnable + if (Traits_::OpType == OperationType::FP32) + { + if (cublasAlgo_[0] > CUBLAS_GEMM_ALGO23 || cublasAlgo_[0] < CUBLAS_GEMM_DEFAULT) + { + // the algorithm is not for FP32 + printf("[ERROR] cuBLAS Algorithm %d is not used in FP32. \n", (int)cublasAlgo_[0]); + exit(-1); + } + } + else + { + if (cublasAlgo_[0] > CUBLAS_GEMM_ALGO15_TENSOR_OP || cublasAlgo_[0] < CUBLAS_GEMM_DEFAULT_TENSOR_OP) + { + // the algorithm is not for FP16 + printf("[ERROR] cuBLAS Algorithm %d is not used in FP16. \n", (int)cublasAlgo_[0]); + exit(-1); + } + } + } + } + + void forward(const DecoderInitParam *param, + DecodingInitParam decoding_params) + { + +#ifndef NDEBUG + PRINT_FUNC_NAME_(); +#endif + const int m = args_.batch_size_; + const int k = args_.hidden_units_; + const int n = args_.vocab_size_; + + /* + sequence_length initialize to 0 + finished: false + word_ids: start_id_ + cum_log_buf_: useless, keep it to reuse the kernel of decoding_opennmt.h + */ + + if(args_.candidate_num_ != 0) + { + init_kernelLauncher(finished_buf_, decoding_params.sequence_length, word_ids_buf_, cum_log_buf_, + args_.start_id_, args_.batch_size_, 1, decoding_params.stream); + } + else if(args_.probability_threshold_ != 0.0) + { + topp_initialization_kernelLauncher(finished_buf_, + decoding_params.sequence_length, + word_ids_buf_, + topp_id_vals_buf_, + topp_offset_buf_, + args_, + decoding_params.stream); + } + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + int cache_size = args_.batch_size_ * args_.seq_len_ * args_.hidden_units_; // type T + + for (int step = 1; step <= args_.seq_len_; ++step) + { + embedding_lookup_sine_position_encoding_kernel_launcher(from_tensor_[0], + decoding_params.embedding_table, + decoding_params.position_encoding_table + (step - 1) * args_.hidden_units_, + word_ids_buf_, + args_.batch_size_, + args_.hidden_units_, + decoding_params.stream); + + int from_id, out_id; + for (int layer = 0; layer < args_.decoder_layers_; ++layer) + { + /* + For the first layer (layer-0), from_id is 0. We also stored the embedding lookup + result in from_tensor_[0] + */ + from_id = layer & 0x1; + out_id = 1 - from_id; + + /* + We use one decoder_ object to process multiple decoder layers. + + At the beginning of each decoder layer, we initialize the decoder object + with corresponding weights and decoder_buf_. + + The decoder_buf_ is reused. + */ + decoder_->initialize(param[layer], decoder_buf_); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + decoder_->forward(from_tensor_[from_id], decoding_params.memory_tensor, + K_cache_[0] + layer * cache_size, + V_cache_[0] + layer * cache_size, + K_mem_cache_[layer], V_mem_cache_[layer], + decoding_params.memory_sequence_length, from_tensor_[out_id], step); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + } + decoder_->decoder_norm1(from_tensor_[out_id], decoding_params.layernorm.gamma, + decoding_params.layernorm.beta, decoder_normed_result_buf_, m, k); + + float alpha = (float)1.0f; + float beta = (float)0.0f; + + check_cuda_error(cublasGemmEx(decoding_params.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + decoding_params.embedding_kernel, AType_, n, + decoder_normed_result_buf_, BType_, k, + &beta, + logits_buf_, CUDA_R_32F, n, + CUDA_R_32F, + static_cast(cublasAlgo_[0]))); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + if(args_.candidate_num_ != 0) + { + // top k sampling + update_logits_without_softmax(logits_buf_, decoding_params.embedding_bias, args_.end_id_, finished_buf_, m, n, decoding_params.stream); + topK_sampling_kernel_kernelLauncher(logits_buf_, + topk_ids_buf_, + topk_val_buf_, + decoding_params.output_ids + (step - 1) * args_.batch_size_, + decoding_params.sequence_length, + finished_buf_, + step, // used as random number + args_, + decoding_params.stream); + } + else if(args_.probability_threshold_ != 0.0) + { + // top p sampling + update_logits_without_log(logits_buf_, decoding_params.embedding_bias, args_.end_id_, finished_buf_, m, n, decoding_params.stream); + topP_sampling_kernel_kernelLauncher(logits_buf_, + topp_id_vals_buf_, + topp_sorted_log_prob_buf_, + topp_sorted_id_vals_buf_, + topp_offset_buf_, + temp_storage_, + finished_buf_, + step, + args_, + decoding_params.output_ids, + decoding_params.sequence_length, + decoding_params.stream); + } + + + word_ids_buf_ = decoding_params.output_ids + (step - 1) * args_.batch_size_; + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + // TODO + // Find a better method to check the is_finished + cudaMemcpy(h_finished_buf_, finished_buf_, sizeof(bool) * args_.batch_size_ , cudaMemcpyDeviceToHost); + int sum = 0; + for(int i = 0; i < args_.batch_size_ ; i++){ + sum += (int)h_finished_buf_[i]; + } + if(sum == args_.batch_size_ ) break; + } + } + + virtual ~DecodingSampling() + { + delete [] K_cache_; + delete [] V_cache_; + delete [] K_mem_cache_; + delete [] V_mem_cache_; + delete [] h_finished_buf_; + delete decoder_; + allocator_.free(buf_); + } +}; + +} //namespace fastertransformer diff --git a/fastertransformer/open_decoder.h b/fastertransformer/open_decoder.h index c9fc785df..5b5c7ce25 100644 --- a/fastertransformer/open_decoder.h +++ b/fastertransformer/open_decoder.h @@ -63,7 +63,6 @@ class OpenDecoder { private: typedef DecoderTransformerTraits Traits_; - const IAllocator &allocator_; typedef typename Traits_::DataType DataType_; DecoderInitParam param_; @@ -71,7 +70,7 @@ class OpenDecoder const cudaDataType_t AType_ = Traits_::AType; const cudaDataType_t BType_ = Traits_::BType; const cudaDataType_t CType_ = Traits_::CType; - int cublasAlgo_[4]; + int cublasAlgo_[5]; int batch_size_; int max_seq_len_; @@ -82,13 +81,19 @@ class OpenDecoder DataType_ *norm_from_tensor_buf_, *query_buf_, *context_buf_, *masked_output_buf_; DataType_ *norm_masked_output_buf_, *cross_output_buf_, *norm_cross_output_buf_, *ffn_inner_buf_; + DataType_ *key_buf_, *value_buf_; + + DataType_** qkv_kernel_; + DataType_** qkv_input_; + DataType_** qkv_buf_; + + bool is_fuse_QKV; public: - OpenDecoder(const IAllocator &allocator, - int batch_size, int seq_len, + OpenDecoder(int batch_size, int seq_len, int head_num, int size_per_head, int memory_hidden_units) : - allocator_(allocator), batch_size_(batch_size), + batch_size_(batch_size), max_seq_len_(seq_len), head_num_(head_num), size_per_head_(size_per_head), memory_hidden_units_(memory_hidden_units) @@ -111,34 +116,33 @@ class OpenDecoder { // first number is a setting for gemm in Decoding, which computes the embedding output. // so we need to skip the number - int tmp; - err = fscanf(fd, "%d%d%d%d%d%*d%*d", &tmp, &cublasAlgo_[0], &cublasAlgo_[1], &cublasAlgo_[2], &cublasAlgo_[3]); + float split_time, fused_time; + err = fscanf(fd, "%*d %*f %d %f %d %*f %d %*f %d %*f %d %f", &cublasAlgo_[0], &split_time, &cublasAlgo_[1], + &cublasAlgo_[2], &cublasAlgo_[3], &cublasAlgo_[4], &fused_time); + is_fuse_QKV = fused_time < split_time * 3 ? true : false; fclose(fd); } - if (err != 5) + if (err != 7) { printf("[WARNING] decoder loading GEMM algorithms error, using default GEMM algorithms!\n"); + int default_algo; if (Traits_::OpType == OperationType::FP32) { - cublasAlgo_[0] = CUBLAS_GEMM_DEFAULT; - cublasAlgo_[1] = CUBLAS_GEMM_DEFAULT; - cublasAlgo_[2] = CUBLAS_GEMM_DEFAULT; - cublasAlgo_[3] = CUBLAS_GEMM_DEFAULT; + default_algo = CUBLAS_GEMM_DEFAULT; } else { - cublasAlgo_[0] = CUBLAS_GEMM_DEFAULT_TENSOR_OP; - cublasAlgo_[1] = CUBLAS_GEMM_DEFAULT_TENSOR_OP; - cublasAlgo_[2] = CUBLAS_GEMM_DEFAULT_TENSOR_OP; - cublasAlgo_[3] = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + default_algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; } + for(int i = 0; i < 5; i++) cublasAlgo_[i] = default_algo; + is_fuse_QKV = false; } else { // check that the gemm_config setting is runnable if (Traits_::OpType == OperationType::FP32) { - for (int i = 0; i < 4; i++) + for (int i = 0; i < 5; i++) { if (cublasAlgo_[i] > CUBLAS_GEMM_ALGO23 || cublasAlgo_[i] < CUBLAS_GEMM_DEFAULT) { @@ -150,7 +154,7 @@ class OpenDecoder } else { - for (int i = 0; i < 4; i++) + for (int i = 0; i < 5; i++) { if (cublasAlgo_[i] > CUBLAS_GEMM_ALGO15_TENSOR_OP || cublasAlgo_[i] < CUBLAS_GEMM_DEFAULT_TENSOR_OP) { @@ -166,7 +170,7 @@ class OpenDecoder int getWorkspaceSize() { int buf_size = batch_size_ * hidden_units_; - return 12 * buf_size; + return 13 * buf_size + sizeof(DataType_*) * 9; } void initialize(DecoderInitParam param, DataType_ *buf) @@ -178,14 +182,31 @@ class OpenDecoder int buf_size = batch_size_ * hidden_units_; norm_from_tensor_buf_ = buf; query_buf_ = buf + buf_size; //store the query values (from_tensor * Q) in both masked and multi-head attention - context_buf_ = buf + 2 * buf_size; //store the context result (softmax(qk)v) in both masked and multi-head attention + key_buf_ = buf + 2 * buf_size; + value_buf_ = buf + 3 * buf_size; + context_buf_ = buf + 4 * buf_size; //store the context result (softmax(qk)v) in both masked and multi-head attention + + masked_output_buf_ = buf + 5 * buf_size; //masked_attention_output + norm_masked_output_buf_ = buf + 6 * buf_size; //norm(masked_attention_output) + + cross_output_buf_ = buf + 7 * buf_size; //mutli-head attention_output + norm_cross_output_buf_ = buf + 8 * buf_size; //norm(multi-head attention_output) + ffn_inner_buf_ = buf + 9 * buf_size; //4 buf size to store inner product - masked_output_buf_ = buf + 3 * buf_size; //masked_attention_output - norm_masked_output_buf_ = buf + 4 * buf_size; //norm(masked_attention_output) + qkv_kernel_ = (DataType_**)(ffn_inner_buf_ + 4 * buf_size); + qkv_input_ = qkv_kernel_ + 3; + qkv_buf_ = qkv_input_ + 3; + + if(is_fuse_QKV == true) + { + const DataType_* hA[] {param_.self_attention.query_weight.kernel, + param_.self_attention.key_weight.kernel, + param_.self_attention.value_weight.kernel, + norm_from_tensor_buf_, norm_from_tensor_buf_, norm_from_tensor_buf_, + query_buf_, key_buf_, value_buf_}; + cudaMemcpyAsync((void*)qkv_kernel_, hA, sizeof(DataType_*) * 9, cudaMemcpyHostToDevice, param_.stream); + } - cross_output_buf_ = buf + 5 * buf_size; //mutli-head attention_output - norm_cross_output_buf_ = buf + 6 * buf_size; //norm(multi-head attention_output) - ffn_inner_buf_ = buf + 7 * buf_size; //4 buf size to store inner product } void forward(const DataType_ *from_tensor, const DataType_ *memory_tensor, @@ -196,8 +217,8 @@ class OpenDecoder #ifndef NDEBUG PRINT_FUNC_NAME_(); #endif - int m = batch_size_; - int n = hidden_units_; + const int m = batch_size_; + const int n = hidden_units_; try { @@ -299,6 +320,18 @@ class OpenDecoder ~OpenDecoder() { + norm_from_tensor_buf_ = nullptr; + query_buf_ = nullptr; + key_buf_ = nullptr; + value_buf_ = nullptr; + context_buf_ = nullptr; + + masked_output_buf_ = nullptr; + norm_masked_output_buf_ = nullptr; + + cross_output_buf_ = nullptr; + norm_cross_output_buf_ = nullptr; + ffn_inner_buf_ = nullptr; } }; } //namespace fastertransformer diff --git a/fastertransformer/tf_op/CMakeLists.txt b/fastertransformer/tf_op/CMakeLists.txt index 2227178dd..a9115ff08 100644 --- a/fastertransformer/tf_op/CMakeLists.txt +++ b/fastertransformer/tf_op/CMakeLists.txt @@ -15,33 +15,28 @@ cmake_minimum_required(VERSION 3.8) set(tf_bert_transformer_files bert_transformer_op.cc - bert_transformer_op.cu.cc - ../cuda/open_attention.cu - ../cuda/cuda_kernels.cu ) set(tf_decoder_files decoder_op.cc - decoder_op.cu.cc - ../cuda/open_decoder.cu - ../cuda/cuda_kernels.cu ) -set(tf_decoding_files - decoding_op.cc - decoding_op.cu.cc - ../cuda/open_decoder.cu - ../cuda/cuda_kernels.cu - ../cuda/open_attention.cu - ../cuda/decoding_kernel_check.cpp +set(tf_decoding_beamsearch_files + decoding_beamsearch_op.cc +) + +set(tf_decoding_sampling_files + decoding_sampling_op.cc ) add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) add_definitions(-DGOOGLE_CUDA=1) add_library(tf_fastertransformer SHARED ${tf_bert_transformer_files}) -target_link_libraries(tf_fastertransformer PRIVATE -lcublas -lcudart -ltensorflow_framework) +target_link_libraries(tf_fastertransformer PRIVATE -lcublas -lcudart -ltensorflow_framework encoder) add_library(tf_decoder SHARED ${tf_decoder_files}) -target_link_libraries(tf_decoder PRIVATE -lcublas -lcudart -ltensorflow_framework) -add_library(tf_decoding SHARED ${tf_decoding_files}) -target_link_libraries(tf_decoding PRIVATE -lcublas -lcudart -ltensorflow_framework) +target_link_libraries(tf_decoder PUBLIC -lcublas -lcudart -ltensorflow_framework decoder encoder) +add_library(tf_decoding_beamsearch SHARED ${tf_decoding_beamsearch_files}) +target_link_libraries(tf_decoding_beamsearch PRIVATE -lcublas -lcudart -ltensorflow_framework decoder decoding) +add_library(tf_decoding_sampling SHARED ${tf_decoding_sampling_files}) +target_link_libraries(tf_decoding_sampling PRIVATE -lcublas -lcudart -ltensorflow_framework decoder decoding) diff --git a/fastertransformer/tf_op/bert_transformer_op.cc b/fastertransformer/tf_op/bert_transformer_op.cc index 0e6370baf..0f2bd8b17 100644 --- a/fastertransformer/tf_op/bert_transformer_op.cc +++ b/fastertransformer/tf_op/bert_transformer_op.cc @@ -13,15 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#define EIGEN_USE_GPU + #include "fastertransformer/faster_transformer.h" -#include "fastertransformer/tf_op/bert_transformer_op.h" #include "fastertransformer/tf_op/common_op.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/platform/logging.h" -#include +#include "fastertransformer/cuda/cuda_kernels.h" + namespace tensorflow { namespace @@ -32,12 +30,12 @@ using GPUDevice = Eigen::GpuDevice; REGISTER_OP("BertTransformer") .Input("from_tensor: T") .Input("to_tensor: T") - .Input("attr_kernel_q: T") - .Input("attr_kernel_k: T") - .Input("attr_kernel_v: T") - .Input("attr_bias_q: T") - .Input("attr_bias_k: T") - .Input("attr_bias_v: T") + .Input("attr_q_kernel: T") + .Input("attr_q_bias: T") + .Input("attr_k_kernel: T") + .Input("attr_k_bias: T") + .Input("attr_v_kernel: T") + .Input("attr_v_bias: T") .Input("attr_mask: T") .Input("attr_output_kernel: T") .Input("attr_output_bias: T") @@ -49,53 +47,14 @@ REGISTER_OP("BertTransformer") .Input("output_bias: T") .Input("output_layernorm_beta: T") .Input("output_layernorm_gamma: T") + .Input("sequence_id_offset: int32") // shape: [valid_word_num] .Output("output: T") .Attr("T: {float, half}") - .Attr("from_seq_len: int >= 1") - .Attr("to_seq_len: int >= 1") .Attr("head_num: int >= 1") .Attr("size_per_head: int >= 1") + .Attr("remove_padding: bool = true") .SetShapeFn([](shape_inference::InferenceContext *c) { - int from_seq_len, to_seq_len, head_num, size_per_head; - c->GetAttr("from_seq_len", &from_seq_len); - c->GetAttr("to_seq_len", &to_seq_len); - c->GetAttr("head_num", &head_num); - c->GetAttr("size_per_head", &size_per_head); - int rank = c->Rank(c->input(0)); - if (rank != 2 && rank != 3) - { - return errors::InvalidArgument("[@BertTransformer::ShapeInference] " - "invalid rank (from_tensor@input[0]): ", - rank, - ", should be 2 or 3"); - } - // calculate batch size - shape_inference::DimensionOrConstant from_len_dim((int64)from_seq_len); - shape_inference::DimensionHandle output_dim1; - shape_inference::DimensionHandle batch_dim; - shape_inference::ShapeHandle input0; - - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input0)); - if (rank == 3) - { // embedding_output, [batch_size, seq_len, hidden_size] - batch_dim = c->Dim(c->input(0), 0); - } - else - { // should be 2, transformer's output, [batch_size*seq_len, hidden_size] - shape_inference::DimensionHandle tmp; - TF_RETURN_IF_ERROR(c->Divide(c->Dim(c->input(0), 0), from_len_dim, - true, &tmp)); - batch_dim = tmp; - } - - TF_RETURN_IF_ERROR(c->Multiply(batch_dim, from_len_dim, &output_dim1)); - - VLOG(2) << "[@BertTransformer::ShapeInference] batch_size: " - << c->Value(shape_inference::DimensionOrConstant(batch_dim)) - << ", output shape: [" << c->Value(shape_inference::DimensionOrConstant(output_dim1)) - << "," << head_num * size_per_head << "]\n"; - - c->set_output(0, c->MakeShape({output_dim1, head_num * size_per_head})); + c->set_output(0, c->input(0)); return Status::OK(); }); template @@ -104,40 +63,31 @@ class BertTransformerOp : public CommonOp public: explicit BertTransformerOp(OpKernelConstruction *context) : CommonOp(context) { - OP_REQUIRES_OK(context, context->GetAttr("from_seq_len", &from_seq_len_)); - OP_REQUIRES_OK(context, context->GetAttr("to_seq_len", &to_seq_len_)); OP_REQUIRES_OK(context, context->GetAttr("head_num", &head_num_)); OP_REQUIRES_OK(context, context->GetAttr("size_per_head", &size_per_head_)); - OP_REQUIRES(context, (from_seq_len_ == to_seq_len_), - errors::InvalidArgument("Only support from_seq_len == to_seq_len")); + OP_REQUIRES_OK(context, context->GetAttr("remove_padding", &remove_padding_)); } void Compute(OpKernelContext *context) override { int rank = (int)context->input(0).dims(); - if (rank != 2 && rank != 3) - { - OP_REQUIRES(context, false, - errors::InvalidArgument("[@BertTransformer::Compute] " - "invalid rank (from_tensor@input[0]): ", - rank, - ", should be 2 or 3")); - } - else if (rank == 3) - { // [batch_size, from_seq_len, hidden_size] - batch_size_ = (int)context->input(0).dim_size(0); - } - else - { // [batch_size * from_seq_len, hidden_size] - batch_size_ = (int)context->input(0).dim_size(0) / from_seq_len_; - } + OP_REQUIRES(context, rank==3 || rank == 2, + errors::InvalidArgument("Invalid rank. The rank of from tensor should be 3 or 2 \ + ([batch size, sequence length, hidden dimension] or [valid_word_num, hidden_dimension])")); + OP_REQUIRES(context, context->input(8).dims() == 4, + errors::InvalidArgument("Invalid rank. The rank of attention mask should be 4 " \ + "([batch_size, 1, seq_len, seq_len])")); - VLOG(2) << "[@BertTransformer::Compute] getting batch size: " - << batch_size_ << "\n"; + batch_size_ = (int)context->input(8).dim_size(0); + from_seq_len_ = (int)context->input(8).dim_size(2); + to_seq_len_ = (int)context->input(8).dim_size(3); + OP_REQUIRES(context, (from_seq_len_ == to_seq_len_), + errors::InvalidArgument("Only support from_seq_len == to_seq_len")); typedef BertEncoderTransformerTraits EncoderTraits_; BertEncoderTransformer *encoder_transformer_; - fastertransformer::Allocator allocator_(context); + const cudaStream_t &stream = context->eigen_device().stream(); + fastertransformer::Allocator allocator_(context, stream); try { encoder_transformer_ = new BertEncoderTransformer(allocator_, @@ -151,18 +101,19 @@ class BertTransformerOp : public CommonOp { OP_REQUIRES(context, false, errors::Internal(error.what())); } - - OP_REQUIRES(context, context->num_inputs() == 19, errors::InvalidArgument("Less input arguments")); - + OP_REQUIRES(context, context->num_inputs() == 20, errors::InvalidArgument("Less input arguments")); + const int hidden_units = head_num_ * size_per_head_; EncoderInitParam param; //init param here + param.stream = stream; param.cublas_handle = this->get_cublas_handler(); + check_cuda_error(cublasSetStream(param.cublas_handle, param.stream)); this->get_tensor(context, 0, ¶m.from_tensor); this->get_tensor(context, 1, ¶m.to_tensor); this->get_tensor(context, 2, ¶m.self_attention.query_weight.kernel); - this->get_tensor(context, 3, ¶m.self_attention.key_weight.kernel); - this->get_tensor(context, 4, ¶m.self_attention.value_weight.kernel); - this->get_tensor(context, 5, ¶m.self_attention.query_weight.bias); - this->get_tensor(context, 6, ¶m.self_attention.key_weight.bias); + this->get_tensor(context, 3, ¶m.self_attention.query_weight.bias); + this->get_tensor(context, 4, ¶m.self_attention.key_weight.kernel); + this->get_tensor(context, 5, ¶m.self_attention.key_weight.bias); + this->get_tensor(context, 6, ¶m.self_attention.value_weight.kernel); this->get_tensor(context, 7, ¶m.self_attention.value_weight.bias); this->get_tensor(context, 8, ¶m.attr_mask); this->get_tensor(context, 9, ¶m.self_attention.attention_output_weight.kernel); @@ -176,24 +127,48 @@ class BertTransformerOp : public CommonOp this->get_tensor(context, 17, ¶m.ffn_layernorm.beta); this->get_tensor(context, 18, ¶m.ffn_layernorm.gamma); + int valid_word_num; + if(remove_padding_ == true) + { + valid_word_num = (int)context->input(19).dim_size(0); + param.sequence_id_offset = reinterpret_cast(context->input(19).flat().data()); + OP_REQUIRES(context, param.sequence_id_offset != nullptr, errors::InvalidArgument("sequence_id_offset is null")); + } + else + { + param.sequence_id_offset = nullptr; + valid_word_num = batch_size_ * from_seq_len_; + } + param.valid_word_num = valid_word_num; + Tensor *output = nullptr; OP_REQUIRES_OK( context, - context->allocate_output(0, {batch_size_ * from_seq_len_, head_num_ * size_per_head_}, &output)); - + context->allocate_output(0, context->input(0).shape(), &output)); param.transformer_out = reinterpret_cast(output->flat().data()); - OP_REQUIRES_OK( - context, - functor::BertTransformerOpFunctor::Compute( - context, - param, - encoder_transformer_)); + try + { + encoder_transformer_->initialize(param); + encoder_transformer_->forward(); + } + catch(std::runtime_error& error) + { + std::cout << errors::Internal(error.what()); + exit(-1); + } + catch(...) + { + std::cout << errors::Internal("Runtime error"); + exit(-1); + } + delete encoder_transformer_; } private: int batch_size_, from_seq_len_, to_seq_len_, head_num_, size_per_head_; + bool remove_padding_; typedef TFTraits traits_; typedef typename traits_::DataType DataType_; }; @@ -209,5 +184,226 @@ REGISTER_GPU(Eigen::half); #undef REGISTER_GPU #endif + +/* ******************************** Build mask and Remove padding *************************** */ + +REGISTER_OP("BuildMaskRemovePadding") + .Input("from_tensor: T") // shape: [batch_size, max_seq_len, hidden_dim] + .Input("sequence_length: int32") // shape: [batch_size] + .Output("output: T") // shpae: [valid_word_num, hidden_dim] + .Output("sequence_id_offset: int32") // shape: [valid_word_num] + .Attr("T: {float, half}") + .SetShapeFn([](shape_inference::InferenceContext *c) { + + assert(c->Rank(c->input(0)) == 3); + assert(c->Rank(c->input(1)) == 1); + c->set_output(0, c->MakeShape({shape_inference::InferenceContext::kUnknownDim, c->Dim(c->input(0), 2)})); + c->set_output(1, c->MakeShape({shape_inference::InferenceContext::kUnknownDim})); + + return Status::OK(); + }); + +template +class BuildMaskRemovePaddingOp : public CommonOp +{ +public: + explicit BuildMaskRemovePaddingOp(OpKernelConstruction *context) : CommonOp(context) + { + } + + void Compute(OpKernelContext *context) override + { + OP_REQUIRES(context, context->num_inputs() == 2, errors::InvalidArgument("Less input arguments")); + OP_REQUIRES(context, context->input(0).dims()==3, + errors::InvalidArgument("Invalid rank. The rank of from tensor should be 3 \ + ([batch size, sequence length, hidden dimension])")); + OP_REQUIRES(context, context->input(1).dims()==1, + errors::InvalidArgument("Invalid rank. The rank of sequence_id_offset should be 1 \ + ([batch_size])")); + const int batch_size = (int)context->input(0).dim_size(0); + const int max_seq_len = (int)context->input(0).dim_size(1); + const int hidden_dim = (int)context->input(0).dim_size(2); + + const DataType_* input_ptr = reinterpret_cast(context->input(0).flat().data()); + const int* sequence_length = reinterpret_cast(context->input(1).flat().data()); + OP_REQUIRES(context, input_ptr != nullptr, errors::InvalidArgument("input_ptr is null")); + OP_REQUIRES(context, sequence_length != nullptr, errors::InvalidArgument("sequence_length is null")); + + Tensor buf; + long long int buf_size = (long long int)(ceil((batch_size * max_seq_len + 1) * sizeof(int) / 4.) * 4); + tensorflow::Status status = context->allocate_temp(DT_UINT8, TensorShape{buf_size}, &buf); + if (status != tensorflow::Status::OK()) + throw std::runtime_error("TF error: context->allocate_temp failed"); + + int* tmp_sequence_id_offset = (int*)buf.flat().data(); + int* d_valid_word_num = tmp_sequence_id_offset + batch_size * max_seq_len; + + const cudaStream_t &stream = context->eigen_device().stream(); + + try + { + build_sequence_length_padding_offset_kernelLauncher(sequence_length, + batch_size, max_seq_len, d_valid_word_num, tmp_sequence_id_offset, stream); + } + catch(std::runtime_error& error) + { + std::cout << errors::Internal(error.what()); + exit(-1); + } + catch(...) + { + std::cout << errors::Internal("Runtime error"); + exit(-1); + } + + int* h_valid_word_num = new int[1]; + cudaMemcpyAsync(h_valid_word_num, d_valid_word_num, sizeof(int), cudaMemcpyDeviceToHost, stream); + const int valid_word_num = h_valid_word_num[0]; + delete h_valid_word_num; + + Tensor *output = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output(0, {valid_word_num, hidden_dim}, &output)); + DataType_* output_ptr = reinterpret_cast(output->flat().data()); + + Tensor *sequence_id_offset_buf = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output(1, {valid_word_num}, &sequence_id_offset_buf)); + int* sequence_id_offset = reinterpret_cast(sequence_id_offset_buf->flat().data()); + + try + { + remove_sequence_length_padding_kernelLauncher(input_ptr, output_ptr, + tmp_sequence_id_offset, + sequence_id_offset, + valid_word_num, hidden_dim, + stream); + } + catch(std::runtime_error& error) + { + std::cout << errors::Internal(error.what()); + exit(-1); + } + catch(...) + { + std::cout << errors::Internal("Runtime error"); + exit(-1); + } + } +private: + typedef TFTraits traits_; + typedef typename traits_::DataType DataType_; +}; + +#ifdef GOOGLE_CUDA + +#define REGISTER_GPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("BuildMaskRemovePadding").Device(DEVICE_GPU).TypeConstraint("T"), \ + BuildMaskRemovePaddingOp) +REGISTER_GPU(float); +REGISTER_GPU(Eigen::half); +#undef REGISTER_GPU + +#endif + +/* ******************************** Rebuild padding *************************** */ + +REGISTER_OP("RebuildPadding") + .Input("from_tensor: T") // shape: [valid_word_num, hidden_dim] + .Input("sequence_id_offset: int32") // shape: [valid_word_num] + .Input("atten_mask: T") // shape: [batch_size, 1, seq_len, seq_len] + .Output("output: T") // shape: [batch_size, seq_len, hidden_dim] + .Attr("T: {float, half}") + .SetShapeFn([](shape_inference::InferenceContext *c) { + + assert(c->Rank(c->input(0)) == 2); + assert(c->Rank(c->input(1)) == 1); + assert(c->Rank(c->input(2)) == 4); + shape_inference::DimensionHandle batch_size = c->Dim(c->input(2), 0); + shape_inference::DimensionHandle seq_len = c->Dim(c->input(2), 2); + shape_inference::DimensionHandle hidden_dim = c->Dim(c->input(0), 1); + c->set_output(0, c->MakeShape({batch_size, seq_len, hidden_dim})); + return Status::OK(); + }); + +template +class RebuildPaddingOp : public CommonOp +{ +public: + explicit RebuildPaddingOp(OpKernelConstruction *context) : CommonOp(context) + { + } + + void Compute(OpKernelContext *context) override + { + OP_REQUIRES(context, context->num_inputs() == 3, errors::InvalidArgument("Less input arguments")); + + OP_REQUIRES(context, context->input(0).dims()==2, + errors::InvalidArgument("Invalid rank. The rank of from tensor should be 2 " \ + "([valid_word_num, hidden_dimension])")); + OP_REQUIRES(context, context->input(1).dims()==1, + errors::InvalidArgument("Invalid rank. The rank of sequence_id_offset should be 1 " \ + "([vaoid_word_num])")); + OP_REQUIRES(context, context->input(2).dims()==4, + errors::InvalidArgument("Invalid rank. The rank of attention mask should be 4 " \ + "([batch_size, 1, seq_len, seq_len])")); + + const int batch_size = (int)context->input(2).dim_size(0); + const int seq_len = (int)context->input(2).dim_size(2); + const int hidden_dim = (int)context->input(0).dim_size(1); + const int valid_word_num = (int)context->input(1).dim_size(0); + + const DataType_* input_ptr = reinterpret_cast(context->input(0).flat().data()); + const int* sequence_id_offset = reinterpret_cast(context->input(1).flat().data()); + OP_REQUIRES(context, input_ptr != nullptr, errors::InvalidArgument("input_ptr is null")); + OP_REQUIRES(context, sequence_id_offset != nullptr, errors::InvalidArgument("sequence_id_offset is null")); + + Tensor *output = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output(0, {batch_size, seq_len, hidden_dim}, &output)); + DataType_* output_ptr = reinterpret_cast(output->flat().data()); + + const cudaStream_t &stream = context->eigen_device().stream(); + cudaMemsetAsync(output_ptr, 0, sizeof(DataType_) * batch_size * seq_len * hidden_dim, stream); + try + { + rebuild_sequence_length_padding_kernelLauncher(input_ptr, output_ptr, + sequence_id_offset, + valid_word_num, hidden_dim, + stream); + } + catch(std::runtime_error& error) + { + std::cout << errors::Internal(error.what()); + exit(-1); + } + catch(...) + { + std::cout << errors::Internal("Runtime error"); + exit(-1); + } + } +private: + typedef TFTraits traits_; + typedef typename traits_::DataType DataType_; +}; + +#ifdef GOOGLE_CUDA + +#define REGISTER_GPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("RebuildPadding").Device(DEVICE_GPU).TypeConstraint("T"), \ + RebuildPaddingOp) +REGISTER_GPU(float); +REGISTER_GPU(Eigen::half); +#undef REGISTER_GPU + +#endif + } //namespace } //namespace tensorflow + diff --git a/fastertransformer/tf_op/bert_transformer_op.cu.cc b/fastertransformer/tf_op/bert_transformer_op.cu.cc deleted file mode 100644 index 20ff89592..000000000 --- a/fastertransformer/tf_op/bert_transformer_op.cu.cc +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifdef GOOGLE_CUDA -#define EIGEN_USE_GPU -#include "fastertransformer/tf_op/bert_transformer_op.h" -#include "fastertransformer/common.h" -#include "fastertransformer/faster_transformer.h" -#include "tensorflow/core/framework/op.h" -#include -#include -namespace tensorflow -{ -using GPUDevice = Eigen::GpuDevice; -using namespace fastertransformer; - -namespace functor -{ -template -struct BertTransformerOpFunctor -{ - typedef typename TFTraits::DataType DataType_; - static Status Compute(OpKernelContext *context, - EncoderInitParam param, - BertEncoderTransformer::OpType, - cuda::OpenMultiHeadAttention > > *encoder_transformer) - { - const cudaStream_t &stream = context->eigen_device().stream(); - param.stream = stream; - try - { - check_cuda_error(cublasSetStream(param.cublas_handle, stream)); - encoder_transformer->initialize(param); - encoder_transformer->forward(); - return Status::OK(); - } - catch(std::runtime_error& error) - { - return errors::Internal(error.what()); - } - catch(...) - { - return errors::Internal("Runtime error"); - } - } -}; -} //namespace functor - -template struct functor::BertTransformerOpFunctor; -template struct functor::BertTransformerOpFunctor; -} //namespace tensorflow -#endif diff --git a/fastertransformer/tf_op/bert_transformer_op.h b/fastertransformer/tf_op/bert_transformer_op.h deleted file mode 100644 index 9570af0c9..000000000 --- a/fastertransformer/tf_op/bert_transformer_op.h +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef TENSORFLOW_CORE_KERNELS_MULTIHEADATTR_OP_H_ -#define TENSORFLOW_CORE_KERNELS_MULTIHEADATTR_OP_H_ - -#include "fastertransformer/common.h" -#include "fastertransformer/faster_transformer.h" -#include "fastertransformer/tf_op/tf_traits.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include -using namespace fastertransformer; -namespace tensorflow -{ - namespace functor - { - template - struct BertTransformerOpFunctor - { - typedef typename TFTraits::DataType DataType_; - static Status Compute(OpKernelContext *context, - EncoderInitParam param, - BertEncoderTransformer::OpType, - cuda::OpenMultiHeadAttention > > *encoder_transformer); - }; - } //namespace functor -} //namespace tensorflow -#endif diff --git a/fastertransformer/tf_op/common_op.h b/fastertransformer/tf_op/common_op.h index bbae1296b..48d71ed18 100644 --- a/fastertransformer/tf_op/common_op.h +++ b/fastertransformer/tf_op/common_op.h @@ -13,12 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "fastertransformer/open_decoder.h" -#include "fastertransformer/tf_op/decoder_op.h" + +#pragma once + +#ifndef TENSORFLOW_COMMON_OP_H +#define TENSORFLOW_COMMON_OP_H + +#include "fastertransformer/common.h" +#include "fastertransformer/tf_op/tf_traits.h" + #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + #include namespace tensorflow @@ -35,11 +47,11 @@ class CommonOp : public OpKernel explicit CommonOp(OpKernelConstruction *context) : OpKernel(context) { try { - check_cuda_error(cublasCreate(&cublas_handle_)); + check_cuda_error(cublasCreate(&cublas_handle_)); } catch(std::runtime_error& error) { - OP_REQUIRES(context, false, errors::Internal(error.what())); + OP_REQUIRES(context, false, errors::Internal(error.what())); } }; @@ -59,3 +71,5 @@ class CommonOp : public OpKernel } //namespace } //namespace tensorflow + +#endif diff --git a/fastertransformer/tf_op/decoder_op.cc b/fastertransformer/tf_op/decoder_op.cc index b0ff9af1e..3ac76e909 100644 --- a/fastertransformer/tf_op/decoder_op.cc +++ b/fastertransformer/tf_op/decoder_op.cc @@ -13,14 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#define EIGEN_USE_GPU + #include "fastertransformer/open_decoder.h" -#include "fastertransformer/tf_op/decoder_op.h" #include "fastertransformer/tf_op/common_op.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/framework/register_types.h" -#include namespace tensorflow { @@ -30,38 +27,38 @@ using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; REGISTER_OP("Decoder") - .Input("from_tensor: T") - .Input("memory_tensor: T") - .Input("memory_sequence_length: int32") - .Input("self_beta: T") - .Input("self_gamma: T") - .Input("self_q_kernel: T") - .Input("self_q_bias: T") - .Input("self_k_kernel: T") - .Input("self_k_bias: T") - .Input("self_v_kernel: T") - .Input("self_v_bias: T") - .Input("self_output_kernel: T") - .Input("self_output_bias: T") - .Input("cross_beta: T") - .Input("cross_gamma: T") - .Input("cross_q_kernel: T") - .Input("cross_q_bias: T") - .Input("cross_k_kernel: T") - .Input("cross_k_bias: T") - .Input("cross_v_kernel: T") - .Input("cross_v_bias: T") - .Input("cross_output_kernel: T") - .Input("cross_output_bias: T") - .Input("ffn_beta: T") - .Input("ffn_gamma: T") - .Input("ffn_kernel1: T") - .Input("ffn_bias1: T") - .Input("ffn_kernel2: T") - .Input("ffn_bias2: T") - .Input("old_self_cache: T") - .Input("old_mem_cache: T") - .Input("pseudo_input: T") // pseudo input, used to prevent the parallel execution for OP and TF + .Input("from_tensor: T") // # 0 + .Input("memory_tensor: T") // # 1 + .Input("memory_sequence_length: int32") // # 2 + .Input("self_beta: T") // # 3 + .Input("self_gamma: T") // # 4 + .Input("self_q_kernel: T") // # 5 + .Input("self_q_bias: T") // # 6 + .Input("self_k_kernel: T") // # 7 + .Input("self_k_bias: T") // # 8 + .Input("self_v_kernel: T") // # 9 + .Input("self_v_bias: T") // # 10 + .Input("self_output_kernel: T") // # 11 + .Input("self_output_bias: T") // # 12 + .Input("cross_beta: T") // # 13 + .Input("cross_gamma: T") // # 14 + .Input("cross_q_kernel: T") // # 15 + .Input("cross_q_bias: T") // # 16 + .Input("cross_k_kernel: T") // # 17 + .Input("cross_k_bias: T") // # 18 + .Input("cross_v_kernel: T") // # 19 + .Input("cross_v_bias: T") // # 20 + .Input("cross_output_kernel: T") // # 21 + .Input("cross_output_bias: T") // # 22 + .Input("ffn_beta: T") // # 23 + .Input("ffn_gamma: T") // # 24 + .Input("ffn_kernel1: T") // # 25 + .Input("ffn_bias1: T") // # 26 + .Input("ffn_kernel2: T") // # 27 + .Input("ffn_bias2: T") // # 28 + .Input("old_self_cache: T") // # 29 + .Input("old_mem_cache: T") // # 30 + .Input("pseudo_input: T") // # 31, pseudo input, used to prevent the parallel execution for OP and TF .Output("decoder_output: T") .Output("new_self_cache: T") .Output("new_mem_cache: T") @@ -94,11 +91,12 @@ class DecoderOp : public CommonOp typedef DecoderTransformerTraits DecoderTraits_; OpenDecoder *decoder_; - fastertransformer::Allocator allocator_(context); + const cudaStream_t &stream = context->eigen_device().stream(); + fastertransformer::Allocator allocator_(context, stream); try { - decoder_ = new OpenDecoder(allocator_, batch_size_, - max_seq_len_, head_num_, size_per_head_, memory_hidden_dim_); + decoder_ = new OpenDecoder(batch_size_,max_seq_len_, + head_num_, size_per_head_, memory_hidden_dim_); } catch (std::runtime_error &error) { @@ -112,15 +110,13 @@ class DecoderOp : public CommonOp context->allocate_output(0, {batch_size_, 1, head_num_ * size_per_head_}, &decoder_output_tensor)); DataType_ *decoder_output = reinterpret_cast(decoder_output_tensor->flat().data()); - Tensor self_cache_tensor = context->mutable_input(29, true); + Tensor self_cache_tensor = context->input(29); context->set_output(1, self_cache_tensor); - DataType_ *self_cache; - self_cache = reinterpret_cast(self_cache_tensor.flat().data()); + DataType_ *self_cache = reinterpret_cast(self_cache_tensor.flat().data()); - Tensor memory_cache_tensor = context->mutable_input(30, true); + Tensor memory_cache_tensor = context->input(30); context->set_output(2, memory_cache_tensor); - DataType_ *memory_cache; - memory_cache = reinterpret_cast(memory_cache_tensor.flat().data()); + DataType_ *memory_cache = reinterpret_cast(memory_cache_tensor.flat().data()); const DataType_ *from_tensor = reinterpret_cast(context->input(0).flat().data()); const DataType_ *memory_tensor = reinterpret_cast(context->input(1).flat().data()); @@ -132,7 +128,10 @@ class DecoderOp : public CommonOp DecoderInitParam params; params.cublas_handle = this->get_cublas_handler(); + params.stream = stream; + check_cuda_error(cublasSetStream(params.cublas_handle, params.stream)); + const int hidden_units = head_num_ * size_per_head_; this->get_tensor(context, 3, ¶ms.self_layernorm.beta); this->get_tensor(context, 4, ¶ms.self_layernorm.gamma); @@ -165,23 +164,31 @@ class DecoderOp : public CommonOp const int step = (int)context->input(29).dim_size(1); DataType_ *K_cache = self_cache; - DataType_ *V_cache = self_cache + batch_size_ * step * head_num_ * size_per_head_; + DataType_ *V_cache = self_cache + batch_size_ * step * hidden_units; DataType_ *K_mem_cache = memory_cache; - DataType_ *V_mem_cache = memory_cache + batch_size_ * max_seq_len_ * head_num_ * size_per_head_; + DataType_ *V_mem_cache = memory_cache + batch_size_ * max_seq_len_ * hidden_units; const int decoder_buffer_size = decoder_->getWorkspaceSize() * sizeof(DataType_); DataType_ *decoder_buffer = (DataType_ *)allocator_.malloc(decoder_buffer_size); - OP_REQUIRES_OK( - context, - functor::DecoderOpFunctor::DynamicDecode( - context, - params, - decoder_, decoder_buffer, - from_tensor, memory_tensor, - K_cache, V_cache, - K_mem_cache, V_mem_cache, - memory_sequence_length, - decoder_output, step)); + try + { + decoder_->initialize(params, decoder_buffer); + decoder_->forward(from_tensor, memory_tensor, + K_cache, V_cache, + K_mem_cache, V_mem_cache, + memory_sequence_length, decoder_output, step); + + } + catch(std::runtime_error& error) + { + std::cout << errors::Internal(error.what()); + exit(-1); + } + catch(...) + { + std::cout << errors::Internal("Runtime error"); + exit(-1); + } allocator_.free(decoder_buffer); delete decoder_; @@ -191,6 +198,7 @@ class DecoderOp : public CommonOp int head_num_, size_per_head_; typedef TFTraits traits_; typedef typename traits_::DataType DataType_; + }; #ifdef GOOGLE_CUDA diff --git a/fastertransformer/tf_op/decoder_op.cu.cc b/fastertransformer/tf_op/decoder_op.cu.cc deleted file mode 100644 index 385f41aa9..000000000 --- a/fastertransformer/tf_op/decoder_op.cu.cc +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 -* - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifdef GOOGLE_CUDA -#define EIGEN_USE_GPU -#include "fastertransformer/tf_op/decoder_op.h" -#include "fastertransformer/beamsearch_opennmt.h" -#include "fastertransformer/common.h" -#include "fastertransformer/open_decoder.h" -#include "tensorflow/core/framework/op.h" -#include -#include -namespace tensorflow -{ -using GPUDevice = Eigen::GpuDevice; -using namespace fastertransformer; - -namespace functor -{ -template -struct DecoderOpFunctor -{ - typedef typename TFTraits::DataType DataType_; - static Status DynamicDecode(OpKernelContext *context, - DecoderInitParam params, - OpenDecoder::OpType> *decoder, DataType_ *decoder_buffer, - const DataType_ *from_tensor, const DataType_ *memory_tensor, - DataType_ *key_cache, DataType_ *value_cache, - DataType_ *key_mem_cache, DataType_ *value_mem_cache, - const int* memory_sequence_length, - DataType_ *decoder_output, const int step) - { - const cudaStream_t &stream = context->eigen_device().stream(); - params.stream = stream; - try - { - check_cuda_error(cublasSetStream(params.cublas_handle, stream)); - decoder->initialize(params, decoder_buffer); - decoder->forward(from_tensor, memory_tensor, - key_cache, value_cache, - key_mem_cache, value_mem_cache, - memory_sequence_length, decoder_output, step); - - return Status::OK(); - } - catch(std::runtime_error& error) - { - return errors::Internal(error.what()); - } - catch(...) - { - return errors::Internal("Runtime error"); - } - } -}; -} //namespace functor - -template struct functor::DecoderOpFunctor; -template struct functor::DecoderOpFunctor; -} //namespace tensorflow -#endif diff --git a/fastertransformer/tf_op/decoder_op.h b/fastertransformer/tf_op/decoder_op.h deleted file mode 100644 index b8f12131c..000000000 --- a/fastertransformer/tf_op/decoder_op.h +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef TENSORFLOW_CORE_KERNELS_DECODER_OP_H_ -#define TENSORFLOW_CORE_KERNELS_DECODER_OP_H_ - -#include "fastertransformer/common.h" -#include "fastertransformer/open_decoder.h" -#include "fastertransformer/tf_op/tf_traits.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include - -using namespace fastertransformer; -namespace tensorflow -{ - namespace functor - { - template - struct DecoderOpFunctor - { - typedef typename TFTraits::DataType DataType_; - static Status DynamicDecode(OpKernelContext *context, - DecoderInitParam param, - OpenDecoder::OpType> *decoder, DataType_ *decoder_buffer, - const DataType_ *from_tensor, const DataType_ *memory_tensor, - DataType_ *key_cache, DataType_ *value_cache, - DataType_ *key_mem_cache, DataType_ *value_mem_cache, - const int* memory_sequence_length, - DataType_ *decoder_output, const int step); - }; - } //namespace functor -} //namespace tensorflow -#endif diff --git a/fastertransformer/tf_op/decoding_beamsearch_op.cc b/fastertransformer/tf_op/decoding_beamsearch_op.cc new file mode 100644 index 000000000..91e6de750 --- /dev/null +++ b/fastertransformer/tf_op/decoding_beamsearch_op.cc @@ -0,0 +1,273 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define EIGEN_USE_GPU + +#include "fastertransformer/open_decoder.h" +#include "fastertransformer/decoding_beamsearch.h" +#include "fastertransformer/common.h" + +#include "fastertransformer/tf_op/common_op.h" +#include "fastertransformer/tf_op/tf_traits.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow +{ +namespace +{ +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +REGISTER_OP("Decoding") + .Input("memory_tensor: T") // 0 + .Input("memory_sequence_length: int32") // 1 + .Input("self_beta: T") // 2 + .Input("self_gamma: T") // 3 + .Input("self_q_kernel: T") // 4 + .Input("self_q_bias: T") // 5 + .Input("self_k_kernel: T") // 6 + .Input("self_k_bias: T") // 7 + .Input("self_v_kernel: T") // 8 + .Input("self_v_bias: T") // 9 + .Input("self_output_kernel: T") // 10 + .Input("self_output_bias: T") // 11 + .Input("cross_beta: T") // 12 + .Input("cross_gamma: T") // 13 + .Input("cross_q_kernel: T") // 14 + .Input("cross_q_bias: T") // 15 + .Input("cross_k_kernel: T") // 16 + .Input("cross_k_bias: T") // 17 + .Input("cross_v_kernel: T") // 18 + .Input("cross_v_bias: T") // 19 + .Input("cross_output_kernel: T") // 20 + .Input("cross_output_bias: T") // 21 + .Input("ffn_beta: T") // 22 + .Input("ffn_gamma: T") // 23 + .Input("ffn_kernel1: T") // 24 + .Input("ffn_bias1: T") // 25 + .Input("ffn_kernel2: T") // 26 + .Input("ffn_bias2: T") // 27 + .Input("decoding_beta: T") // 28 + .Input("decoding_gamma: T") // 29 + .Input("embedding_table: T") // 30 + .Input("embedding_kernel: T") // 31 + .Input("embedding_bias: float32") // 32 + .Input("position_encoding_table: T") // 33 + .Output("output_ids: int32") + .Output("parent_ids: int32") + .Output("sequence_lengths: int32") + .Attr("T: {float, half}") + .Attr("beam_width: int >= 1") + .Attr("max_seq_len: int >= 1") + .Attr("head_num: int >= 1") + .Attr("size_per_head: int >= 1") + .Attr("num_layer: int >= 1") + .Attr("start_id: int >= 0") + .Attr("end_id: int >= 0") + .Attr("beam_search_diversity_rate: float = 0.0") + .SetShapeFn([](shape_inference::InferenceContext *c) { + int beam_width, max_seq_len; + c->GetAttr("beam_width", &beam_width); + c->GetAttr("max_seq_len", &max_seq_len); + + int rank = c->Rank(c->input(0)); + assert(rank == 3); + + // calculate batch size + shape_inference::DimensionOrConstant max_seq_dim((int64)max_seq_len); + shape_inference::DimensionHandle output_dim; + shape_inference::DimensionHandle batchxbeam_dim; + + batchxbeam_dim = c->Dim(c->input(0), 0); + TF_RETURN_IF_ERROR(c->Multiply(batchxbeam_dim, max_seq_dim, &output_dim)); + + c->set_output(0, c->MakeShape({output_dim})); + c->set_output(1, c->MakeShape({output_dim})); + c->set_output(2, c->MakeShape({batchxbeam_dim})); + return Status::OK(); + }); +template +class DecodingOp : public CommonOp +{ +public: + explicit DecodingOp(OpKernelConstruction *context) : CommonOp(context) + { + OP_REQUIRES_OK(context, context->GetAttr("beam_width", &beam_width_)); + OP_REQUIRES_OK(context, context->GetAttr("max_seq_len", &max_seq_len_)); + OP_REQUIRES_OK(context, context->GetAttr("head_num", &head_num_)); + OP_REQUIRES_OK(context, context->GetAttr("size_per_head", &size_per_head_)); + OP_REQUIRES_OK(context, context->GetAttr("num_layer", &num_layer_)); + OP_REQUIRES_OK(context, context->GetAttr("start_id", &start_id_)); + OP_REQUIRES_OK(context, context->GetAttr("end_id", &end_id_)); + OP_REQUIRES_OK(context, context->GetAttr("beam_search_diversity_rate", &beam_search_diversity_rate_)); + } + + void Compute(OpKernelContext *context) override + { + assert((int)(context->input(0).dims()) == 3); + batch_size_ = (int)context->input(0).dim_size(0) / beam_width_; + const int memory_max_seq_len = (int)context->input(0).dim_size(1); + const int memory_hidden_dim = (int)context->input(0).dim_size(2); + const int vocab_size = (int)context->input(30).dim_size(0); + + DecodingInitParam decoding_params; + decoding_params.cublas_handle = this->get_cublas_handler(); + Tensor *output_ids = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output(0, {max_seq_len_, batch_size_ * beam_width_}, &output_ids)); + + Tensor *parent_ids = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output(1, {max_seq_len_, batch_size_ * beam_width_}, &parent_ids)); + + Tensor *sequence_length = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output(2, {batch_size_ * beam_width_}, &sequence_length)); + + decoding_params.output_ids = reinterpret_cast(output_ids->flat().data()); + decoding_params.parent_ids = reinterpret_cast(parent_ids->flat().data()); + decoding_params.sequence_length = reinterpret_cast(sequence_length->flat().data()); + + check_cuda_error(cudaMemset(decoding_params.output_ids, 0, sizeof(int) * max_seq_len_ * batch_size_ * beam_width_)); + check_cuda_error(cudaMemset(decoding_params.parent_ids, 0, sizeof(int) * max_seq_len_ * batch_size_ * beam_width_)); + check_cuda_error(cudaMemset(decoding_params.sequence_length, 0, sizeof(int) * batch_size_ * beam_width_)); + + typedef DecoderTransformerTraits DecodingTraits_; + DecodingBeamsearch *decoding_beamsearch_; + const cudaStream_t &stream = context->eigen_device().stream(); + decoding_params.stream = stream; + fastertransformer::Allocator allocator_(context, stream); + try + { + decoding_beamsearch_ = new DecodingBeamsearch( + allocator_, batch_size_, beam_width_, + max_seq_len_, head_num_, size_per_head_, + vocab_size, num_layer_, + memory_hidden_dim, memory_max_seq_len, + start_id_, end_id_, + beam_search_diversity_rate_); + } + catch (std::runtime_error &error) + { + OP_REQUIRES(context, false, errors::Internal(error.what())); + } + + OP_REQUIRES(context, context->num_inputs() == 34, errors::InvalidArgument("[ERROR] Less or more input arguments")); + + this->get_tensor(context, 0, &decoding_params.memory_tensor); + decoding_params.memory_sequence_length = reinterpret_cast(context->input(1).flat().data()); + OP_REQUIRES(context, decoding_params.memory_sequence_length != nullptr, errors::InvalidArgument("memory_sequence_length")); + + DecoderInitParam *params = new DecoderInitParam[num_layer_]; + const int hidden_unit = size_per_head_ * head_num_; + for (int i = 0; i < num_layer_; i++) + { + params[i].stream = stream; + params[i].cublas_handle = this->get_cublas_handler(); + check_cuda_error(cublasSetStream(params[i].cublas_handle, params[i].stream)); + + this->get_tensor(context, 2, ¶ms[i].self_layernorm.beta, i * hidden_unit); + this->get_tensor(context, 3, ¶ms[i].self_layernorm.gamma, i * hidden_unit); + + this->get_tensor(context, 4, ¶ms[i].self_attention.query_weight.kernel, i * hidden_unit * hidden_unit); + this->get_tensor(context, 5, ¶ms[i].self_attention.query_weight.bias, i * hidden_unit); + this->get_tensor(context, 6, ¶ms[i].self_attention.key_weight.kernel, i * hidden_unit * hidden_unit); + this->get_tensor(context, 7, ¶ms[i].self_attention.key_weight.bias, i * hidden_unit); + this->get_tensor(context, 8, ¶ms[i].self_attention.value_weight.kernel, i * hidden_unit * hidden_unit); + this->get_tensor(context, 9, ¶ms[i].self_attention.value_weight.bias, i * hidden_unit); + + this->get_tensor(context, 10, ¶ms[i].self_attention.attention_output_weight.kernel, i * hidden_unit * hidden_unit); + this->get_tensor(context, 11, ¶ms[i].self_attention.attention_output_weight.bias, i * hidden_unit); + this->get_tensor(context, 12, ¶ms[i].cross_layernorm.beta, i * hidden_unit); + this->get_tensor(context, 13, ¶ms[i].cross_layernorm.gamma, i * hidden_unit); + this->get_tensor(context, 14, ¶ms[i].cross_attention.query_weight.kernel, i * hidden_unit * hidden_unit); + this->get_tensor(context, 15, ¶ms[i].cross_attention.query_weight.bias, i * hidden_unit); + this->get_tensor(context, 16, ¶ms[i].cross_attention.key_weight.kernel, i * memory_hidden_dim * hidden_unit); + this->get_tensor(context, 17, ¶ms[i].cross_attention.key_weight.bias, i * hidden_unit); + this->get_tensor(context, 18, ¶ms[i].cross_attention.value_weight.kernel, i * memory_hidden_dim * hidden_unit); + this->get_tensor(context, 19, ¶ms[i].cross_attention.value_weight.bias, i * hidden_unit); + this->get_tensor(context, 20, ¶ms[i].cross_attention.attention_output_weight.kernel, i * hidden_unit * hidden_unit); + this->get_tensor(context, 21, ¶ms[i].cross_attention.attention_output_weight.bias, i * hidden_unit); + this->get_tensor(context, 22, ¶ms[i].ffn_layernorm.beta, i * hidden_unit); + this->get_tensor(context, 23, ¶ms[i].ffn_layernorm.gamma, i * hidden_unit); + this->get_tensor(context, 24, ¶ms[i].ffn.intermediate_weight.kernel, i * hidden_unit * hidden_unit * 4); + this->get_tensor(context, 25, ¶ms[i].ffn.intermediate_weight.bias, i * hidden_unit * 4); + this->get_tensor(context, 26, ¶ms[i].ffn.output_weight.kernel, i * hidden_unit * hidden_unit * 4); + this->get_tensor(context, 27, ¶ms[i].ffn.output_weight.bias, i * hidden_unit); + } + + this->get_tensor(context, 28, &decoding_params.layernorm.beta); + this->get_tensor(context, 29, &decoding_params.layernorm.gamma); + this->get_tensor(context, 30, &decoding_params.embedding_table); + this->get_tensor(context, 31, &decoding_params.embedding_kernel); + + + decoding_params.embedding_bias = reinterpret_cast(context->input(32).flat().data()); + OP_REQUIRES(context, decoding_params.embedding_bias != nullptr, errors::InvalidArgument("embedding_bias")); + this->get_tensor(context, 33, &decoding_params.position_encoding_table); + + try + { + decoding_beamsearch_->forward(params, decoding_params); + } + catch(std::runtime_error& error) + { + std::cout << errors::Internal(error.what()); + exit(-1); + } + catch(...) + { + std::cout << errors::Internal("Runtime error"); + exit(-1); + } + + delete decoding_beamsearch_; + delete [] params; + } + +private: + int batch_size_, beam_width_, max_seq_len_; + int head_num_, size_per_head_, num_layer_; + int start_id_, end_id_; + float beam_search_diversity_rate_; + typedef TFTraits traits_; + typedef typename traits_::DataType DataType_; +}; + +#ifdef GOOGLE_CUDA + +#define REGISTER_GPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("Decoding").Device(DEVICE_GPU).TypeConstraint("T"), \ + DecodingOp) +REGISTER_GPU(float); +REGISTER_GPU(Eigen::half); +#undef REGISTER_GPU + +#endif +} //namespace +} //namespace tensorflow diff --git a/fastertransformer/tf_op/decoding_op.cc b/fastertransformer/tf_op/decoding_op.cc deleted file mode 100644 index 1316ed72b..000000000 --- a/fastertransformer/tf_op/decoding_op.cc +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "fastertransformer/open_decoder.h" -#include "fastertransformer/tf_op/decoding_op.h" -#include "fastertransformer/tf_op/common_op.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/framework/register_types.h" -#include -namespace tensorflow -{ -namespace -{ -using CPUDevice = Eigen::ThreadPoolDevice; -using GPUDevice = Eigen::GpuDevice; - -REGISTER_OP("Decoding") - .Input("memory_tensor: T") - .Input("memory_sequence_length: int32") - .Input("self_beta: T") - .Input("self_gamma: T") - .Input("self_q_kernel: T") - .Input("self_q_bias: T") - .Input("self_k_kernel: T") - .Input("self_k_bias: T") - .Input("self_v_kernel: T") - .Input("self_v_bias: T") - .Input("self_output_kernel: T") - .Input("self_output_bias: T") - .Input("cross_beta: T") - .Input("cross_gamma: T") - .Input("cross_q_kernel: T") - .Input("cross_q_bias: T") - .Input("cross_k_kernel: T") - .Input("cross_k_bias: T") - .Input("cross_v_kernel: T") - .Input("cross_v_bias: T") - .Input("cross_output_kernel: T") - .Input("cross_output_bias: T") - .Input("ffn_beta: T") - .Input("ffn_gamma: T") - .Input("ffn_kernel1: T") - .Input("ffn_bias1: T") - .Input("ffn_kernel2: T") - .Input("ffn_bias2: T") - .Input("embedding_table: T") - .Input("decoding_beta: T") - .Input("decoding_gamma: T") - .Input("embedding_kernel: T") - .Input("embedding_bias: float32") - .Output("output_ids: int32") - .Output("parent_ids: int32") - .Output("sequence_lengths: int32") - .Attr("T: {float, half}") - .Attr("batch_size: int >= 1") - .Attr("beam_width: int >= 1") - .Attr("max_seq_len: int >= 1") - .Attr("head_num: int >= 1") - .Attr("size_per_head: int >= 1") - .Attr("num_layer: int >= 1") - .Attr("memory_hidden_dim: int >= 1") - .Attr("vocab_size: int >= 1") - .Attr("start_id: int >= 0") - .Attr("end_id: int >= 0") - .SetShapeFn([](shape_inference::InferenceContext *c) { - int batch_size, beam_width, max_seq_len; - c->GetAttr("batch_size", &batch_size); - c->GetAttr("beam_width", &beam_width); - c->GetAttr("max_seq_len", &max_seq_len); - c->set_output(0, c->MakeShape({batch_size * beam_width * max_seq_len})); - c->set_output(1, c->MakeShape({batch_size * beam_width * max_seq_len})); - c->set_output(2, c->MakeShape({batch_size * beam_width})); - return Status::OK(); - }); -template -class DecodingOp : public CommonOp -{ -public: - explicit DecodingOp(OpKernelConstruction *context) : CommonOp(context) - { - OP_REQUIRES_OK(context, context->GetAttr("batch_size", &batch_size_)); - OP_REQUIRES_OK(context, context->GetAttr("beam_width", &beam_width_)); - OP_REQUIRES_OK(context, context->GetAttr("max_seq_len", &max_seq_len_)); - OP_REQUIRES_OK(context, context->GetAttr("head_num", &head_num_)); - OP_REQUIRES_OK(context, context->GetAttr("size_per_head", &size_per_head_)); - OP_REQUIRES_OK(context, context->GetAttr("num_layer", &num_layer_)); - OP_REQUIRES_OK(context, context->GetAttr("vocab_size", &vocab_size_)); - OP_REQUIRES_OK(context, context->GetAttr("start_id", &start_id_)); - OP_REQUIRES_OK(context, context->GetAttr("end_id", &end_id_)); - } - - void Compute(OpKernelContext *context) override - { - // input(0): memory_tensor: [batch_size * beam_width, memory_max_seq_len, memory_hidden_dim] - assert((int)(context->input(0).dims()) == 3); - const int memory_max_seq_len = (int)context->input(0).dim_size(1); - const int memory_hidden_dim_ = (int)context->input(0).dim_size(2); - - DecodingInitParam decoding_params; - decoding_params.cublas_handle = this->get_cublas_handler(); - Tensor *output_ids = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output(0, {max_seq_len_, batch_size_ * beam_width_}, &output_ids)); - - Tensor *parent_ids = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output(1, {max_seq_len_, batch_size_ * beam_width_}, &parent_ids)); - - Tensor *sequence_length = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output(2, {batch_size_ * beam_width_}, &sequence_length)); - - decoding_params.output_ids = reinterpret_cast(output_ids->flat().data()); - decoding_params.parent_ids = reinterpret_cast(parent_ids->flat().data()); - decoding_params.sequence_length = reinterpret_cast(sequence_length->flat().data()); - - check_cuda_error(cudaMemset(decoding_params.output_ids, 0, sizeof(int) * max_seq_len_ * batch_size_ * beam_width_)); - check_cuda_error(cudaMemset(decoding_params.parent_ids, 0, sizeof(int) * max_seq_len_ * batch_size_ * beam_width_)); - check_cuda_error(cudaMemset(decoding_params.sequence_length, 0, sizeof(int) * batch_size_ * beam_width_)); - - typedef DecoderTransformerTraits DecodingTraits_; - DecodingOpenNMT *decoding_opennmt_; - fastertransformer::Allocator allocator_(context); - try - { - decoding_opennmt_ = new DecodingOpenNMT( - allocator_, batch_size_, beam_width_, - max_seq_len_, head_num_, size_per_head_, - vocab_size_, num_layer_, - memory_hidden_dim_, memory_max_seq_len, - start_id_, end_id_); - } - catch (std::runtime_error &error) - { - OP_REQUIRES(context, false, errors::Internal(error.what())); - } - - OP_REQUIRES(context, context->num_inputs() == 33, errors::InvalidArgument("[ERROR] Less or more input arguments")); - - this->get_tensor(context, 0, &decoding_params.memory_tensor); - decoding_params.memory_sequence_length = reinterpret_cast(context->input(1).flat().data()); - OP_REQUIRES(context, decoding_params.memory_sequence_length != nullptr, errors::InvalidArgument("memory_sequence_length")); - - DecoderInitParam *params = new DecoderInitParam[num_layer_]; - const int hidden_unit = size_per_head_ * head_num_; - for (int i = 0; i < num_layer_; i++) - { - params[i].cublas_handle = this->get_cublas_handler(); - this->get_tensor(context, 2, ¶ms[i].self_layernorm.beta, i * hidden_unit); - this->get_tensor(context, 3, ¶ms[i].self_layernorm.gamma, i * hidden_unit); - this->get_tensor(context, 4, ¶ms[i].self_attention.query_weight.kernel, i * hidden_unit * hidden_unit); - this->get_tensor(context, 5, ¶ms[i].self_attention.query_weight.bias, i * hidden_unit); - this->get_tensor(context, 6, ¶ms[i].self_attention.key_weight.kernel, i * hidden_unit * hidden_unit); - this->get_tensor(context, 7, ¶ms[i].self_attention.key_weight.bias, i * hidden_unit); - this->get_tensor(context, 8, ¶ms[i].self_attention.value_weight.kernel, i * hidden_unit * hidden_unit); - this->get_tensor(context, 9, ¶ms[i].self_attention.value_weight.bias, i * hidden_unit); - this->get_tensor(context, 10, ¶ms[i].self_attention.attention_output_weight.kernel, i * hidden_unit * hidden_unit); - this->get_tensor(context, 11, ¶ms[i].self_attention.attention_output_weight.bias, i * hidden_unit); - this->get_tensor(context, 12, ¶ms[i].cross_layernorm.beta, i * hidden_unit); - this->get_tensor(context, 13, ¶ms[i].cross_layernorm.gamma, i * hidden_unit); - this->get_tensor(context, 14, ¶ms[i].cross_attention.query_weight.kernel, i * hidden_unit * hidden_unit); - this->get_tensor(context, 15, ¶ms[i].cross_attention.query_weight.bias, i * hidden_unit); - this->get_tensor(context, 16, ¶ms[i].cross_attention.key_weight.kernel, i * memory_hidden_dim_ * hidden_unit); - this->get_tensor(context, 17, ¶ms[i].cross_attention.key_weight.bias, i * hidden_unit); - this->get_tensor(context, 18, ¶ms[i].cross_attention.value_weight.kernel, i * memory_hidden_dim_ * hidden_unit); - this->get_tensor(context, 19, ¶ms[i].cross_attention.value_weight.bias, i * hidden_unit); - this->get_tensor(context, 20, ¶ms[i].cross_attention.attention_output_weight.kernel, i * hidden_unit * hidden_unit); - this->get_tensor(context, 21, ¶ms[i].cross_attention.attention_output_weight.bias, i * hidden_unit); - this->get_tensor(context, 22, ¶ms[i].ffn_layernorm.beta, i * hidden_unit); - this->get_tensor(context, 23, ¶ms[i].ffn_layernorm.gamma, i * hidden_unit); - this->get_tensor(context, 24, ¶ms[i].ffn.intermediate_weight.kernel, i * hidden_unit * hidden_unit * 4); - this->get_tensor(context, 25, ¶ms[i].ffn.intermediate_weight.bias, i * hidden_unit * 4); - this->get_tensor(context, 26, ¶ms[i].ffn.output_weight.kernel, i * hidden_unit * hidden_unit * 4); - this->get_tensor(context, 27, ¶ms[i].ffn.output_weight.bias, i * hidden_unit); - } - - this->get_tensor(context, 28, &decoding_params.layernorm.beta); - this->get_tensor(context, 29, &decoding_params.layernorm.gamma); - this->get_tensor(context, 30, &decoding_params.embedding_table); - this->get_tensor(context, 31, &decoding_params.embedding_kernel); - - decoding_params.embedding_bias = reinterpret_cast(context->input(32).flat().data()); - OP_REQUIRES(context, decoding_params.embedding_bias != nullptr, errors::InvalidArgument("memory_sequence_length")); - - OP_REQUIRES_OK( - context, - functor::DecodingOpFunctor::DynamicDecode( - context, - num_layer_, - params, - decoding_opennmt_, - max_seq_len_, - decoding_params)); - - delete decoding_opennmt_; - delete params; - } - -private: - int batch_size_, beam_width_, max_seq_len_; - int head_num_, size_per_head_, num_layer_; - int memory_hidden_dim_, vocab_size_, start_id_, end_id_; - typedef TFTraits traits_; - typedef typename traits_::DataType DataType_; -}; - -#ifdef GOOGLE_CUDA - -#define REGISTER_GPU(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("Decoding").Device(DEVICE_GPU).TypeConstraint("T"), \ - DecodingOp) -REGISTER_GPU(float); -REGISTER_GPU(Eigen::half); -#undef REGISTER_GPU - -#endif -} //namespace -} //namespace tensorflow diff --git a/fastertransformer/tf_op/decoding_op.cu.cc b/fastertransformer/tf_op/decoding_op.cu.cc deleted file mode 100644 index a1bd0571a..000000000 --- a/fastertransformer/tf_op/decoding_op.cu.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifdef GOOGLE_CUDA -#define EIGEN_USE_GPU -#include "fastertransformer/tf_op/decoding_op.h" -#include "fastertransformer/decoding_opennmt.h" -#include "fastertransformer/common.h" -#include "fastertransformer/open_decoder.h" -#include "tensorflow/core/framework/op.h" -#include -#include -namespace tensorflow -{ -using GPUDevice = Eigen::GpuDevice; -using namespace fastertransformer; - -namespace functor -{ -template -struct DecodingOpFunctor -{ - typedef typename TFTraits::DataType DataType_; - static Status DynamicDecode(OpKernelContext *context, - const int num_layers, - DecoderInitParam *params, - DecodingOpenNMT::OpType> *decoding_opennmt, - const int max_seq_len, - DecodingInitParam decoding_params) - { - const cudaStream_t &stream = context->eigen_device().stream(); - try - { - decoding_params.stream = stream; - for(int i = 0; i < num_layers; ++i) - { - params[i].stream = stream; - check_cuda_error(cublasSetStream(params[i].cublas_handle, stream)); - } - decoding_opennmt->forward(params, decoding_params); - - return Status::OK(); - } - catch(std::runtime_error& error) - { - return errors::Internal(error.what()); - } - catch(...) - { - return errors::Internal("Runtime error"); - } - } -}; -} //namespace functor - -template struct functor::DecodingOpFunctor; -template struct functor::DecodingOpFunctor; -} //namespace tensorflow -#endif diff --git a/fastertransformer/tf_op/decoding_op.h b/fastertransformer/tf_op/decoding_op.h deleted file mode 100644 index 78db420f8..000000000 --- a/fastertransformer/tf_op/decoding_op.h +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef TENSORFLOW_CORE_KERNELS_DECODING_OP_H_ -#define TENSORFLOW_CORE_KERNELS_DECODING_OP_H_ - -#include "fastertransformer/common.h" -#include "fastertransformer/open_decoder.h" -#include "fastertransformer/decoding_opennmt.h" -#include "fastertransformer/tf_op/tf_traits.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include - -using namespace fastertransformer; -namespace tensorflow -{ - namespace functor - { - template - struct DecodingOpFunctor - { - typedef typename TFTraits::DataType DataType_; - static Status DynamicDecode( - OpKernelContext *context, - const int num_layers, - DecoderInitParam *params, - DecodingOpenNMT::OpType> *decoding_opennmt, - const int max_seq_len, - DecodingInitParam decoding_params); - }; - } //namespace functor -} //namespace tensorflow -#endif diff --git a/fastertransformer/tf_op/decoding_sampling_op.cc b/fastertransformer/tf_op/decoding_sampling_op.cc new file mode 100644 index 000000000..d323799d3 --- /dev/null +++ b/fastertransformer/tf_op/decoding_sampling_op.cc @@ -0,0 +1,251 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define EIGEN_USE_GPU + +#include "fastertransformer/open_decoder.h" +#include "fastertransformer/decoding_sampling.h" +#include "fastertransformer/tf_op/common_op.h" + +namespace tensorflow +{ +namespace +{ +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +REGISTER_OP("DecodingSampling") + .Input("memory_tensor: T") // 0 + .Input("memory_sequence_length: int32") // 1 + .Input("self_beta: T") // 2 + .Input("self_gamma: T") // 3 + .Input("self_q_kernel: T") // 4 + .Input("self_q_bias: T") // 5 + .Input("self_k_kernel: T") // 6 + .Input("self_k_bias: T") // 7 + .Input("self_v_kernel: T") // 8 + .Input("self_v_bias: T") // 9 + .Input("self_output_kernel: T") // 10 + .Input("self_output_bias: T") // 11 + .Input("cross_beta: T") // 12 + .Input("cross_gamma: T") // 13 + .Input("cross_q_kernel: T") // 14 + .Input("cross_q_bias: T") // 15 + .Input("cross_k_kernel: T") // 16 + .Input("cross_k_bias: T") // 17 + .Input("cross_v_kernel: T") // 18 + .Input("cross_v_bias: T") // 19 + .Input("cross_output_kernel: T") // 20 + .Input("cross_output_bias: T") // 21 + .Input("ffn_beta: T") // 22 + .Input("ffn_gamma: T") // 23 + .Input("ffn_kernel1: T") // 24 + .Input("ffn_bias1: T") // 25 + .Input("ffn_kernel2: T") // 26 + .Input("ffn_bias2: T") // 27 + .Input("decoding_beta: T") // 28 + .Input("decoding_gamma: T") // 29 + .Input("embedding_table: T") // 30 + .Input("embedding_kernel: T") // 31 + .Input("embedding_bias: float32") // 32 + .Input("position_encoding_table: T") // 33 + .Output("output_ids: int32") + .Output("sequence_lengths: int32") + .Attr("T: {float, half}") + .Attr("max_seq_len: int >= 1") + .Attr("candidate_num: int >= 0") + .Attr("probability_threshold: float = 0.0") + .Attr("head_num: int >= 1") + .Attr("size_per_head: int >= 1") + .Attr("num_layer: int >= 1") + .Attr("start_id: int >= 0") + .Attr("end_id: int >= 0") + .SetShapeFn([](shape_inference::InferenceContext *c) { + int max_seq_len; + c->GetAttr("max_seq_len", &max_seq_len); + + int rank = c->Rank(c->input(0)); + assert(rank == 3); + + shape_inference::DimensionOrConstant max_seq_dim((int64)max_seq_len); + shape_inference::DimensionHandle output_dim; + shape_inference::DimensionHandle batch_dim; + + batch_dim = c->Dim(c->input(0), 0); + TF_RETURN_IF_ERROR(c->Multiply(batch_dim, max_seq_dim, &output_dim)); + + c->set_output(0, c->MakeShape({output_dim})); + c->set_output(1, c->MakeShape({batch_dim})); + return Status::OK(); + + }); +template +class DecodingSamplingOp : public CommonOp +{ +public: + explicit DecodingSamplingOp(OpKernelConstruction *context) : CommonOp(context) + { + OP_REQUIRES_OK(context, context->GetAttr("max_seq_len", &max_seq_len_)); + OP_REQUIRES_OK(context, context->GetAttr("candidate_num", &candidate_num_)); + OP_REQUIRES_OK(context, context->GetAttr("probability_threshold", &probability_threshold_)); + OP_REQUIRES_OK(context, context->GetAttr("head_num", &head_num_)); + OP_REQUIRES_OK(context, context->GetAttr("size_per_head", &size_per_head_)); + OP_REQUIRES_OK(context, context->GetAttr("num_layer", &num_layer_)); + OP_REQUIRES_OK(context, context->GetAttr("start_id", &start_id_)); + OP_REQUIRES_OK(context, context->GetAttr("end_id", &end_id_)); + } + + void Compute(OpKernelContext *context) override + { + // input(0): memory_tensor: [batch_size * memory_max_seq_len, memory_hidden_dim] + assert((int)(context->input(0).dims()) == 3); + batch_size_ = (int)context->input(0).dim_size(0); + const int memory_max_seq_len = (int)context->input(0).dim_size(1); + const int memory_hidden_dim = (int)context->input(0).dim_size(2); + const int vocab_size = (int)context->input(30).dim_size(0); + + DecodingInitParam decoding_params; + decoding_params.cublas_handle = this->get_cublas_handler(); + Tensor *output_ids = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output(0, {max_seq_len_, batch_size_}, &output_ids)); + + Tensor *sequence_length = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output(1, {batch_size_}, &sequence_length)); + + decoding_params.output_ids = reinterpret_cast(output_ids->flat().data()); + decoding_params.sequence_length = reinterpret_cast(sequence_length->flat().data()); + + check_cuda_error(cudaMemset(decoding_params.output_ids, 0, sizeof(int) * max_seq_len_ * batch_size_)); + check_cuda_error(cudaMemset(decoding_params.sequence_length, 0, sizeof(int) * batch_size_)); + + typedef DecoderTransformerTraits DecodingTraits_; + DecodingSampling *decoding_sampling_; + const cudaStream_t &stream = context->eigen_device().stream(); + decoding_params.stream = stream; + fastertransformer::Allocator allocator_(context, stream); + try + { + decoding_sampling_ = new DecodingSampling( + allocator_, batch_size_, + max_seq_len_, head_num_, size_per_head_, + vocab_size, num_layer_, + memory_hidden_dim, memory_max_seq_len, + start_id_, end_id_, + candidate_num_, probability_threshold_); + } + catch (std::runtime_error &error) + { + OP_REQUIRES(context, false, errors::Internal(error.what())); + } + + OP_REQUIRES(context, context->num_inputs() == 34, errors::InvalidArgument("[ERROR] Less or more input arguments")); + + this->get_tensor(context, 0, &decoding_params.memory_tensor); + decoding_params.memory_sequence_length = reinterpret_cast(context->input(1).flat().data()); + OP_REQUIRES(context, decoding_params.memory_sequence_length != nullptr, errors::InvalidArgument("memory_sequence_length")); + + DecoderInitParam *params = new DecoderInitParam[num_layer_]; + const int hidden_unit = size_per_head_ * head_num_; + for (int i = 0; i < num_layer_; i++) + { + params[i].stream = stream; + params[i].cublas_handle = this->get_cublas_handler(); + check_cuda_error(cublasSetStream(params[i].cublas_handle, params[i].stream)); + + this->get_tensor(context, 2, ¶ms[i].self_layernorm.beta, i * hidden_unit); + this->get_tensor(context, 3, ¶ms[i].self_layernorm.gamma, i * hidden_unit); + + this->get_tensor(context, 4, ¶ms[i].self_attention.query_weight.kernel, i * hidden_unit * hidden_unit); + this->get_tensor(context, 5, ¶ms[i].self_attention.query_weight.bias, i * hidden_unit); + this->get_tensor(context, 6, ¶ms[i].self_attention.key_weight.kernel, i * hidden_unit * hidden_unit); + this->get_tensor(context, 7, ¶ms[i].self_attention.key_weight.bias, i * hidden_unit); + this->get_tensor(context, 8, ¶ms[i].self_attention.value_weight.kernel, i * hidden_unit * hidden_unit); + this->get_tensor(context, 9, ¶ms[i].self_attention.value_weight.bias, i * hidden_unit); + + this->get_tensor(context, 10, ¶ms[i].self_attention.attention_output_weight.kernel, i * hidden_unit * hidden_unit); + this->get_tensor(context, 11, ¶ms[i].self_attention.attention_output_weight.bias, i * hidden_unit); + this->get_tensor(context, 12, ¶ms[i].cross_layernorm.beta, i * hidden_unit); + this->get_tensor(context, 13, ¶ms[i].cross_layernorm.gamma, i * hidden_unit); + this->get_tensor(context, 14, ¶ms[i].cross_attention.query_weight.kernel, i * hidden_unit * hidden_unit); + this->get_tensor(context, 15, ¶ms[i].cross_attention.query_weight.bias, i * hidden_unit); + this->get_tensor(context, 16, ¶ms[i].cross_attention.key_weight.kernel, i * memory_hidden_dim * hidden_unit); + this->get_tensor(context, 17, ¶ms[i].cross_attention.key_weight.bias, i * hidden_unit); + this->get_tensor(context, 18, ¶ms[i].cross_attention.value_weight.kernel, i * memory_hidden_dim * hidden_unit); + this->get_tensor(context, 19, ¶ms[i].cross_attention.value_weight.bias, i * hidden_unit); + this->get_tensor(context, 20, ¶ms[i].cross_attention.attention_output_weight.kernel, i * hidden_unit * hidden_unit); + this->get_tensor(context, 21, ¶ms[i].cross_attention.attention_output_weight.bias, i * hidden_unit); + this->get_tensor(context, 22, ¶ms[i].ffn_layernorm.beta, i * hidden_unit); + this->get_tensor(context, 23, ¶ms[i].ffn_layernorm.gamma, i * hidden_unit); + this->get_tensor(context, 24, ¶ms[i].ffn.intermediate_weight.kernel, i * hidden_unit * hidden_unit * 4); + this->get_tensor(context, 25, ¶ms[i].ffn.intermediate_weight.bias, i * hidden_unit * 4); + this->get_tensor(context, 26, ¶ms[i].ffn.output_weight.kernel, i * hidden_unit * hidden_unit * 4); + this->get_tensor(context, 27, ¶ms[i].ffn.output_weight.bias, i * hidden_unit); + } + + this->get_tensor(context, 28, &decoding_params.layernorm.beta); + this->get_tensor(context, 29, &decoding_params.layernorm.gamma); + this->get_tensor(context, 30, &decoding_params.embedding_table); + this->get_tensor(context, 31, &decoding_params.embedding_kernel); + + decoding_params.embedding_bias = reinterpret_cast(context->input(32).flat().data()); + OP_REQUIRES(context, decoding_params.embedding_bias != nullptr, errors::InvalidArgument("embedding_bias")); + this->get_tensor(context, 33, &decoding_params.position_encoding_table); + + try + { + decoding_sampling_->forward(params, decoding_params); + } + catch(std::runtime_error& error) + { + std::cout << errors::Internal(error.what()); + exit(-1); + } + catch(...) + { + std::cout << errors::Internal("Runtime error"); + exit(-1); + } + + delete decoding_sampling_; + delete params; + } + +private: + int batch_size_, max_seq_len_, candidate_num_; + int head_num_, size_per_head_, num_layer_; + int start_id_, end_id_; + float probability_threshold_; + typedef TFTraits traits_; + typedef typename traits_::DataType DataType_; +}; + +#ifdef GOOGLE_CUDA + +#define REGISTER_GPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("DecodingSampling").Device(DEVICE_GPU).TypeConstraint("T"), \ + DecodingSamplingOp) +REGISTER_GPU(float); +REGISTER_GPU(Eigen::half); +#undef REGISTER_GPU + +#endif +} //namespace +} //namespace tensorflow diff --git a/fastertransformer/th_op/CMakeLists.txt b/fastertransformer/th_op/CMakeLists.txt new file mode 100644 index 000000000..bd1d4425f --- /dev/null +++ b/fastertransformer/th_op/CMakeLists.txt @@ -0,0 +1,88 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +cmake_minimum_required(VERSION 3.13) + +set(th_fastertransformer_ext_files + ft_ext.cc + encoder_ext.cc + decoder_ext.cc + decoding_ext.cc + utils.cu +) + +set(th_fastertransformer_ths_files + ft_ths_op.cc + encoder_ths_op.cc + decoder_ths_op.cc + decoding_ths_op.cc + utils.cu +) + +set(th_fastertransformer_ths_f_files + ft_ths_op_f.cc + encoder_ths_op_f.cc +) + +add_definitions(-DTORCH_CUDA=1) + +# hack for bugs in torch +if(TARGET torch_cpu) + set_target_properties(torch_cpu PROPERTIES + INTERFACE_COMPILE_OPTIONS "") +endif() +if(TARGET torch_cuda) + set_target_properties(torch_cuda PROPERTIES + INTERFACE_COMPILE_OPTIONS "") + set(NEW_TORCH_CUDA_LINK_VAR) + get_target_property(OLD_TORCH_CUDA_LINK_VAR torch_cuda INTERFACE_LINK_LIBRARIES) + foreach (TMPVAR ${OLD_TORCH_CUDA_LINK_VAR}) + string(REPLACE "/usr/local/cuda" "${CUDA_TOOLKIT_ROOT_DIR}" TMPVAR ${TMPVAR}) + list(APPEND NEW_TORCH_CUDA_LINK_VAR ${TMPVAR}) + endforeach(TMPVAR) + set_target_properties(torch_cuda PROPERTIES + INTERFACE_LINK_LIBRARIES "${NEW_TORCH_CUDA_LINK_VAR}") +endif() +if(TARGET torch) + set(NEW_TORCH_LINK_VAR) + get_target_property(OLD_TORCH_LINK_VAR torch INTERFACE_LINK_LIBRARIES) + foreach (TMPVAR ${OLD_TORCH_LINK_VAR}) + string(REPLACE "/usr/local/cuda" "${CUDA_TOOLKIT_ROOT_DIR}" TMPVAR ${TMPVAR}) + list(APPEND NEW_TORCH_LINK_VAR ${TMPVAR}) + endforeach(TMPVAR) + set_target_properties(torch PROPERTIES + INTERFACE_LINK_LIBRARIES "${NEW_TORCH_LINK_VAR}") +endif() + + +if(BUILD_THE) + set(LIB_NAME_1 "th_fastertransformer") + add_library(${LIB_NAME_1} SHARED ${th_fastertransformer_ext_files}) + set_target_properties(${LIB_NAME_1} PROPERTIES + PREFIX "" + SUFFIX ${PY_SUFFIX} + LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}) + target_link_libraries(${LIB_NAME_1} "${TORCH_LIBRARIES}" "${TORCH_LINK}" encoder decoder decoding) +endif() + +if(BUILD_THS) + set(LIB_NAME_2 "ths_fastertransformer") + add_library(${LIB_NAME_2} SHARED ${th_fastertransformer_ths_files}) + target_link_libraries(${LIB_NAME_2} "${TORCH_LIBRARIES}" encoder decoder decoding) +endif() + +if(BUILD_THSOP) + set(LIB_NAME_3 "ths_fastertransformer_op") + add_library(${LIB_NAME_3} SHARED ${th_fastertransformer_ths_f_files}) + target_link_libraries(${LIB_NAME_3} "${TORCH_LIBRARIES}" encoder) +endif() diff --git a/fastertransformer/th_op/decoder_ext.cc b/fastertransformer/th_op/decoder_ext.cc new file mode 100644 index 000000000..27fb8f13a --- /dev/null +++ b/fastertransformer/th_op/decoder_ext.cc @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fastertransformer/th_op/decoder_ext.h" + +namespace torch_ext { +using torch::Tensor; + +FasterTransformerDecoder::FasterTransformerDecoder( + int head_num, + int head_size, + Tensor self_layernorm_gamma, + Tensor self_layernorm_beta, + Tensor self_kernel_q, + Tensor self_kernel_k, + Tensor self_kernel_v, + Tensor self_bias_q, + Tensor self_bias_k, + Tensor self_bias_v, + Tensor self_output_kernel, + Tensor self_output_bias, + Tensor cross_layernorm_gamma, + Tensor cross_layernorm_beta, + Tensor cross_kernel_q, + Tensor cross_kernel_k, + Tensor cross_kernel_v, + Tensor cross_bias_q, + Tensor cross_bias_k, + Tensor cross_bias_v, + Tensor cross_output_kernel, + Tensor cross_output_bias, + Tensor ffn_layernorm_gamma, + Tensor ffn_layernorm_beta, + Tensor inter_kernel, + Tensor inter_bias, + Tensor output_kernel, + Tensor output_bias) +: _st(self_layernorm_gamma.scalar_type()) +{ + CHECK_INPUT(self_layernorm_gamma, _st); // hidden_dim + CHECK_INPUT(self_layernorm_beta, _st); // hidden_dim + CHECK_INPUT(self_kernel_q, _st); // hidden_dim, hidden_dim + CHECK_INPUT(self_kernel_k, _st); // hidden_dim, hidden_dim + CHECK_INPUT(self_kernel_v, _st); // hidden_dim, hidden_dim + CHECK_INPUT(self_bias_q, _st); // hidden_dim + CHECK_INPUT(self_bias_k, _st); // hidden_dim + CHECK_INPUT(self_bias_v, _st); // hidden_dim + CHECK_INPUT(self_output_kernel, _st); // hidden_dim, hidden_dim + CHECK_INPUT(self_output_bias, _st); // hidden_dim + CHECK_INPUT(cross_layernorm_gamma, _st); // hidden_dim + CHECK_INPUT(cross_layernorm_beta, _st); // hidden_dim + CHECK_INPUT(cross_kernel_q, _st); // hidden_dim, hidden_dim + CHECK_INPUT(cross_kernel_k, _st); // mem_hidden_dim, hidden_dim + CHECK_INPUT(cross_kernel_v, _st); // mem_hidden_dim, hidden_dim + CHECK_INPUT(cross_bias_q, _st); // hidden_dim + CHECK_INPUT(cross_bias_k, _st); // hidden_dim + CHECK_INPUT(cross_bias_v, _st); // hidden_dim + CHECK_INPUT(cross_output_kernel, _st); // hidden_dim, hidden_dim + CHECK_INPUT(cross_output_bias, _st); // hidden_dim + CHECK_INPUT(ffn_layernorm_gamma, _st); // hidden_dim + CHECK_INPUT(ffn_layernorm_beta, _st); // hidden_dim + CHECK_INPUT(inter_kernel, _st); // hidden_dim, 4 * hidden_dim + CHECK_INPUT(inter_bias, _st); // 4 * hidden_dim + CHECK_INPUT(output_kernel, _st); // 4 * hidden_dim, hidden_dim + CHECK_INPUT(output_bias, _st); // hidden_dim + std::vector weights{self_layernorm_gamma, self_layernorm_beta, + self_kernel_q, self_kernel_k, self_kernel_v, self_bias_q, self_bias_k, self_bias_v, + self_output_kernel, self_output_bias, + cross_layernorm_gamma, cross_layernorm_beta, + cross_kernel_q, cross_kernel_k, cross_kernel_v, cross_bias_q, cross_bias_k, cross_bias_v, + cross_output_kernel, cross_output_bias, + ffn_layernorm_gamma, ffn_layernorm_beta, inter_kernel, inter_bias, output_kernel, output_bias}; + switch (_st) { + case at::ScalarType::Float: + ftdecoder = new FTDecoder(head_num, head_size, weights); + break; + case at::ScalarType::Half: + ftdecoder = new FTDecoder(head_num, head_size, weights); + break; + default: + throw std::runtime_error("Wrong Tensor type."); + } +} + +FasterTransformerDecoder::~FasterTransformerDecoder() { + delete ftdecoder; +} + +Tensor FasterTransformerDecoder::forward(Tensor input, Tensor memory, Tensor memory_seq_lens, Tensor self_cache, Tensor mem_cache) { + CHECK_INPUT(input, _st); + CHECK_INPUT(memory, _st); + CHECK_INPUT(self_cache, _st); + CHECK_INPUT(mem_cache, _st); + CHECK_CUDA(memory_seq_lens); CHECK_CONTIGUOUS(memory_seq_lens); TORCH_CHECK(memory_seq_lens.dtype()==torch::kInt32, "mem_seq_lens dtype should be int32"); + auto mem_size = memory.sizes(); + int batch_size = mem_size[0]; + int seq_len = mem_size[1]; + int mem_hidden_dim = mem_size[2]; + int step = self_cache.size(1); + auto output = torch::empty_like(input); + ftdecoder->forward(batch_size, seq_len, mem_hidden_dim, step, input, memory, memory_seq_lens, self_cache, mem_cache, output); + return output; +} +} // namespace torch_ext diff --git a/fastertransformer/th_op/decoder_ext.h b/fastertransformer/th_op/decoder_ext.h new file mode 100644 index 000000000..457bf41a0 --- /dev/null +++ b/fastertransformer/th_op/decoder_ext.h @@ -0,0 +1,157 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "torch/extension.h" +#include "torch/csrc/cuda/Stream.h" + +#include "fastertransformer/open_decoder.h" +#include "fastertransformer/th_op/th_traits.h" +#include "fastertransformer/th_op/utils.h" + +namespace torch_ext { +using namespace fastertransformer; +using torch::Tensor; + +class IFTDecoder { +public: + virtual ~IFTDecoder() {} + virtual void forward(int batch_size, int seq_len, int mem_hidden_dim, int step, + Tensor& input, Tensor& memory, Tensor& memory_seq_lens, Tensor& self_cache, Tensor& mem_cache, Tensor& output) = 0; +}; + +template +class FTDecoder : public IFTDecoder { +public: + FTDecoder(int head_num, int head_size, const std::vector& w) : _head_num(head_num), _head_size(head_size), _weights(w) { + int hidden_dim = _head_num * _head_size; + check_cuda_error(cublasCreate(&_cublasHandle)); + decoder_params.self_layernorm.gamma = get_ptr(_weights[0]); + decoder_params.self_layernorm.beta = get_ptr(_weights[1]); + decoder_params.self_attention.query_weight.kernel = get_ptr(_weights[2]); + decoder_params.self_attention.key_weight.kernel = get_ptr(_weights[3]); + decoder_params.self_attention.value_weight.kernel = get_ptr(_weights[4]); + decoder_params.self_attention.query_weight.bias = get_ptr(_weights[5]); + decoder_params.self_attention.key_weight.bias = get_ptr(_weights[6]); + decoder_params.self_attention.value_weight.bias = get_ptr(_weights[7]); + decoder_params.self_attention.attention_output_weight.kernel = get_ptr(_weights[8]); + decoder_params.self_attention.attention_output_weight.bias = get_ptr(_weights[9]); + decoder_params.cross_layernorm.gamma = get_ptr(_weights[10]); + decoder_params.cross_layernorm.beta = get_ptr(_weights[11]); + decoder_params.cross_attention.query_weight.kernel = get_ptr(_weights[12]); + decoder_params.cross_attention.key_weight.kernel = get_ptr(_weights[13]); + decoder_params.cross_attention.value_weight.kernel = get_ptr(_weights[14]); + decoder_params.cross_attention.query_weight.bias = get_ptr(_weights[15]); + decoder_params.cross_attention.key_weight.bias = get_ptr(_weights[16]); + decoder_params.cross_attention.value_weight.bias = get_ptr(_weights[17]); + decoder_params.cross_attention.attention_output_weight.kernel = get_ptr(_weights[18]); + decoder_params.cross_attention.attention_output_weight.bias = get_ptr(_weights[19]); + decoder_params.ffn_layernorm.gamma = get_ptr(_weights[20]); + decoder_params.ffn_layernorm.beta = get_ptr(_weights[21]); + decoder_params.ffn.intermediate_weight.kernel = get_ptr(_weights[22]); + decoder_params.ffn.intermediate_weight.bias = get_ptr(_weights[23]); + decoder_params.ffn.output_weight.kernel = get_ptr(_weights[24]); + decoder_params.ffn.output_weight.bias = get_ptr(_weights[25]); + decoder_params.cublas_handle = _cublasHandle; + } + + ~FTDecoder() override { + cublasDestroy(_cublasHandle); + } + + void forward(int batch_size, int seq_len, int mem_hidden_dim, int step, + Tensor& input, Tensor& memory, Tensor& memory_seq_lens, Tensor& self_cache, Tensor& mem_cache, Tensor& output) override + { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + check_cuda_error(cublasSetStream(decoder_params.cublas_handle, stream)); + decoder_params.stream = stream; + fastertransformer::Allocator allocator; + OpenDecoder::OpType>* decoder = + new OpenDecoder::OpType>(batch_size, seq_len, _head_num, _head_size, mem_hidden_dim); + + T* output_ptr = get_ptr(output); + T* self_cache_ptr = get_ptr(self_cache); + T* mem_cache_ptr = get_ptr(mem_cache); + const T* input_ptr = get_ptr(input); + const T* memory_ptr = get_ptr(memory); + const int* memory_seq_lens_ptr = get_ptr(memory_seq_lens); + + T* K_cache = self_cache_ptr; + T* V_cache = self_cache_ptr + batch_size * step * _head_num * _head_size; + T* K_mem_cache = mem_cache_ptr; + T* V_mem_cache = mem_cache_ptr + batch_size * seq_len * _head_num * _head_size; + const int decoder_buffer_size = decoder->getWorkspaceSize() * sizeof(T); + T* decoder_buffer = (T*)allocator.malloc(decoder_buffer_size); + + decoder->initialize(decoder_params, decoder_buffer); + decoder->forward(input_ptr, memory_ptr, K_cache, V_cache, K_mem_cache, V_mem_cache, memory_seq_lens_ptr, output_ptr, step); + allocator.free(decoder_buffer); + delete decoder; + } + +private: + const int _head_num; + const int _head_size; + std::vector _weights; + cublasHandle_t _cublasHandle; + DecoderInitParam decoder_params; +}; + +class FasterTransformerDecoder { +public: + FasterTransformerDecoder( + int head_num, + int head_size, + Tensor self_layernorm_gamma, + Tensor self_layernorm_beta, + Tensor self_kernel_q, + Tensor self_kernel_k, + Tensor self_kernel_v, + Tensor self_bias_q, + Tensor self_bias_k, + Tensor self_bias_v, + Tensor self_output_kernel, + Tensor self_output_bias, + Tensor cross_layernorm_gamma, + Tensor cross_layernorm_beta, + Tensor cross_kernel_q, + Tensor cross_kernel_k, + Tensor cross_kernel_v, + Tensor cross_bias_q, + Tensor cross_bias_k, + Tensor cross_bias_v, + Tensor cross_output_kernel, + Tensor cross_output_bias, + Tensor ffn_layernorm_gamma, + Tensor ffn_layernorm_beta, + Tensor inter_kernel, + Tensor inter_bias, + Tensor output_kernel, + Tensor output_bias); + + ~FasterTransformerDecoder(); + + Tensor forward(Tensor input, Tensor memory, Tensor memory_seq_lens, Tensor self_cache, Tensor mem_cache); + +private: + const at::ScalarType _st; + IFTDecoder* ftdecoder; +}; + +} // namespace torch_ext \ No newline at end of file diff --git a/fastertransformer/th_op/decoder_ths_op.cc b/fastertransformer/th_op/decoder_ths_op.cc new file mode 100644 index 000000000..64cca3798 --- /dev/null +++ b/fastertransformer/th_op/decoder_ths_op.cc @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fastertransformer/th_op/decoder_ths_op.h" + +namespace torch_ths { +using torch::Tensor; + +FasterTransformerDecoder::FasterTransformerDecoder( + int64_t head_num, + int64_t head_size, + Tensor self_layernorm_gamma, + Tensor self_layernorm_beta, + Tensor self_kernel_q, + Tensor self_kernel_k, + Tensor self_kernel_v, + Tensor self_bias_q, + Tensor self_bias_k, + Tensor self_bias_v, + Tensor self_output_kernel, + Tensor self_output_bias, + Tensor cross_layernorm_gamma, + Tensor cross_layernorm_beta, + Tensor cross_kernel_q, + Tensor cross_kernel_k, + Tensor cross_kernel_v, + Tensor cross_bias_q, + Tensor cross_bias_k, + Tensor cross_bias_v, + Tensor cross_output_kernel, + Tensor cross_output_bias, + Tensor ffn_layernorm_gamma, + Tensor ffn_layernorm_beta, + Tensor inter_kernel, + Tensor inter_bias, + Tensor output_kernel, + Tensor output_bias) +: _st(self_layernorm_gamma.scalar_type()), + weights{self_layernorm_gamma, self_layernorm_beta, + self_kernel_q, self_kernel_k, self_kernel_v, self_bias_q, self_bias_k, self_bias_v, self_output_kernel, self_output_bias, + cross_layernorm_gamma, cross_layernorm_beta, + cross_kernel_q, cross_kernel_k, cross_kernel_v, cross_bias_q, cross_bias_k, cross_bias_v, + cross_output_kernel, cross_output_bias, + ffn_layernorm_gamma, ffn_layernorm_beta, inter_kernel, inter_bias, output_kernel, output_bias} +{ + CHECK_INPUT(self_layernorm_gamma, _st); // hidden_dim + CHECK_INPUT(self_layernorm_beta, _st); // hidden_dim + CHECK_INPUT(self_kernel_q, _st); // hidden_dim, hidden_dim + CHECK_INPUT(self_kernel_k, _st); // hidden_dim, hidden_dim + CHECK_INPUT(self_kernel_v, _st); // hidden_dim, hidden_dim + CHECK_INPUT(self_bias_q, _st); // hidden_dim + CHECK_INPUT(self_bias_k, _st); // hidden_dim + CHECK_INPUT(self_bias_v, _st); // hidden_dim + CHECK_INPUT(self_output_kernel, _st); // hidden_dim, hidden_dim + CHECK_INPUT(self_output_bias, _st); // hidden_dim + CHECK_INPUT(cross_layernorm_gamma, _st); // hidden_dim + CHECK_INPUT(cross_layernorm_beta, _st); // hidden_dim + CHECK_INPUT(cross_kernel_q, _st); // hidden_dim, hidden_dim + CHECK_INPUT(cross_kernel_k, _st); // mem_hidden_dim, hidden_dim + CHECK_INPUT(cross_kernel_v, _st); // mem_hidden_dim, hidden_dim + CHECK_INPUT(cross_bias_q, _st); // hidden_dim + CHECK_INPUT(cross_bias_k, _st); // hidden_dim + CHECK_INPUT(cross_bias_v, _st); // hidden_dim + CHECK_INPUT(cross_output_kernel, _st); // hidden_dim, hidden_dim + CHECK_INPUT(cross_output_bias, _st); // hidden_dim + CHECK_INPUT(ffn_layernorm_gamma, _st); // hidden_dim + CHECK_INPUT(ffn_layernorm_beta, _st); // hidden_dim + CHECK_INPUT(inter_kernel, _st); // hidden_dim, 4 * hidden_dim + CHECK_INPUT(inter_bias, _st); // 4 * hidden_dim + CHECK_INPUT(output_kernel, _st); // 4 * hidden_dim, hidden_dim + CHECK_INPUT(output_bias, _st); // hidden_dim + switch (_st) { + case at::ScalarType::Float: + ftdecoder = new torch_ext::FTDecoder(head_num, head_size, weights); + break; + case at::ScalarType::Half: + ftdecoder = new torch_ext::FTDecoder(head_num, head_size, weights); + break; + default: + throw std::runtime_error("Wrong Tensor type."); + } + head_info = torch::empty({2}, torch::dtype(torch::kInt64)); + head_info[0] = head_num; + head_info[1] = head_size; +} + +FasterTransformerDecoder::~FasterTransformerDecoder() { + delete ftdecoder; +} + +Tensor FasterTransformerDecoder::forward(Tensor input, Tensor memory, Tensor memory_seq_lens, Tensor self_cache, Tensor mem_cache) { + CHECK_INPUT(input, _st); + CHECK_INPUT(memory, _st); + CHECK_INPUT(self_cache, _st); + CHECK_INPUT(mem_cache, _st); + CHECK_CUDA(memory_seq_lens); CHECK_CONTIGUOUS(memory_seq_lens); TORCH_CHECK(memory_seq_lens.dtype()==torch::kInt32, "mem_seq_lens dtype should be int32"); + auto mem_size = memory.sizes(); + int batch_size = mem_size[0]; + int seq_len = mem_size[1]; + int mem_hidden_dim = mem_size[2]; + int step = self_cache.size(1); + auto output = torch::empty_like(input); + ftdecoder->forward(batch_size, seq_len, mem_hidden_dim, step, input, memory, memory_seq_lens, self_cache, mem_cache, output); + return output; +} + +std::vector FasterTransformerDecoder::get_pickle_info() const { + std::vector tmp(weights); + tmp.push_back(head_info); + return tmp; +} +} // namespace torch_ths diff --git a/fastertransformer/th_op/decoder_ths_op.h b/fastertransformer/th_op/decoder_ths_op.h new file mode 100644 index 000000000..c120abea3 --- /dev/null +++ b/fastertransformer/th_op/decoder_ths_op.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "fastertransformer/open_decoder.h" +#include "fastertransformer/th_op/decoder_ext.h" + +namespace torch_ths { +using namespace fastertransformer; +using torch::Tensor; + +class FasterTransformerDecoder : public torch::jit::CustomClassHolder { +public: + FasterTransformerDecoder( + int64_t head_num, + int64_t head_size, + Tensor self_layernorm_gamma, + Tensor self_layernorm_beta, + Tensor self_kernel_q, + Tensor self_kernel_k, + Tensor self_kernel_v, + Tensor self_bias_q, + Tensor self_bias_k, + Tensor self_bias_v, + Tensor self_output_kernel, + Tensor self_output_bias, + Tensor cross_layernorm_gamma, + Tensor cross_layernorm_beta, + Tensor cross_kernel_q, + Tensor cross_kernel_k, + Tensor cross_kernel_v, + Tensor cross_bias_q, + Tensor cross_bias_k, + Tensor cross_bias_v, + Tensor cross_output_kernel, + Tensor cross_output_bias, + Tensor ffn_layernorm_gamma, + Tensor ffn_layernorm_beta, + Tensor inter_kernel, + Tensor inter_bias, + Tensor output_kernel, + Tensor output_bias); + + ~FasterTransformerDecoder(); + + Tensor forward(Tensor input, Tensor memory, Tensor memory_seq_lens, Tensor self_cache, Tensor mem_cache); + + std::vector get_pickle_info() const; + +private: + const at::ScalarType _st; + torch_ext::IFTDecoder* ftdecoder; + Tensor head_info; + std::vector weights; +}; + +} // namespace torch_ths \ No newline at end of file diff --git a/fastertransformer/th_op/decoding_ext.cc b/fastertransformer/th_op/decoding_ext.cc new file mode 100644 index 000000000..0ffcd147a --- /dev/null +++ b/fastertransformer/th_op/decoding_ext.cc @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fastertransformer/th_op/decoding_ext.h" + +namespace torch_ext { +using torch::Tensor; + +FasterTransformerDecoding::FasterTransformerDecoding( + int head_num, + int head_size, + int mem_hidden_dim, + int layer_num, + int vocab_size, + int start_id, + int end_id, + float beam_search_diversity_rate, + Tensor self_layernorm_gamma, + Tensor self_layernorm_beta, + Tensor self_kernel_q, + Tensor self_kernel_k, + Tensor self_kernel_v, + Tensor self_bias_q, + Tensor self_bias_k, + Tensor self_bias_v, + Tensor self_output_kernel, + Tensor self_output_bias, + Tensor cross_layernorm_gamma, + Tensor cross_layernorm_beta, + Tensor cross_kernel_q, + Tensor cross_kernel_k, + Tensor cross_kernel_v, + Tensor cross_bias_q, + Tensor cross_bias_k, + Tensor cross_bias_v, + Tensor cross_output_kernel, + Tensor cross_output_bias, + Tensor ffn_layernorm_gamma, + Tensor ffn_layernorm_beta, + Tensor inter_kernel, + Tensor inter_bias, + Tensor output_kernel, + Tensor output_bias, + Tensor decoding_gamma, + Tensor decoding_beta, + Tensor embedding_table, + Tensor position_encoding_table, + Tensor embedding_kernel, + Tensor embedding_bias) +: _st(self_layernorm_gamma.scalar_type()) +{ + CHECK_INPUT(self_layernorm_gamma, _st); // layer_num, hidden_dim + CHECK_INPUT(self_layernorm_beta, _st); // layer_num, hidden_dim + CHECK_INPUT(self_kernel_q, _st); // hidden_dim, hidden_dim + CHECK_INPUT(self_kernel_k, _st); // hidden_dim, hidden_dim + CHECK_INPUT(self_kernel_v, _st); // hidden_dim, hidden_dim + CHECK_INPUT(self_bias_q, _st); // layer_num, hidden_dim + CHECK_INPUT(self_bias_k, _st); // layer_num, hidden_dim + CHECK_INPUT(self_bias_v, _st); // layer_num, hidden_dim + CHECK_INPUT(self_output_kernel, _st); // layer_num, hidden_dim, hidden_dim + CHECK_INPUT(self_output_bias, _st); // layer_num, hidden_dim + CHECK_INPUT(cross_layernorm_gamma, _st); // layer_num, hidden_dim + CHECK_INPUT(cross_layernorm_beta, _st); // layer_num, hidden_dim + CHECK_INPUT(cross_kernel_q, _st); // layer_num, hidden_dim, hidden_dim + CHECK_INPUT(cross_kernel_k, _st); // layer_num, mem_hidden_dim, hidden_dim + CHECK_INPUT(cross_kernel_v, _st); // layer_num, mem_hidden_dim, hidden_dim + CHECK_INPUT(cross_bias_q, _st); // layer_num, hidden_dim + CHECK_INPUT(cross_bias_k, _st); // layer_num, hidden_dim + CHECK_INPUT(cross_bias_v, _st); // layer_num, hidden_dim + CHECK_INPUT(cross_output_kernel, _st); // layer_num, hidden_dim, hidden_dim + CHECK_INPUT(cross_output_bias, _st); // layer_num, hidden_dim + CHECK_INPUT(ffn_layernorm_gamma, _st); // layer_num, hidden_dim + CHECK_INPUT(ffn_layernorm_beta, _st); // layer_num, hidden_dim + CHECK_INPUT(inter_kernel, _st); // layer_num, hidden_dim, 4 * hidden_dim + CHECK_INPUT(inter_bias, _st); // layer_num, 4 * hidden_dim + CHECK_INPUT(output_kernel, _st); // layer_num, 4 * hidden_dim, hidden_dim + CHECK_INPUT(output_bias, _st); // layer_num, hidden_dim + CHECK_INPUT(decoding_gamma, _st); // hidden_dim + CHECK_INPUT(decoding_beta, _st); // hidden_dim + CHECK_INPUT(embedding_table, _st); // vocab_size, hidden_dim + CHECK_INPUT(position_encoding_table, _st); // max_step, hidden_dim + CHECK_INPUT(embedding_kernel, _st); // hidden_dim, vocab_size + CHECK_INPUT(embedding_bias, at::ScalarType::Float); // vocab_size + std::vector weights{self_layernorm_gamma, self_layernorm_beta, + self_kernel_q, self_kernel_k, self_kernel_v, self_bias_q, self_bias_k, self_bias_v, + self_output_kernel, self_output_bias, + cross_layernorm_gamma, cross_layernorm_beta, + cross_kernel_q, cross_kernel_k, cross_kernel_v, cross_bias_q, cross_bias_k, cross_bias_v, + cross_output_kernel, cross_output_bias, + ffn_layernorm_gamma, ffn_layernorm_beta, inter_kernel, inter_bias, output_kernel, output_bias, + decoding_gamma, decoding_beta, embedding_table, position_encoding_table, + embedding_kernel, embedding_bias}; + switch (_st) { + case at::ScalarType::Float: + ftdecoding = new FTDecoding(head_num, head_size, mem_hidden_dim, layer_num, vocab_size, + start_id, end_id, beam_search_diversity_rate, weights); + break; + case at::ScalarType::Half: + ftdecoding = new FTDecoding(head_num, head_size, mem_hidden_dim, layer_num, vocab_size, + start_id, end_id, beam_search_diversity_rate, weights); + break; + default: + throw std::runtime_error("Wrong Tensor type."); + } +} + +FasterTransformerDecoding::~FasterTransformerDecoding() { + delete ftdecoding; +} + +std::vector FasterTransformerDecoding::forward(int batch_size, int beam_size, int max_seq_len, Tensor memory, Tensor memory_seq_lens) { + CHECK_INPUT(memory, _st); + CHECK_CUDA(memory_seq_lens); CHECK_CONTIGUOUS(memory_seq_lens); TORCH_CHECK(memory_seq_lens.dtype()==torch::kInt32, "mem_seq_lens dtype should be int32"); + int mem_max_seq_len = memory.size(1); + auto output_ids = torch::empty({batch_size * beam_size * max_seq_len}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + auto parent_ids = torch::empty({batch_size * beam_size * max_seq_len}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + auto out_seq_lens = torch::empty({batch_size * beam_size}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + ftdecoding->forward(batch_size, beam_size, max_seq_len, mem_max_seq_len, + memory, memory_seq_lens, output_ids, parent_ids, out_seq_lens); + return std::vector{output_ids, parent_ids, out_seq_lens}; +} + +Tensor gather_tree(Tensor step_ids, Tensor parent_ids, Tensor max_sequence_lengths, int end_token) { + CHECK_CUDA(step_ids); CHECK_CONTIGUOUS(step_ids); TORCH_CHECK(step_ids.dtype()==torch::kInt32, "step_ids dtype should be int32"); + CHECK_CUDA(parent_ids); CHECK_CONTIGUOUS(parent_ids); TORCH_CHECK(parent_ids.dtype()==torch::kInt32, "parent_ids dtype should be int32"); + CHECK_CUDA(max_sequence_lengths); CHECK_CONTIGUOUS(max_sequence_lengths); TORCH_CHECK(max_sequence_lengths.dtype()==torch::kInt32, "max_sequence_lengths dtype should be int32"); + int max_step = step_ids.size(0); + int batch_size = step_ids.size(1); + int beam_width = step_ids.size(2); + auto beams = torch::empty_like(step_ids); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + gather_tree_kernel_launcher(max_step, batch_size, beam_width, + get_ptr(step_ids), + get_ptr(parent_ids), + get_ptr(max_sequence_lengths), + end_token, + get_ptr(beams), + stream); + return beams; +} + +} // namespace torch_ext \ No newline at end of file diff --git a/fastertransformer/th_op/decoding_ext.h b/fastertransformer/th_op/decoding_ext.h new file mode 100644 index 000000000..3e5a24a89 --- /dev/null +++ b/fastertransformer/th_op/decoding_ext.h @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "torch/extension.h" +#include "torch/csrc/cuda/Stream.h" + +#include "fastertransformer/open_decoder.h" +#include "fastertransformer/decoding_beamsearch.h" +#include "fastertransformer/th_op/th_traits.h" +#include "fastertransformer/th_op/utils.h" + +namespace torch_ext { +using namespace fastertransformer; +using torch::Tensor; + +class IFTDecoding { +public: + virtual ~IFTDecoding() {} + virtual void forward(int batch_size, int beam_size, int max_seq_len, int mem_max_seq_len, + Tensor memory, Tensor memory_seq_lens, Tensor output_ids, Tensor parent_ids, Tensor out_seq_lens) = 0; +}; + +template +class FTDecoding : public IFTDecoding { +public: + FTDecoding(int head_num, int head_size, int mem_hidden_dim, int layer_num, int vocab_size, + int start_id, int end_id, float beam_search_diversity_rate, const std::vector& w) + : _head_num(head_num), _head_size(head_size), _mem_hidden_dim(mem_hidden_dim), _layer_num(layer_num), _vocab_size(vocab_size), + _start_id(start_id), _end_id(end_id), _beam_search_diversity_rate(beam_search_diversity_rate), _weights(w) + { + check_cuda_error(cublasCreate(&_cublasHandle)); + decoder_params = new DecoderInitParam[_layer_num]; + const int hidden_dim = _head_num * _head_size; + for (int i = 0; i < _layer_num; ++i) { + decoder_params[i].self_layernorm.gamma = get_ptr(_weights[0]) + i * hidden_dim; + decoder_params[i].self_layernorm.beta = get_ptr(_weights[1]) + i * hidden_dim; + decoder_params[i].self_attention.query_weight.kernel = get_ptr(_weights[2]) + i * hidden_dim * hidden_dim; + decoder_params[i].self_attention.key_weight.kernel = get_ptr(_weights[3]) + i * hidden_dim * hidden_dim; + decoder_params[i].self_attention.value_weight.kernel = get_ptr(_weights[4]) + i * hidden_dim * hidden_dim; + decoder_params[i].self_attention.query_weight.bias = get_ptr(_weights[5]) + i * hidden_dim; + decoder_params[i].self_attention.key_weight.bias = get_ptr(_weights[6]) + i * hidden_dim; + decoder_params[i].self_attention.value_weight.bias = get_ptr(_weights[7]) + i * hidden_dim; + decoder_params[i].self_attention.attention_output_weight.kernel = get_ptr(_weights[8]) + i * hidden_dim * hidden_dim; + decoder_params[i].self_attention.attention_output_weight.bias = get_ptr(_weights[9]) + i * hidden_dim; + decoder_params[i].cross_layernorm.gamma = get_ptr(_weights[10]) + i * hidden_dim; + decoder_params[i].cross_layernorm.beta = get_ptr(_weights[11]) + i * hidden_dim; + decoder_params[i].cross_attention.query_weight.kernel = get_ptr(_weights[12]) + i * hidden_dim * hidden_dim; + decoder_params[i].cross_attention.key_weight.kernel = get_ptr(_weights[13]) + i * mem_hidden_dim * hidden_dim; + decoder_params[i].cross_attention.value_weight.kernel = get_ptr(_weights[14]) + i * mem_hidden_dim * hidden_dim; + decoder_params[i].cross_attention.query_weight.bias = get_ptr(_weights[15]) + i * hidden_dim; + decoder_params[i].cross_attention.key_weight.bias = get_ptr(_weights[16]) + i * hidden_dim; + decoder_params[i].cross_attention.value_weight.bias = get_ptr(_weights[17]) + i * hidden_dim; + decoder_params[i].cross_attention.attention_output_weight.kernel = get_ptr(_weights[18]) + i * hidden_dim * hidden_dim; + decoder_params[i].cross_attention.attention_output_weight.bias = get_ptr(_weights[19]) + i * hidden_dim; + decoder_params[i].ffn_layernorm.gamma = get_ptr(_weights[20]) + i * hidden_dim; + decoder_params[i].ffn_layernorm.beta = get_ptr(_weights[21]) + i * hidden_dim; + decoder_params[i].ffn.intermediate_weight.kernel = get_ptr(_weights[22]) + i * hidden_dim * hidden_dim * 4; + decoder_params[i].ffn.intermediate_weight.bias = get_ptr(_weights[23]) + i * hidden_dim * 4; + decoder_params[i].ffn.output_weight.kernel = get_ptr(_weights[24]) + i * hidden_dim * hidden_dim * 4; + decoder_params[i].ffn.output_weight.bias = get_ptr(_weights[25]) + i * hidden_dim; + decoder_params[i].cublas_handle = _cublasHandle; + } + decoding_params.layernorm.gamma = get_ptr(_weights[26]); + decoding_params.layernorm.beta = get_ptr(_weights[27]); + decoding_params.embedding_table = get_ptr(_weights[28]); + decoding_params.position_encoding_table = get_ptr(_weights[29]); + decoding_params.embedding_kernel = get_ptr(_weights[30]); + decoding_params.embedding_bias = get_ptr(_weights[31]); + decoding_params.cublas_handle = _cublasHandle; + } + + ~FTDecoding() override { + cublasDestroy(_cublasHandle); + delete [] decoder_params; + } + + void forward(int batch_size, int beam_size, int max_seq_len, int mem_max_seq_len, + Tensor memory, Tensor memory_seq_lens, Tensor output_ids, Tensor parent_ids, Tensor out_seq_lens) override + { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + check_cuda_error(cublasSetStream(_cublasHandle, stream)); + decoding_params.stream = stream; + for(int i = 0; i < _layer_num; ++i) + { + decoder_params[i].stream = stream; + check_cuda_error(cublasSetStream(decoder_params[i].cublas_handle, stream)); + } + check_cuda_error(cublasSetStream(decoding_params.cublas_handle, stream)); + + decoding_params.output_ids = get_ptr(output_ids); + decoding_params.parent_ids = get_ptr(parent_ids); + decoding_params.sequence_length = get_ptr(out_seq_lens); + check_cuda_error(cudaMemset(decoding_params.output_ids, 0, sizeof(int) * batch_size * beam_size * max_seq_len)); + check_cuda_error(cudaMemset(decoding_params.parent_ids, 0, sizeof(int) * batch_size * beam_size * max_seq_len)); + check_cuda_error(cudaMemset(decoding_params.sequence_length, 0, sizeof(int) * batch_size * beam_size)); + decoding_params.memory_tensor = get_ptr(memory); + decoding_params.memory_sequence_length = get_ptr(memory_seq_lens); + + fastertransformer::Allocator allocator; + DecodingBeamsearch::OpType>* decoding = + new DecodingBeamsearch::OpType>(allocator, batch_size, beam_size, max_seq_len, _head_num, _head_size, _vocab_size, + _layer_num, _mem_hidden_dim, mem_max_seq_len, _start_id, _end_id, _beam_search_diversity_rate); + decoding->forward(decoder_params, decoding_params); + delete decoding; + } + +private: + const int _head_num; + const int _head_size; + const int _mem_hidden_dim; + const int _layer_num; + const int _vocab_size; + const int _start_id; + const int _end_id; + const float _beam_search_diversity_rate; + std::vector _weights; + cublasHandle_t _cublasHandle; + DecodingInitParam decoding_params; + DecoderInitParam* decoder_params; +}; + +class FasterTransformerDecoding { +public: + FasterTransformerDecoding( + int head_num, + int head_size, + int mem_hidden_dim, + int layer_num, + int vocab_size, + int start_id, + int end_id, + float beam_search_diversity_rate, + Tensor self_layernorm_gamma, + Tensor self_layernorm_beta, + Tensor self_kernel_q, + Tensor self_kernel_k, + Tensor self_kernel_v, + Tensor self_bias_q, + Tensor self_bias_k, + Tensor self_bias_v, + Tensor self_output_kernel, + Tensor self_output_bias, + Tensor cross_layernorm_gamma, + Tensor cross_layernorm_beta, + Tensor cross_kernel_q, + Tensor cross_kernel_k, + Tensor cross_kernel_v, + Tensor cross_bias_q, + Tensor cross_bias_k, + Tensor cross_bias_v, + Tensor cross_output_kernel, + Tensor cross_output_bias, + Tensor ffn_layernorm_gamma, + Tensor ffn_layernorm_beta, + Tensor inter_kernel, + Tensor inter_bias, + Tensor output_kernel, + Tensor output_bias, + Tensor decoding_gamma, + Tensor decoding_beta, + Tensor embedding_table, + Tensor position_encoding_table, + Tensor embedding_kernel, + Tensor embedding_bias); + + ~FasterTransformerDecoding(); + + std::vector forward(int batch_size, int beam_size, int max_seq_len, Tensor memory, Tensor memory_seq_lens); + +private: + const at::ScalarType _st; + IFTDecoding* ftdecoding; +}; + +Tensor gather_tree(Tensor step_ids, Tensor parent_ids, Tensor max_sequence_lengths, int end_token); + +} // namespace torch_ext \ No newline at end of file diff --git a/fastertransformer/th_op/decoding_ths_op.cc b/fastertransformer/th_op/decoding_ths_op.cc new file mode 100644 index 000000000..13cad077a --- /dev/null +++ b/fastertransformer/th_op/decoding_ths_op.cc @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fastertransformer/th_op/decoding_ths_op.h" + +namespace torch_ths { +using torch::Tensor; + +FasterTransformerDecoding::FasterTransformerDecoding( + int64_t head_num, + int64_t head_size, + int64_t mem_hidden_dim, + int64_t layer_num, + int64_t vocab_size, + int64_t start_id, + int64_t end_id, + double beam_search_diversity_rate, + Tensor self_layernorm_gamma, + Tensor self_layernorm_beta, + Tensor self_kernel_q, + Tensor self_kernel_k, + Tensor self_kernel_v, + Tensor self_bias_q, + Tensor self_bias_k, + Tensor self_bias_v, + Tensor self_output_kernel, + Tensor self_output_bias, + Tensor cross_layernorm_gamma, + Tensor cross_layernorm_beta, + Tensor cross_kernel_q, + Tensor cross_kernel_k, + Tensor cross_kernel_v, + Tensor cross_bias_q, + Tensor cross_bias_k, + Tensor cross_bias_v, + Tensor cross_output_kernel, + Tensor cross_output_bias, + Tensor ffn_layernorm_gamma, + Tensor ffn_layernorm_beta, + Tensor inter_kernel, + Tensor inter_bias, + Tensor output_kernel, + Tensor output_bias, + Tensor decoding_gamma, + Tensor decoding_beta, + Tensor embedding_table, + Tensor position_encoding_table, + Tensor embedding_kernel, + Tensor embedding_bias) +: _st(self_layernorm_gamma.scalar_type()), + weights{self_layernorm_gamma, self_layernorm_beta, + self_kernel_q, self_kernel_k, self_kernel_v, self_bias_q, self_bias_k, self_bias_v, + self_output_kernel, self_output_bias, + cross_layernorm_gamma, cross_layernorm_beta, + cross_kernel_q, cross_kernel_k, cross_kernel_v, cross_bias_q, cross_bias_k, cross_bias_v, + cross_output_kernel, cross_output_bias, + ffn_layernorm_gamma, ffn_layernorm_beta, inter_kernel, inter_bias, output_kernel, output_bias, + decoding_gamma, decoding_beta, embedding_table, position_encoding_table, + embedding_kernel, embedding_bias} +{ + CHECK_INPUT(self_layernorm_gamma, _st); // layer_num, hidden_dim + CHECK_INPUT(self_layernorm_beta, _st); // layer_num, hidden_dim + CHECK_INPUT(self_kernel_q, _st); // hidden_dim, hidden_dim + CHECK_INPUT(self_kernel_k, _st); // hidden_dim, hidden_dim + CHECK_INPUT(self_kernel_v, _st); // hidden_dim, hidden_dim + CHECK_INPUT(self_bias_q, _st); // layer_num, hidden_dim + CHECK_INPUT(self_bias_k, _st); // layer_num, hidden_dim + CHECK_INPUT(self_bias_v, _st); // layer_num, hidden_dim + CHECK_INPUT(self_output_kernel, _st); // layer_num, hidden_dim, hidden_dim + CHECK_INPUT(self_output_bias, _st); // layer_num, hidden_dim + CHECK_INPUT(cross_layernorm_gamma, _st); // layer_num, hidden_dim + CHECK_INPUT(cross_layernorm_beta, _st); // layer_num, hidden_dim + CHECK_INPUT(cross_kernel_q, _st); // layer_num, hidden_dim, hidden_dim + CHECK_INPUT(cross_kernel_k, _st); // layer_num, mem_hidden_dim, hidden_dim + CHECK_INPUT(cross_kernel_v, _st); // layer_num, mem_hidden_dim, hidden_dim + CHECK_INPUT(cross_bias_q, _st); // layer_num, hidden_dim + CHECK_INPUT(cross_bias_k, _st); // layer_num, hidden_dim + CHECK_INPUT(cross_bias_v, _st); // layer_num, hidden_dim + CHECK_INPUT(cross_output_kernel, _st); // layer_num, hidden_dim, hidden_dim + CHECK_INPUT(cross_output_bias, _st); // layer_num, hidden_dim + CHECK_INPUT(ffn_layernorm_gamma, _st); // layer_num, hidden_dim + CHECK_INPUT(ffn_layernorm_beta, _st); // layer_num, hidden_dim + CHECK_INPUT(inter_kernel, _st); // layer_num, hidden_dim, 4 * hidden_dim + CHECK_INPUT(inter_bias, _st); // layer_num, 4 * hidden_dim + CHECK_INPUT(output_kernel, _st); // layer_num, 4 * hidden_dim, hidden_dim + CHECK_INPUT(output_bias, _st); // layer_num, hidden_dim + CHECK_INPUT(decoding_gamma, _st); // hidden_dim + CHECK_INPUT(decoding_beta, _st); // hidden_dim + CHECK_INPUT(embedding_table, _st); // vocab_size, hidden_dim + CHECK_INPUT(position_encoding_table, _st); // max_step, hidden_dim + CHECK_INPUT(embedding_kernel, _st); // hidden_dim, vocab_size + CHECK_INPUT(embedding_bias, at::ScalarType::Float); // vocab_size + switch (_st) { + case at::ScalarType::Float: + ftdecoding = new torch_ext::FTDecoding(head_num, head_size, mem_hidden_dim, layer_num, vocab_size, + start_id, end_id, (float)beam_search_diversity_rate, weights); + break; + case at::ScalarType::Half: + ftdecoding = new torch_ext::FTDecoding(head_num, head_size, mem_hidden_dim, layer_num, vocab_size, + start_id, end_id, (float)beam_search_diversity_rate, weights); + break; + default: + throw std::runtime_error("Wrong Tensor type."); + } + decoding_info = torch::empty({7}, torch::dtype(torch::kInt64)); + decoding_info[0] = head_num; + decoding_info[1] = head_size; + decoding_info[2] = mem_hidden_dim; + decoding_info[3] = layer_num; + decoding_info[4] = vocab_size; + decoding_info[5] = start_id; + decoding_info[6] = end_id; + beam_search_diversity_rate_info = torch::empty({1}, torch::dtype(torch::kFloat64)); + beam_search_diversity_rate_info[0] = beam_search_diversity_rate; +} + +FasterTransformerDecoding::~FasterTransformerDecoding() { + delete ftdecoding; +} + +std::vector FasterTransformerDecoding::forward(int64_t batch_size, int64_t beam_size, int64_t max_seq_len, Tensor memory, Tensor memory_seq_lens) { + CHECK_INPUT(memory, _st); + CHECK_CUDA(memory_seq_lens); CHECK_CONTIGUOUS(memory_seq_lens); TORCH_CHECK(memory_seq_lens.dtype()==torch::kInt32, "mem_seq_lens dtype should be int32"); + int mem_max_seq_len = memory.size(1); + auto output_ids = torch::empty({batch_size * beam_size * max_seq_len}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + auto parent_ids = torch::empty({batch_size * beam_size * max_seq_len}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + auto out_seq_lens = torch::empty({batch_size * beam_size}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + ftdecoding->forward(batch_size, beam_size, max_seq_len, mem_max_seq_len, + memory, memory_seq_lens, output_ids, parent_ids, out_seq_lens); + return std::vector{output_ids, parent_ids, out_seq_lens}; +} + +std::vector FasterTransformerDecoding::get_pickle_info() const { + std::vector tmp(weights); + tmp.push_back(decoding_info); + tmp.push_back(beam_search_diversity_rate_info); + return tmp; +} + +Tensor gather_tree(Tensor step_ids, Tensor parent_ids, Tensor max_sequence_lengths, int64_t end_token) { + CHECK_CUDA(step_ids); CHECK_CONTIGUOUS(step_ids); TORCH_CHECK(step_ids.dtype()==torch::kInt32, "step_ids dtype should be int32"); + CHECK_CUDA(parent_ids); CHECK_CONTIGUOUS(parent_ids); TORCH_CHECK(parent_ids.dtype()==torch::kInt32, "parent_ids dtype should be int32"); + CHECK_CUDA(max_sequence_lengths); CHECK_CONTIGUOUS(max_sequence_lengths); TORCH_CHECK(max_sequence_lengths.dtype()==torch::kInt32, "max_sequence_lengths dtype should be int32"); + int max_step = step_ids.size(0); + int batch_size = step_ids.size(1); + int beam_width = step_ids.size(2); + auto beams = torch::empty_like(step_ids); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + torch_ext::gather_tree_kernel_launcher(max_step, batch_size, beam_width, + torch_ext::get_ptr(step_ids), + torch_ext::get_ptr(parent_ids), + torch_ext::get_ptr(max_sequence_lengths), + end_token, + torch_ext::get_ptr(beams), + stream); + return beams; +} + +} // namespace torch_ths \ No newline at end of file diff --git a/fastertransformer/th_op/decoding_ths_op.h b/fastertransformer/th_op/decoding_ths_op.h new file mode 100644 index 000000000..dae52e2bd --- /dev/null +++ b/fastertransformer/th_op/decoding_ths_op.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "fastertransformer/open_decoder.h" +#include "fastertransformer/decoding_beamsearch.h" +#include "fastertransformer/th_op/decoding_ext.h" + +namespace torch_ths { +using namespace fastertransformer; +using torch::Tensor; + +class FasterTransformerDecoding : public torch::jit::CustomClassHolder { +public: + FasterTransformerDecoding( + int64_t head_num, + int64_t head_size, + int64_t mem_hidden_dim, + int64_t layer_num, + int64_t vocab_size, + int64_t start_id, + int64_t end_id, + double beam_search_diversity_rate, + Tensor self_layernorm_gamma, + Tensor self_layernorm_beta, + Tensor self_kernel_q, + Tensor self_kernel_k, + Tensor self_kernel_v, + Tensor self_bias_q, + Tensor self_bias_k, + Tensor self_bias_v, + Tensor self_output_kernel, + Tensor self_output_bias, + Tensor cross_layernorm_gamma, + Tensor cross_layernorm_beta, + Tensor cross_kernel_q, + Tensor cross_kernel_k, + Tensor cross_kernel_v, + Tensor cross_bias_q, + Tensor cross_bias_k, + Tensor cross_bias_v, + Tensor cross_output_kernel, + Tensor cross_output_bias, + Tensor ffn_layernorm_gamma, + Tensor ffn_layernorm_beta, + Tensor inter_kernel, + Tensor inter_bias, + Tensor output_kernel, + Tensor output_bias, + Tensor decoding_gamma, + Tensor decoding_beta, + Tensor embedding_table, + Tensor position_encoding_table, + Tensor embedding_kernel, + Tensor embedding_bias); + + ~FasterTransformerDecoding(); + + std::vector forward(int64_t batch_size, int64_t beam_size, int64_t max_seq_len, Tensor memory, Tensor memory_seq_lens); + + std::vector get_pickle_info() const; + +private: + const at::ScalarType _st; + torch_ext::IFTDecoding* ftdecoding; + Tensor decoding_info; + Tensor beam_search_diversity_rate_info; + std::vector weights; +}; + +Tensor gather_tree(Tensor step_ids, Tensor parent_ids, Tensor max_sequence_lengths, int64_t end_token); + +} // namespace torch_ths \ No newline at end of file diff --git a/fastertransformer/th_op/encoder_ext.cc b/fastertransformer/th_op/encoder_ext.cc new file mode 100644 index 000000000..564d1a40c --- /dev/null +++ b/fastertransformer/th_op/encoder_ext.cc @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fastertransformer/th_op/encoder_ext.h" + +namespace torch_ext { +using torch::Tensor; + +FasterTransformerEncoder::FasterTransformerEncoder( + int head_num, + int head_size, + bool remove_padding, + Tensor q_kernel, + Tensor q_bias, + Tensor k_kernel, + Tensor k_bias, + Tensor v_kernel, + Tensor v_bias, + Tensor attr_output_kernel, + Tensor attr_output_bias, + Tensor attr_output_layernorm_gamma, + Tensor attr_output_layernorm_beta, + Tensor inter_kernel, + Tensor inter_bias, + Tensor output_kernel, + Tensor output_bias, + Tensor output_layernorm_gamma, + Tensor output_layernorm_beta) +: _st(q_kernel.scalar_type()), _remove_padding(remove_padding) +{ + CHECK_INPUT(q_kernel, _st); // hidden_dim, hidden_dim + CHECK_INPUT(q_bias, _st); // hidden_dim + CHECK_INPUT(k_kernel, _st); // hidden_dim, hidden_dim + CHECK_INPUT(k_bias, _st); // hidden_dim + CHECK_INPUT(v_kernel, _st); // hidden_dim, hidden_dim + CHECK_INPUT(v_bias, _st); // hidden_dim + CHECK_INPUT(attr_output_kernel, _st); // hidden_dim, hidden_dim + CHECK_INPUT(attr_output_bias, _st); // hidden_dim + CHECK_INPUT(attr_output_layernorm_gamma, _st); // hidden_dim + CHECK_INPUT(attr_output_layernorm_beta, _st); // hidden_dim + CHECK_INPUT(inter_kernel, _st); // hidden_dim, 4 * hidden_dim + CHECK_INPUT(inter_bias, _st); // 4 * hidden_dim + CHECK_INPUT(output_kernel, _st); // 4 * hidden_dim, hidden_dim + CHECK_INPUT(output_bias, _st); // hidden_dim + CHECK_INPUT(output_layernorm_gamma, _st); // hidden_dim + CHECK_INPUT(output_layernorm_beta, _st); // hidden_dim + std::vector weights{q_kernel, q_bias, k_kernel, k_bias, v_kernel, v_bias, + attr_output_kernel, attr_output_bias, attr_output_layernorm_gamma, attr_output_layernorm_beta, + inter_kernel, inter_bias, output_kernel, output_bias, output_layernorm_gamma, output_layernorm_beta}; + switch (_st) { + case at::ScalarType::Float: + ftencoder = new FTEncoder(head_num, head_size, weights); + break; + case at::ScalarType::Half: + ftencoder = new FTEncoder(head_num, head_size, weights); + break; + default: + throw std::runtime_error("Wrong Tensor type."); + } +} + +FasterTransformerEncoder::~FasterTransformerEncoder() { + delete ftencoder; +} + +Tensor FasterTransformerEncoder::forward(Tensor input, Tensor attr_mask, Tensor sequence_lengths) { + auto input_size = input.sizes(); + int batch_size = input_size[0]; + int seq_len = input_size[1]; + CHECK_INPUT(input, _st); + CHECK_INPUT(attr_mask, _st); + if (_remove_padding) { + CHECK_CUDA(sequence_lengths); CHECK_CONTIGUOUS(sequence_lengths); + TORCH_CHECK(sequence_lengths.dtype()==torch::kInt32, "sequence_length dtype should be int32"); + TORCH_CHECK(sequence_lengths.numel()!=0, "sequence_length should not be empty tensor"); + TORCH_CHECK(sequence_lengths.size(0)==batch_size, "wrong sequence_length shape"); + } + auto output = torch::empty_like(input); + ftencoder->forward(batch_size, seq_len, input, attr_mask, output, sequence_lengths, _remove_padding); + return output; +} +} // namespace torch_ext diff --git a/fastertransformer/th_op/encoder_ext.h b/fastertransformer/th_op/encoder_ext.h new file mode 100644 index 000000000..e735f8312 --- /dev/null +++ b/fastertransformer/th_op/encoder_ext.h @@ -0,0 +1,183 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "torch/extension.h" +#include "torch/csrc/cuda/Stream.h" + +#include "fastertransformer/faster_transformer.h" +#include "fastertransformer/cuda/cuda_kernels.h" +#include "fastertransformer/th_op/th_traits.h" +#include "fastertransformer/th_op/utils.h" + + +namespace torch_ext { +using namespace fastertransformer; +using torch::Tensor; + +class IFTEncoder { +public: + virtual ~IFTEncoder() {} + virtual void forward(int batch_size, + int seq_len, + Tensor& input, + Tensor& attr_mask, + Tensor& output, + Tensor& sequence_lengths, + bool removing_padding) = 0; +}; + +template +class FTEncoder : public IFTEncoder { +public: + FTEncoder(int head_num, int head_size, const std::vector& w) : _head_num(head_num), _head_size(head_size), _weights(w) { + int hidden_dim = _head_num * _head_size; + check_cuda_error(cublasCreate(&_cublasHandle)); + encoder_param.self_attention.query_weight.kernel = get_ptr(_weights[0]); + encoder_param.self_attention.query_weight.bias = get_ptr(_weights[1]); + encoder_param.self_attention.key_weight.kernel = get_ptr(_weights[2]); + encoder_param.self_attention.key_weight.bias = get_ptr(_weights[3]); + encoder_param.self_attention.value_weight.kernel = get_ptr(_weights[4]); + encoder_param.self_attention.value_weight.bias = get_ptr(_weights[5]); + encoder_param.self_attention.attention_output_weight.kernel = get_ptr(_weights[6]); + encoder_param.self_attention.attention_output_weight.bias = get_ptr(_weights[7]); + encoder_param.self_layernorm.gamma = get_ptr(_weights[8]); + encoder_param.self_layernorm.beta = get_ptr(_weights[9]); + encoder_param.ffn.intermediate_weight.kernel = get_ptr(_weights[10]); + encoder_param.ffn.intermediate_weight.bias = get_ptr(_weights[11]); + encoder_param.ffn.output_weight.kernel = get_ptr(_weights[12]); + encoder_param.ffn.output_weight.bias = get_ptr(_weights[13]); + encoder_param.ffn_layernorm.gamma = get_ptr(_weights[14]); + encoder_param.ffn_layernorm.beta = get_ptr(_weights[15]); + encoder_param.cublas_handle = _cublasHandle; + } + + ~FTEncoder() override { + cublasDestroy(_cublasHandle); + } + + void forward(int batch_size, + int seq_len, + Tensor& input, + Tensor& attr_mask, + Tensor& output, + Tensor& sequence_lengths, + bool removing_padding) override { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + encoder_param.stream = stream; + int hidden_dim = _head_num * _head_size; + std::vector buf_vector; + + if (removing_padding) { + const T* input_ptr = get_ptr(input); + const int* sequence_lengths_ptr = get_ptr(sequence_lengths); + auto buf = torch::empty({batch_size * seq_len + 1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + int* tmp_sequence_id_offset = get_ptr(buf); + int* d_valid_word_num = tmp_sequence_id_offset + batch_size * seq_len; + build_sequence_length_padding_offset_kernelLauncher(sequence_lengths_ptr, batch_size, seq_len, + d_valid_word_num, tmp_sequence_id_offset, stream); + int* h_valid_word_num = new int[1]; + cudaMemcpyAsync(h_valid_word_num, d_valid_word_num, sizeof(int), cudaMemcpyDeviceToHost, stream); + const int valid_word_num = h_valid_word_num[0]; + delete h_valid_word_num; + auto intermediate_input = + torch::empty({valid_word_num, hidden_dim}, torch::dtype(input.dtype()).device(torch::kCUDA).requires_grad(false)); + buf_vector.push_back(intermediate_input); + T* intermediate_input_ptr = get_ptr(intermediate_input); + auto sequence_id_offset = + torch::empty({valid_word_num}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + buf_vector.push_back(sequence_id_offset); + int* sequence_id_offset_ptr = get_ptr(sequence_id_offset); + remove_sequence_length_padding_kernelLauncher(input_ptr, intermediate_input_ptr, + tmp_sequence_id_offset, sequence_id_offset_ptr, + valid_word_num, hidden_dim, stream); + auto intermediate_output = torch::empty_like(intermediate_input); + buf_vector.push_back(intermediate_output); + encoder_param.from_tensor = intermediate_input_ptr; + encoder_param.to_tensor = intermediate_input_ptr; + encoder_param.sequence_id_offset = sequence_id_offset_ptr; + encoder_param.valid_word_num = valid_word_num; + encoder_param.transformer_out = get_ptr(intermediate_output); + } else { + encoder_param.from_tensor = get_ptr(input); + encoder_param.to_tensor = get_ptr(input); + encoder_param.sequence_id_offset = nullptr; + encoder_param.valid_word_num = batch_size * seq_len; + encoder_param.transformer_out = get_ptr(output); + } + + encoder_param.attr_mask = get_ptr(attr_mask); + check_cuda_error(cublasSetStream(encoder_param.cublas_handle, encoder_param.stream)); + fastertransformer::Allocator allocator; + BertEncoderTransformer* encoder = + new BertEncoderTransformer(allocator, batch_size, seq_len, seq_len, _head_num, _head_size); + encoder->initialize(encoder_param); + encoder->forward(); + delete encoder; + + if (removing_padding) { + rebuild_sequence_length_padding_kernelLauncher(encoder_param.transformer_out, get_ptr(output), + encoder_param.sequence_id_offset, encoder_param.valid_word_num, + hidden_dim, stream); + } + } + +private: + typedef BertEncoderTransformerTraits::OpType, cuda::OpenMultiHeadAttention> EncoderTraits_; + const int _head_num; + const int _head_size; + std::vector _weights; + cublasHandle_t _cublasHandle; + EncoderInitParam encoder_param; +}; + +class FasterTransformerEncoder { +public: + FasterTransformerEncoder( + int head_num, + int head_size, + bool remove_padding, + Tensor q_kernel, + Tensor q_bias, + Tensor k_kernel, + Tensor k_bias, + Tensor v_kernel, + Tensor v_bias, + Tensor attr_output_kernel, + Tensor attr_output_bias, + Tensor attr_output_layernorm_gamma, + Tensor attr_output_layernorm_beta, + Tensor inter_kernel, + Tensor inter_bias, + Tensor output_kernel, + Tensor output_bias, + Tensor output_layernorm_gamma, + Tensor output_layernorm_beta); + + ~FasterTransformerEncoder(); + + Tensor forward(Tensor input, Tensor attr_mask, Tensor sequence_lengths); + +private: + const at::ScalarType _st; + bool _remove_padding; + IFTEncoder* ftencoder; +}; +} // namespace torch_ext \ No newline at end of file diff --git a/fastertransformer/th_op/encoder_ths_op.cc b/fastertransformer/th_op/encoder_ths_op.cc new file mode 100644 index 000000000..ff54efbcc --- /dev/null +++ b/fastertransformer/th_op/encoder_ths_op.cc @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fastertransformer/th_op/encoder_ths_op.h" + +namespace torch_ths { +using torch::Tensor; + +FasterTransformerEncoder::FasterTransformerEncoder( + int64_t head_num, + int64_t head_size, + bool remove_padding, + Tensor q_kernel, + Tensor q_bias, + Tensor k_kernel, + Tensor k_bias, + Tensor v_kernel, + Tensor v_bias, + Tensor attr_output_kernel, + Tensor attr_output_bias, + Tensor attr_output_layernorm_gamma, + Tensor attr_output_layernorm_beta, + Tensor inter_kernel, + Tensor inter_bias, + Tensor output_kernel, + Tensor output_bias, + Tensor output_layernorm_gamma, + Tensor output_layernorm_beta) +: _st(q_kernel.scalar_type()), _remove_padding(remove_padding), + weights{q_kernel, q_bias, k_kernel, k_bias, v_kernel, v_bias, + attr_output_kernel, attr_output_bias, attr_output_layernorm_gamma, attr_output_layernorm_beta, + inter_kernel, inter_bias, output_kernel, output_bias, output_layernorm_gamma, output_layernorm_beta} +{ + CHECK_INPUT(q_kernel, _st); // hidden_dim, hidden_dim + CHECK_INPUT(q_bias, _st); // hidden_dim + CHECK_INPUT(k_kernel, _st); // hidden_dim, hidden_dim + CHECK_INPUT(k_bias, _st); // hidden_dim + CHECK_INPUT(v_kernel, _st); // hidden_dim, hidden_dim + CHECK_INPUT(v_bias, _st); // hidden_dim + CHECK_INPUT(attr_output_kernel, _st); // hidden_dim, hidden_dim + CHECK_INPUT(attr_output_bias, _st); // hidden_dim + CHECK_INPUT(attr_output_layernorm_gamma, _st); // hidden_dim + CHECK_INPUT(attr_output_layernorm_beta, _st); // hidden_dim + CHECK_INPUT(inter_kernel, _st); // 4 * hidden_dim, hidden_dim + CHECK_INPUT(inter_bias, _st); // 4 * hidden_dim + CHECK_INPUT(output_kernel, _st); // hidden_dim, 4 * hidden_dim + CHECK_INPUT(output_bias, _st); // hidden_dim + CHECK_INPUT(output_layernorm_gamma, _st); // hidden_dim + CHECK_INPUT(output_layernorm_beta, _st); // hidden_dim + switch (_st) { + case at::ScalarType::Float: + ftencoder = new torch_ext::FTEncoder(head_num, head_size, weights); + break; + case at::ScalarType::Half: + ftencoder = new torch_ext::FTEncoder(head_num, head_size, weights); + break; + default: + throw std::runtime_error("Wrong Tensor type."); + } + head_info = torch::empty({3}, torch::dtype(torch::kInt64)); + head_info[0] = head_num; + head_info[1] = head_size; + head_info[2] = (int64_t)remove_padding; +} + +FasterTransformerEncoder::~FasterTransformerEncoder() { + delete ftencoder; +} + +Tensor FasterTransformerEncoder::forward(Tensor input, Tensor attr_mask, Tensor sequence_lengths) { + auto input_size = input.sizes(); + int batch_size = input_size[0]; + int seq_len = input_size[1]; + CHECK_INPUT(input, _st); + CHECK_INPUT(attr_mask, _st); + if (_remove_padding) { + CHECK_CUDA(sequence_lengths); CHECK_CONTIGUOUS(sequence_lengths); + TORCH_CHECK(sequence_lengths.dtype()==torch::kInt32, "sequence_length dtype should be int32"); + TORCH_CHECK(sequence_lengths.numel()!=0, "sequence_length should not be empty tensor"); + TORCH_CHECK(sequence_lengths.size(0)==batch_size, "wrong sequence_length shape"); + } + auto output = torch::empty_like(input); + ftencoder->forward(batch_size, seq_len, input, attr_mask, output, sequence_lengths, _remove_padding); + return output; +} + +std::vector FasterTransformerEncoder::get_pickle_info() const { + std::vector tmp(weights); + tmp.push_back(head_info); + return tmp; +} +} // namespace torch_ths diff --git a/fastertransformer/th_op/encoder_ths_op.h b/fastertransformer/th_op/encoder_ths_op.h new file mode 100644 index 000000000..5c5b003cb --- /dev/null +++ b/fastertransformer/th_op/encoder_ths_op.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "fastertransformer/faster_transformer.h" +#include "fastertransformer/th_op/encoder_ext.h" + +namespace torch_ths { +using namespace fastertransformer; +using torch::Tensor; + +class FasterTransformerEncoder : public torch::jit::CustomClassHolder { +public: + FasterTransformerEncoder( + int64_t head_num, + int64_t head_size, + bool remove_padding, + Tensor q_kernel, + Tensor q_bias, + Tensor k_kernel, + Tensor k_bias, + Tensor v_kernel, + Tensor v_bias, + Tensor attr_output_kernel, + Tensor attr_output_bias, + Tensor attr_output_layernorm_gamma, + Tensor attr_output_layernorm_beta, + Tensor inter_kernel, + Tensor inter_bias, + Tensor output_kernel, + Tensor output_bias, + Tensor output_layernorm_gamma, + Tensor output_layernorm_beta); + + ~FasterTransformerEncoder(); + + Tensor forward(Tensor input, Tensor attr_mask, Tensor sequence_lengths); + + std::vector get_pickle_info() const; + +private: + const at::ScalarType _st; + bool _remove_padding; + torch_ext::IFTEncoder* ftencoder; + Tensor head_info; + std::vector weights; +}; +} // namespace torch_ths \ No newline at end of file diff --git a/fastertransformer/th_op/encoder_ths_op_f.cc b/fastertransformer/th_op/encoder_ths_op_f.cc new file mode 100644 index 000000000..2a8be29e0 --- /dev/null +++ b/fastertransformer/th_op/encoder_ths_op_f.cc @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fastertransformer/th_op/encoder_ths_op_f.h" + +namespace torch_ths { +using torch::Tensor; + +Tensor fastertransformerthsencoder( + int64_t head_num, + int64_t head_size, + bool remove_padding, + Tensor q_kernel, + Tensor q_bias, + Tensor k_kernel, + Tensor k_bias, + Tensor v_kernel, + Tensor v_bias, + Tensor attr_output_kernel, + Tensor attr_output_bias, + Tensor attr_output_layernorm_gamma, + Tensor attr_output_layernorm_beta, + Tensor inter_kernel, + Tensor inter_bias, + Tensor output_kernel, + Tensor output_bias, + Tensor output_layernorm_gamma, + Tensor output_layernorm_beta, + Tensor input, + Tensor attr_mask, + Tensor sequence_lengths) +{ + const at::ScalarType _st = q_kernel.scalar_type(); + CHECK_INPUT(q_kernel, _st); // hidden_dim, hidden_dim + CHECK_INPUT(q_bias, _st); // hidden_dim + CHECK_INPUT(k_kernel, _st); // hidden_dim, hidden_dim + CHECK_INPUT(k_bias, _st); // hidden_dim + CHECK_INPUT(v_kernel, _st); // hidden_dim, hidden_dim + CHECK_INPUT(v_bias, _st); // hidden_dim + CHECK_INPUT(attr_output_kernel, _st); // hidden_dim, hidden_dim + CHECK_INPUT(attr_output_bias, _st); // hidden_dim + CHECK_INPUT(attr_output_layernorm_gamma, _st); // hidden_dim + CHECK_INPUT(attr_output_layernorm_beta, _st); // hidden_dim + CHECK_INPUT(inter_kernel, _st); // 4 * hidden_dim, hidden_dim + CHECK_INPUT(inter_bias, _st); // 4 * hidden_dim + CHECK_INPUT(output_kernel, _st); // hidden_dim, 4 * hidden_dim + CHECK_INPUT(output_bias, _st); // hidden_dim + CHECK_INPUT(output_layernorm_gamma, _st); // hidden_dim + CHECK_INPUT(output_layernorm_beta, _st); // hidden_dim + CHECK_INPUT(input, _st); + CHECK_INPUT(attr_mask, _st); + auto input_size = input.sizes(); + int batch_size = input_size[0]; + int seq_len = input_size[1]; + if (remove_padding) { + CHECK_CUDA(sequence_lengths); CHECK_CONTIGUOUS(sequence_lengths); + TORCH_CHECK(sequence_lengths.dtype()==torch::kInt32, "sequence_length dtype should be int32"); + TORCH_CHECK(sequence_lengths.numel()!=0, "sequence_length should not be empty tensor"); + TORCH_CHECK(sequence_lengths.size(0)==batch_size, "wrong sequence_length shape"); + } + std::vector weights{q_kernel, q_bias, k_kernel, k_bias, v_kernel, v_bias, + attr_output_kernel, attr_output_bias, attr_output_layernorm_gamma, attr_output_layernorm_beta, + inter_kernel, inter_bias, output_kernel, output_bias, output_layernorm_gamma, output_layernorm_beta}; + auto output = torch::empty_like(input); + switch (_st) { + case at::ScalarType::Float: + ftencoder(head_num, head_size, weights, batch_size, seq_len, remove_padding, input, attr_mask, sequence_lengths, output); + break; + case at::ScalarType::Half: + ftencoder(head_num, head_size, weights, batch_size, seq_len, remove_padding, input, attr_mask, sequence_lengths, output); + break; + default: + throw std::runtime_error("Wrong Tensor type."); + } + return output; +} +} // namespace torch_ths diff --git a/fastertransformer/th_op/encoder_ths_op_f.h b/fastertransformer/th_op/encoder_ths_op_f.h new file mode 100644 index 000000000..b5e6e4386 --- /dev/null +++ b/fastertransformer/th_op/encoder_ths_op_f.h @@ -0,0 +1,143 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include "torch/csrc/cuda/Stream.h" +#include "ATen/cuda/CUDAContext.h" + +#include "fastertransformer/faster_transformer.h" +#include "fastertransformer/cuda/cuda_kernels.h" +#include "fastertransformer/th_op/th_traits.h" +#include "fastertransformer/th_op/utils.h" + +namespace torch_ths { +using namespace fastertransformer; +using torch::Tensor; +using torch_ext::THTraits; +using torch_ext::get_ptr; + +template +void ftencoder(int head_num, int head_size, std::vector& weights, + int batch_size, int seq_len, bool removing_padding, + Tensor& input, Tensor& attr_mask, Tensor& sequence_lengths, Tensor& output) { + typedef BertEncoderTransformerTraits::OpType, cuda::OpenMultiHeadAttention> EncoderTraits_; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + cublasHandle_t cublasHandle = at::cuda::getCurrentCUDABlasHandle(); + EncoderInitParam encoder_param; + int hidden_dim = head_num * head_size; + encoder_param.self_attention.query_weight.kernel = get_ptr(weights[0]); + encoder_param.self_attention.query_weight.bias = get_ptr(weights[1]); + encoder_param.self_attention.key_weight.kernel = get_ptr(weights[2]); + encoder_param.self_attention.key_weight.bias = get_ptr(weights[3]); + encoder_param.self_attention.value_weight.kernel = get_ptr(weights[4]); + encoder_param.self_attention.value_weight.bias = get_ptr(weights[5]); + encoder_param.self_attention.attention_output_weight.kernel = get_ptr(weights[6]); + encoder_param.self_attention.attention_output_weight.bias = get_ptr(weights[7]); + encoder_param.self_layernorm.gamma = get_ptr(weights[8]); + encoder_param.self_layernorm.beta = get_ptr(weights[9]); + encoder_param.ffn.intermediate_weight.kernel = get_ptr(weights[10]); + encoder_param.ffn.intermediate_weight.bias = get_ptr(weights[11]); + encoder_param.ffn.output_weight.kernel = get_ptr(weights[12]); + encoder_param.ffn.output_weight.bias = get_ptr(weights[13]); + encoder_param.ffn_layernorm.gamma = get_ptr(weights[14]); + encoder_param.ffn_layernorm.beta = get_ptr(weights[15]); + encoder_param.attr_mask = get_ptr(attr_mask); + encoder_param.stream = stream; + encoder_param.cublas_handle = cublasHandle; + + std::vector buf_vector; + if (removing_padding) { + const T* input_ptr = get_ptr(input); + const int* sequence_lengths_ptr = get_ptr(sequence_lengths); + auto buf = torch::empty({batch_size * seq_len + 1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + int* tmp_sequence_id_offset = get_ptr(buf); + int* d_valid_word_num = tmp_sequence_id_offset + batch_size * seq_len; + build_sequence_length_padding_offset_kernelLauncher(sequence_lengths_ptr, batch_size, seq_len, + d_valid_word_num, tmp_sequence_id_offset, stream); + int* h_valid_word_num = new int[1]; + cudaMemcpyAsync(h_valid_word_num, d_valid_word_num, sizeof(int), cudaMemcpyDeviceToHost, stream); + const int valid_word_num = h_valid_word_num[0]; + delete h_valid_word_num; + auto intermediate_input = + torch::empty({valid_word_num, hidden_dim}, torch::dtype(input.dtype()).device(torch::kCUDA).requires_grad(false)); + buf_vector.push_back(intermediate_input); + T* intermediate_input_ptr = get_ptr(intermediate_input); + auto sequence_id_offset = + torch::empty({valid_word_num}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + buf_vector.push_back(sequence_id_offset); + int* sequence_id_offset_ptr = get_ptr(sequence_id_offset); + remove_sequence_length_padding_kernelLauncher(input_ptr, intermediate_input_ptr, + tmp_sequence_id_offset, sequence_id_offset_ptr, + valid_word_num, hidden_dim, stream); + auto intermediate_output = torch::empty_like(intermediate_input); + buf_vector.push_back(intermediate_output); + encoder_param.from_tensor = intermediate_input_ptr; + encoder_param.to_tensor = intermediate_input_ptr; + encoder_param.sequence_id_offset = sequence_id_offset_ptr; + encoder_param.valid_word_num = valid_word_num; + encoder_param.transformer_out = get_ptr(intermediate_output); + } else { + encoder_param.from_tensor = get_ptr(input); + encoder_param.to_tensor = get_ptr(input); + encoder_param.sequence_id_offset = nullptr; + encoder_param.valid_word_num = batch_size * seq_len; + encoder_param.transformer_out = get_ptr(output); + } + + fastertransformer::Allocator allocator; + BertEncoderTransformer* encoder = + new BertEncoderTransformer(allocator, batch_size, seq_len, seq_len, head_num, head_size); + encoder->initialize(encoder_param); + encoder->forward(); + delete encoder; + + if (removing_padding) { + rebuild_sequence_length_padding_kernelLauncher(encoder_param.transformer_out, get_ptr(output), + encoder_param.sequence_id_offset, encoder_param.valid_word_num, + hidden_dim, stream); + } +} + +Tensor fastertransformerthsencoder( + int64_t head_num, + int64_t head_size, + bool remove_padding, + Tensor q_kernel, + Tensor q_bias, + Tensor k_kernel, + Tensor k_bias, + Tensor v_kernel, + Tensor v_bias, + Tensor attr_output_kernel, + Tensor attr_output_bias, + Tensor attr_output_layernorm_gamma, + Tensor attr_output_layernorm_beta, + Tensor inter_kernel, + Tensor inter_bias, + Tensor output_kernel, + Tensor output_bias, + Tensor output_layernorm_gamma, + Tensor output_layernorm_beta, + Tensor input, + Tensor attr_mask, + Tensor sequence_lengths); + +} // namespace torch_ths \ No newline at end of file diff --git a/fastertransformer/th_op/ft_ext.cc b/fastertransformer/th_op/ft_ext.cc new file mode 100644 index 000000000..8d2fee094 --- /dev/null +++ b/fastertransformer/th_op/ft_ext.cc @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "torch/extension.h" + +#include "fastertransformer/th_op/encoder_ext.h" +#include "fastertransformer/th_op/decoder_ext.h" +#include "fastertransformer/th_op/decoding_ext.h" + +using torch::Tensor; +namespace py = pybind11; + +PYBIND11_MODULE(th_fastertransformer, m) { + py::class_(m, "FasterTransformerEncoder") + .def(py::init()) + .def("forward", &torch_ext::FasterTransformerEncoder::forward); + + py::class_(m, "FasterTransformerDecoder") + .def(py::init()) + .def("forward", &torch_ext::FasterTransformerDecoder::forward); + + py::class_(m, "FasterTransformerDecoding") + .def(py::init()) + .def("forward", &torch_ext::FasterTransformerDecoding::forward); + + m.def("gather_tree", &torch_ext::gather_tree); +} diff --git a/fastertransformer/th_op/ft_ths_op.cc b/fastertransformer/th_op/ft_ths_op.cc new file mode 100644 index 000000000..afeb998e6 --- /dev/null +++ b/fastertransformer/th_op/ft_ths_op.cc @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "fastertransformer/th_op/encoder_ths_op.h" +#include "fastertransformer/th_op/decoder_ths_op.h" +#include "fastertransformer/th_op/decoding_ths_op.h" + +using torch::Tensor; + +static auto fasterTransformerEncoderTHS = + torch::jit::class_("FasterTransformerEncoder") + .def(torch::jit::init()) + .def("forward", &torch_ths::FasterTransformerEncoder::forward) + .def_pickle( + [](const c10::intrusive_ptr& self) -> std::vector { + return self->get_pickle_info(); + }, + [](std::vector state) -> c10::intrusive_ptr { + int head_num = state[16][0].item().to(); + int head_size = state[16][1].item().to(); + bool remove_padding = (bool)(state[16][2].item().to()); + return c10::make_intrusive(head_num, head_size, remove_padding, + state[0], state[1], state[2], state[3], state[4], state[5], + state[6], state[7], state[8], state[9], state[10], state[11], + state[12], state[13], state[14], state[15]); + } + ); + +static auto fasterTransformerDecoderTHS = + torch::jit::class_("FasterTransformerDecoder") + .def(torch::jit::init()) + .def("forward", &torch_ths::FasterTransformerDecoder::forward) + .def_pickle( + [](const c10::intrusive_ptr& self) -> std::vector { + return self->get_pickle_info(); + }, + [](std::vector state) -> c10::intrusive_ptr { + int head_num = state[26][0].item().to(); + int head_size = state[26][1].item().to(); + return c10::make_intrusive(head_num, head_size, + state[0], state[1], state[2], state[3], state[4], state[5], state[6], state[7], + state[8], state[9], state[10], state[11], state[12], state[13], state[14], state[15], + state[16], state[17], state[18], state[19], state[20], state[21], state[22], state[23], + state[24], state[25]); + } + ); + +static auto fasterTransformerDecodingTHS = + torch::jit::class_("FasterTransformerDecoding") + .def(torch::jit::init()) + .def("forward", &torch_ths::FasterTransformerDecoding::forward) + .def_pickle( + [](const c10::intrusive_ptr& self) -> std::vector { + return self->get_pickle_info(); + }, + [](std::vector state) -> c10::intrusive_ptr { + int head_num = state[32][0].item().to(); + int head_size = state[32][1].item().to(); + int mem_hidden_dim = state[32][2].item().to(); + int layer_num = state[32][3].item().to(); + int vocab_size = state[32][4].item().to(); + int start_id = state[32][5].item().to(); + int end_id = state[32][6].item().to(); + double beam_search_diversity_rate = state[31][0].item().to(); + return c10::make_intrusive(head_num, head_size, + mem_hidden_dim, layer_num, vocab_size, start_id, end_id, beam_search_diversity_rate, + state[0], state[1], state[2], state[3], state[4], state[5], state[6], state[7], + state[8], state[9], state[10], state[11], state[12], state[13], state[14], state[15], + state[16], state[17], state[18], state[19], state[20], state[21], state[22], state[23], + state[24], state[25], state[26], state[27], state[28], state[29], state[30], state[31]); + } + ); + +static auto gather_tree = + torch::RegisterOperators("fastertransformer::gather_tree", &torch_ths::gather_tree); diff --git a/fastertransformer/th_op/ft_ths_op_f.cc b/fastertransformer/th_op/ft_ths_op_f.cc new file mode 100644 index 000000000..deb28c62f --- /dev/null +++ b/fastertransformer/th_op/ft_ths_op_f.cc @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "fastertransformer/th_op/encoder_ths_op_f.h" + +using torch::Tensor; + +static auto registry = + torch::RegisterOperators("fastertransformer::encoder", &torch_ths::fastertransformerthsencoder); diff --git a/fastertransformer/th_op/th_traits.h b/fastertransformer/th_op/th_traits.h new file mode 100644 index 000000000..331bfba08 --- /dev/null +++ b/fastertransformer/th_op/th_traits.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#ifndef TORCH_TRAITS_H_ +#define TORCH_TRAITS_H_ + +#include + +using namespace fastertransformer; +namespace torch_ext +{ + template class THTraits; + + template <> + class THTraits + { + public: + static const OperationType OpType = OperationType::FP32; + }; + + template <> + class THTraits + { + public: + static const OperationType OpType = OperationType::FP16; + }; + +} //namespace torch_ext +#endif diff --git a/fastertransformer/th_op/utils.cu b/fastertransformer/th_op/utils.cu new file mode 100644 index 000000000..5fe0ac949 --- /dev/null +++ b/fastertransformer/th_op/utils.cu @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fastertransformer/th_op/utils.h" + +namespace torch_ext { + +// modified from TensorFlow's implementation of tf.contrib.seq2seq.gather_tree +__global__ void gather_tree_kernel(const int batch_size, const int max_time, const int beam_width, const int end_token, + const int* step_ids, const int* parent_ids, const int* max_sequence_lengths, int* beams) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size * beam_width; i += gridDim.x * blockDim.x) { + const int batch = i / beam_width; + const int beam = i % beam_width; + + const int max_seq_len_b = min(max_time, __ldg(max_sequence_lengths + batch)); + if (max_seq_len_b <= 0) { + continue; + } + +#define GET_IX(time_ix, beam_ix) (batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix)) + + const int initial_beam_ix = GET_IX(max_seq_len_b - 1, beam); + beams[initial_beam_ix] = __ldg(step_ids + initial_beam_ix); + int parent = __ldg(parent_ids + initial_beam_ix); + bool found_bad = false; + for (int level = max_seq_len_b - 2; level >= 0; --level) { + const int level_beam_ix = GET_IX(level, beam); + const int level_parent_ix = GET_IX(level, parent); + if (parent < 0 || parent > beam_width) { + beams[level_beam_ix] = -1; + parent = -1; + found_bad = true; + } else { + beams[level_beam_ix] = __ldg(step_ids + level_parent_ix); + parent = __ldg(parent_ids + level_parent_ix); + } + } +// Not necessary when using a BeamSearchDecoder, but necessary +// when a user feeds in possibly broken trajectory (i.e., non-eos +// entries in a beam following eos entries). + if (!found_bad) { + bool finished = false; + for (int time = 0; time < max_seq_len_b; ++time) { + const int level_beam_ix = GET_IX(time, beam); + if (finished) { + beams[level_beam_ix] = end_token; + } else if (beams[level_beam_ix] == end_token) { + finished = true; + } + } + } +#undef GET_IX + } +} + + +void gather_tree_kernel_launcher(int max_time, int batch_size, int beam_width, + int* step_ids, int* parent_ids, int* max_sequence_lengths, + int end_token, int* beams, cudaStream_t stream) { + int batchbeam = batch_size * beam_width; + dim3 grid(1), block(batchbeam); + // though decoder do not support > 1024 for now + if (batchbeam > 1024) { + grid.x = ceil(batch_size * beam_width / 1024.); + block.x = 1024; + } + gather_tree_kernel<<>>(batch_size, max_time, beam_width, end_token, + step_ids, parent_ids, max_sequence_lengths, beams); +} +} // namespace torch_ext diff --git a/fastertransformer/th_op/utils.h b/fastertransformer/th_op/utils.h new file mode 100644 index 000000000..16d4ca02f --- /dev/null +++ b/fastertransformer/th_op/utils.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include "torch/extension.h" + +#define CHECK_TYPE(x, st) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type: " #x) +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x, st) CHECK_CUDA(x); CHECK_CONTIGUOUS(x); CHECK_TYPE(x, st) +#define PRINT_TENSOR(x) std::cout << #x << ":\n" << x << std::endl +#define PRINT_TENSOR_SIZE(x) std::cout << "size of " << #x << ": " << x.sizes() << std::endl + +namespace torch_ext { + +template +void print_ptr(T* p, int r, int c); + +template +inline T* get_ptr(torch::Tensor& t) { + return reinterpret_cast(t.data_ptr()); +} + +void gather_tree_kernel_launcher(int max_time, int batch_size, int beam_width, + int* step_ids, int* parent_ids, int* max_sequence_lengths, + int end_token, int* beams, cudaStream_t stream); + +} // namespace torch_ext diff --git a/fastertransformer/trt_plugin/bert_transformer_plugin.h b/fastertransformer/trt_plugin/bert_transformer_plugin.h index 16c30c42a..c04b24aef 100644 --- a/fastertransformer/trt_plugin/bert_transformer_plugin.h +++ b/fastertransformer/trt_plugin/bert_transformer_plugin.h @@ -145,7 +145,7 @@ class TransformerPlugin: public IPluginV2 encoder_param.ffn_layernorm.beta = d_output_layernorm_beta_; encoder_param.ffn_layernorm.gamma = d_output_layernorm_gamma_; encoder_param.cublas_handle = cublas_handle_; - + encoder_transformer_->initialize(encoder_param); } catch(std::runtime_error& error) @@ -201,7 +201,7 @@ class TransformerPlugin: public IPluginV2 BertEncoderTransformer(*allocator_, max_batch_size, seq_len, seq_len, head_num, hidden_dim / head_num); EncoderInitParam encoder_param; //init param here - + encoder_param.self_attention.query_weight.kernel = d_attr_kernel_Q_; encoder_param.self_attention.key_weight.kernel = d_attr_kernel_K_; encoder_param.self_attention.value_weight.kernel = d_attr_kernel_V_; @@ -269,13 +269,13 @@ class TransformerPlugin: public IPluginV2 bool supportsFormat(nvinfer1::DataType type, PluginFormat format) const override { - return type == nvinfer1::DataType::kFLOAT && format == PluginFormat::kNCHW; + return type == TransformerTrtTraits::DataType && format == PluginFormat::kNCHW; } void configureWithFormat(const Dims* pInputDim, int nInputDim, const Dims* pOutputDim, int nOutputDim, nvinfer1::DataType dataType, nvinfer1::PluginFormat pluginFormat, int maxBatchSize) override { - assert(dataType == nvinfer1::DataType::kFLOAT && pluginFormat == nvinfer1::PluginFormat::kNCHW); + assert(dataType == TransformerTrtTraits::DataType && pluginFormat == nvinfer1::PluginFormat::kNCHW); assert(nInputDim == 2); assert(pInputDim[0].nbDims == 2 && pInputDim[0].d[0] == seq_len_ && pInputDim[0].d[1] == hidden_dim_); assert(pInputDim[1].nbDims == 2 && pInputDim[1].d[0] == seq_len_ && pInputDim[1].d[1] == seq_len_); diff --git a/fastertransformer/trt_plugin/trt_model.h b/fastertransformer/trt_plugin/trt_model.h index ea28972d9..a2cfdfcf9 100644 --- a/fastertransformer/trt_plugin/trt_model.h +++ b/fastertransformer/trt_plugin/trt_model.h @@ -104,7 +104,7 @@ class TRT_Transformer builder->setMaxBatchSize(batch_size_); builder->setMaxWorkspaceSize(1 << 20); - builder->setFp16Mode(false); + builder->setFp16Mode(sizeof(T) == 2); engine_ = builder->buildCudaEngine(*network); assert(engine_); diff --git a/images/effective_transformer.png b/images/effective_transformer.png new file mode 100644 index 000000000..0b0920534 Binary files /dev/null and b/images/effective_transformer.png differ diff --git a/images/encoder-decoding-2.png b/images/encoder-decoding-2.png new file mode 100644 index 000000000..eb10ca66c Binary files /dev/null and b/images/encoder-decoding-2.png differ diff --git a/sample/cpp/CMakeLists.txt b/sample/cpp/CMakeLists.txt index 8ae136431..80cdf3ba0 100644 --- a/sample/cpp/CMakeLists.txt +++ b/sample/cpp/CMakeLists.txt @@ -17,12 +17,19 @@ set(encoder_sample_files encoder_sample.cc ) -set(decoding_sample_files - decoding_sample.cc +set(decoding_beamsearch_sample_files + decoding_beamsearch_sample.cc +) + +set(decoding_sampling_sample_files + decoding_sampling_sample.cc ) add_executable(encoder_sample ${encoder_sample_files}) -target_link_libraries(encoder_sample PUBLIC -lcublas -lcudart fastertransformer) +target_link_libraries(encoder_sample PUBLIC -lcublas -lcudart encoder) + +add_executable(decoding_beamsearch_sample ${decoding_beamsearch_sample_files}) +target_link_libraries(decoding_beamsearch_sample PUBLIC -lcublas -lcudart decoder decoding) -add_executable(decoding_sample ${decoding_sample_files}) -target_link_libraries(decoding_sample PUBLIC -lcublas -lcudart fastertransformer) \ No newline at end of file +add_executable(decoding_sampling_sample ${decoding_sampling_sample_files}) +target_link_libraries(decoding_sampling_sample PUBLIC -lcublas -lcudart -lcurand decoder decoding) \ No newline at end of file diff --git a/sample/cpp/decoding_sample.cc b/sample/cpp/decoding_beamsearch_sample.cc similarity index 72% rename from sample/cpp/decoding_sample.cc rename to sample/cpp/decoding_beamsearch_sample.cc index d91ac3ee4..52d740430 100644 --- a/sample/cpp/decoding_sample.cc +++ b/sample/cpp/decoding_beamsearch_sample.cc @@ -15,7 +15,7 @@ */ #include "fastertransformer/open_decoder.h" -#include "fastertransformer/decoding_opennmt.h" +#include "fastertransformer/decoding_beamsearch.h" #include #include #include @@ -24,6 +24,8 @@ #include #include +#include + using namespace fastertransformer; template @@ -125,35 +127,35 @@ void decoding_sample(int batch_size, T *d_cross_gamma, *d_cross_beta; T *d_ffn_gamma, *d_ffn_beta; - device_malloc(&d_self_Q_kernel, sizeof(T) * hidden_units * hidden_units); - device_malloc(&d_self_K_kernel, sizeof(T) * hidden_units * hidden_units); - device_malloc(&d_self_V_kernel, sizeof(T) * hidden_units * hidden_units); - device_malloc(&d_self_output_kernel, sizeof(T) * hidden_units * hidden_units); - device_malloc(&d_self_Q_bias, sizeof(T) * hidden_units); - device_malloc(&d_self_K_bias, sizeof(T) * hidden_units); - device_malloc(&d_self_V_bias, sizeof(T) * hidden_units); - device_malloc(&d_self_output_bias, sizeof(T) * hidden_units); + device_malloc(&d_self_Q_kernel, hidden_units * hidden_units); + device_malloc(&d_self_K_kernel, hidden_units * hidden_units); + device_malloc(&d_self_V_kernel, hidden_units * hidden_units); + device_malloc(&d_self_output_kernel, hidden_units * hidden_units); + device_malloc(&d_self_Q_bias, hidden_units); + device_malloc(&d_self_K_bias, hidden_units); + device_malloc(&d_self_V_bias, hidden_units); + device_malloc(&d_self_output_bias, hidden_units); - device_malloc(&d_cross_Q_kernel, sizeof(T) * hidden_units * hidden_units); - device_malloc(&d_cross_K_kernel, sizeof(T) * memory_hidden_units * hidden_units); - device_malloc(&d_cross_V_kernel, sizeof(T) * memory_hidden_units * hidden_units); - device_malloc(&d_cross_output_kernel, sizeof(T) * hidden_units * hidden_units); - device_malloc(&d_cross_Q_bias, sizeof(T) * hidden_units); - device_malloc(&d_cross_K_bias, sizeof(T) * hidden_units); - device_malloc(&d_cross_V_bias, sizeof(T) * hidden_units); - device_malloc(&d_cross_output_bias, sizeof(T) * hidden_units); + device_malloc(&d_cross_Q_kernel, hidden_units * hidden_units); + device_malloc(&d_cross_K_kernel, memory_hidden_units * hidden_units); + device_malloc(&d_cross_V_kernel, memory_hidden_units * hidden_units); + device_malloc(&d_cross_output_kernel, hidden_units * hidden_units); + device_malloc(&d_cross_Q_bias, hidden_units); + device_malloc(&d_cross_K_bias, hidden_units); + device_malloc(&d_cross_V_bias, hidden_units); + device_malloc(&d_cross_output_bias, hidden_units); - device_malloc(&d_ffn_bias1, sizeof(T) * inner_size); - device_malloc(&d_ffn_kernel1, sizeof(T) * inner_size * hidden_units); - device_malloc(&d_ffn_bias2, sizeof(T) * hidden_units); - device_malloc(&d_ffn_kernel2, sizeof(T) * inner_size * hidden_units); + device_malloc(&d_ffn_bias1, inner_size); + device_malloc(&d_ffn_kernel1, inner_size * hidden_units); + device_malloc(&d_ffn_bias2, hidden_units); + device_malloc(&d_ffn_kernel2, inner_size * hidden_units); - device_malloc(&d_self_gamma, sizeof(T) * hidden_units); - device_malloc(&d_self_beta, sizeof(T) * hidden_units); - device_malloc(&d_cross_gamma, sizeof(T) * hidden_units); - device_malloc(&d_cross_beta, sizeof(T) * hidden_units); - device_malloc(&d_ffn_gamma, sizeof(T) * hidden_units); - device_malloc(&d_ffn_beta, sizeof(T) * hidden_units); + device_malloc(&d_self_gamma, hidden_units); + device_malloc(&d_self_beta, hidden_units); + device_malloc(&d_cross_gamma, hidden_units); + device_malloc(&d_cross_beta, hidden_units); + device_malloc(&d_ffn_gamma, hidden_units); + device_malloc(&d_ffn_beta, hidden_units); param[i].self_attention.query_weight.kernel = d_self_Q_kernel; param[i].self_attention.key_weight.kernel = d_self_K_kernel; @@ -191,25 +193,27 @@ void decoding_sample(int batch_size, T *d_embedding_table; T* d_embedding_kernel; float* d_embedding_bias; + T* d_position_encoding_table; int* d_output_ids; int* d_parent_ids; int* d_sequence_lengths; int* d_memory_sequence_lengths; T *d_gamma, *d_beta; - device_malloc(&d_memory_tensor, sizeof(T) * hidden_units * seq_len * batch_size * beam_width); - device_malloc(&d_embedding_table, sizeof(T) * hidden_units * vocab_size); - device_malloc(&d_embedding_kernel, sizeof(T) * vocab_size * hidden_units); - check_cuda_error(cudaMalloc((void**)&d_embedding_bias, sizeof(float) * vocab_size)); + device_malloc(&d_memory_tensor, memory_hidden_units * memory_seq_len * batch_size * beam_width); + device_malloc(&d_embedding_table, hidden_units * vocab_size); + device_malloc(&d_embedding_kernel, vocab_size * hidden_units); + device_malloc(&d_embedding_bias, vocab_size); + device_malloc(&d_position_encoding_table, max_seq_len * hidden_units); check_cuda_error(cudaMalloc((void**)&d_output_ids, sizeof(int) * (max_seq_len) * batch_size * beam_width)); check_cuda_error(cudaMalloc((void**)&d_parent_ids, sizeof(int) * (max_seq_len) * batch_size * beam_width)); check_cuda_error(cudaMalloc((void**)&d_sequence_lengths, sizeof(int) * batch_size * beam_width)); check_cuda_error(cudaMalloc((void**)&d_memory_sequence_lengths, sizeof(int) * batch_size * beam_width)); - device_malloc(&d_gamma, sizeof(T) * hidden_units); - device_malloc(&d_beta, sizeof(T) * hidden_units); + device_malloc(&d_gamma, hidden_units); + device_malloc(&d_beta, hidden_units); int *h_memory_sequence_lengths = new int[batch_size * beam_width]; - for(int i = 0; i < batch_size * beam_width; i++) h_memory_sequence_lengths[i] = seq_len; + for(int i = 0; i < batch_size * beam_width; i++) h_memory_sequence_lengths[i] = memory_seq_len; check_cuda_error(cudaMemcpy(d_memory_sequence_lengths, h_memory_sequence_lengths, sizeof(int) * batch_size * beam_width, cudaMemcpyHostToDevice)); decoding_params.cublas_handle = cublasHandle; @@ -218,6 +222,7 @@ void decoding_sample(int batch_size, decoding_params.embedding_table = d_embedding_table; decoding_params.embedding_kernel = d_embedding_kernel; decoding_params.embedding_bias = d_embedding_bias; + decoding_params.position_encoding_table = d_position_encoding_table; decoding_params.output_ids = d_output_ids; decoding_params.parent_ids = d_parent_ids; decoding_params.sequence_length = d_sequence_lengths; @@ -227,15 +232,15 @@ void decoding_sample(int batch_size, const fastertransformer::OperationType type = sizeof(T) == sizeof(float) ? OperationType::FP32 : OperationType::FP16; - DecodingOpenNMT *decoding = new - DecodingOpenNMT(allocator, batch_size, beam_width, + DecodingBeamsearch *decoding = new + DecodingBeamsearch(allocator, batch_size, beam_width, max_seq_len, head_num, size_per_head, vocab_size, decoder_layers, memory_hidden_units, memory_seq_len, start_id, end_id); //warm up - int ite = 100; + int ite = 50; for(int i = 0; i < ite; ++i) decoding->forward(param, decoding_params); @@ -248,11 +253,36 @@ void decoding_sample(int batch_size, cudaDeviceSynchronize(); gettimeofday(&end, NULL); - printf("[batch_size %d beam_width %d head_num %d size_per_head %d seq_len %d decoder_layers %d vocab_size %d] costs %.2f ms\n", + + printf("[INFO] batch_size %d beam_width %d head_num %d size_per_head %d seq_len %d" \ + " decoder_layers %d vocab_size %d FT-CPP-decoding-beamsearch-time %.2f ms\n", batch_size, beam_width, head_num, size_per_head, seq_len, decoder_layers, vocab_size, ((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite); - printf("done\n"); + // std::string fName = "out"; + // auto outFile = std::ofstream(fName, std::ios::out); + + // size_t outCount = max_seq_len * batch_size * beam_width; + // int *hBuf = new int[outCount]; + // cudaDeviceSynchronize(); + // cudaMemcpy(hBuf, d_output_ids, outCount * sizeof(int), cudaMemcpyDeviceToHost); + + // { + // std::cout << "Writing " << outCount << " elements\n"; + // int zerroCount = 0; + // //outFile.precision(5); + // //outFile << std::fixed << std::scientific; + // for (size_t i = 0; i < outCount; i++) + // { + // if (hBuf[i] == int(0)) zerroCount++; + // outFile << hBuf[i] << std::endl; + // } + // std::cout << "zerroCount = " << zerroCount << std::endl; + // } + // delete [] hBuf; + + delete [] param; + delete [] h_memory_sequence_lengths; delete decoding; return ; } \ No newline at end of file diff --git a/sample/cpp/decoding_sampling_sample.cc b/sample/cpp/decoding_sampling_sample.cc new file mode 100644 index 000000000..7b36c6741 --- /dev/null +++ b/sample/cpp/decoding_sampling_sample.cc @@ -0,0 +1,267 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fastertransformer/open_decoder.h" +#include "fastertransformer/decoding_sampling.h" +#include +#include +#include +#include +#include +#include +#include + +using namespace fastertransformer; + +template +void device_malloc(T** ptr, int size); + +template +void decoding_sample(int batch_size, + int candidate_num, + float probability_threshold, + int head_num, + int size_per_head, + int vocab_size, + int seq_len, + int decoder_layers, + int memory_hidden_units); + +int main(int argc, char* argv[]) +{ + srand(0); + struct cudaDeviceProp prop; + check_cuda_error(cudaGetDeviceProperties(&prop, 0)); + printf("Device %s\n", prop.name); + + if(argc != 11) + { + printf("[ERROR] decoding_sample batch_size candidate_num probability_threshold head_num size_per_head vocab_size seq_len num_layer memory_hidden_units is_fp16\n"); + printf("e.g. ./bin/decoding_sample 32 1 0.0 8 64 30000 32 6 768 0\n"); + return 0; + } + + const int batch_size = atoi(argv[1]); + const int candidate_num = atoi(argv[2]); + const float probability_threshold = atof(argv[3]); + const int head_num = atoi(argv[4]); + const int size_per_head = atoi(argv[5]); + const int vocab_size = atoi(argv[6]); + const int seq_len = atoi(argv[7]); + const int decoder_layers = atoi(argv[8]); + const int memory_hidden_units = atoi(argv[9]); + + if(atoi(argv[10]) == 0) + decoding_sample(batch_size, candidate_num, probability_threshold, head_num, size_per_head, vocab_size, seq_len, decoder_layers, memory_hidden_units); + else if(atoi(argv[10]) == 1) + decoding_sample(batch_size, candidate_num, probability_threshold, head_num, size_per_head, vocab_size, seq_len, decoder_layers, memory_hidden_units); + else + { + printf("[ERROR] is_fp16 should be 0 (use float) or 1 (use half). \n"); + return -1; + } + + return 0; +} + +template +void device_malloc(T** ptr, int size) +{ + check_cuda_error(cudaMalloc((void**)ptr, sizeof(T) * size)); + T* tmp = new T[size]; + for(int i = 0; i < size; i++) tmp[i] = (T)((float) rand() / (RAND_MAX + 1.0) * 0.02); + check_cuda_error(cudaMemcpy(*ptr, tmp, sizeof(T) * size, cudaMemcpyHostToDevice)); + delete tmp; +} + +template +void decoding_sample(int batch_size, + int candidate_num, + float probability_threshold, + int head_num, + int size_per_head, + int vocab_size, + int seq_len, + int decoder_layers, + int memory_hidden_units) +{ + const int max_seq_len = seq_len; + const int memory_seq_len = seq_len; + const int start_id = 1; + const int end_id = 2; + const int hidden_units = head_num * size_per_head; + const int inner_size = 4 * hidden_units; + + cublasHandle_t cublasHandle; + check_cuda_error(cublasCreate(&cublasHandle)); + + cudaStream_t stream; + check_cuda_error(cudaStreamCreate(&stream)); + check_cuda_error(cublasSetStream(cublasHandle, stream)); + + fastertransformer::Allocator allocator(0); + DecoderInitParam *param = new DecoderInitParam[decoder_layers]; + + for(int i = 0; i < decoder_layers; i++){ + param[i].stream = stream; + param[i].cublas_handle = cublasHandle; + + T *d_self_Q_kernel, *d_self_K_kernel, *d_self_V_kernel, *d_self_output_kernel; + T *d_self_Q_bias, *d_self_K_bias, *d_self_V_bias, *d_self_output_bias; + T *d_cross_Q_kernel, *d_cross_K_kernel, *d_cross_V_kernel, *d_cross_output_kernel; + T *d_cross_Q_bias, *d_cross_K_bias, *d_cross_V_bias, *d_cross_output_bias; + T *d_ffn_kernel1, *d_ffn_bias1, *d_ffn_kernel2, *d_ffn_bias2; + T *d_self_gamma, *d_self_beta; + T *d_cross_gamma, *d_cross_beta; + T *d_ffn_gamma, *d_ffn_beta; + + device_malloc(&d_self_Q_kernel, hidden_units * hidden_units); + device_malloc(&d_self_K_kernel, hidden_units * hidden_units); + device_malloc(&d_self_V_kernel, hidden_units * hidden_units); + device_malloc(&d_self_output_kernel, hidden_units * hidden_units); + device_malloc(&d_self_Q_bias, hidden_units); + device_malloc(&d_self_K_bias, hidden_units); + device_malloc(&d_self_V_bias, hidden_units); + device_malloc(&d_self_output_bias, hidden_units); + + device_malloc(&d_cross_Q_kernel, hidden_units * hidden_units); + device_malloc(&d_cross_K_kernel, memory_hidden_units * hidden_units); + device_malloc(&d_cross_V_kernel, memory_hidden_units * hidden_units); + device_malloc(&d_cross_output_kernel, hidden_units * hidden_units); + device_malloc(&d_cross_Q_bias, hidden_units); + device_malloc(&d_cross_K_bias, hidden_units); + device_malloc(&d_cross_V_bias, hidden_units); + device_malloc(&d_cross_output_bias, hidden_units); + + device_malloc(&d_ffn_bias1, inner_size); + device_malloc(&d_ffn_kernel1, inner_size * hidden_units); + device_malloc(&d_ffn_bias2, hidden_units); + device_malloc(&d_ffn_kernel2, inner_size * hidden_units); + + device_malloc(&d_self_gamma, hidden_units); + device_malloc(&d_self_beta, hidden_units); + device_malloc(&d_cross_gamma, hidden_units); + device_malloc(&d_cross_beta, hidden_units); + device_malloc(&d_ffn_gamma, hidden_units); + device_malloc(&d_ffn_beta, hidden_units); + + param[i].self_attention.query_weight.kernel = d_self_Q_kernel; + param[i].self_attention.key_weight.kernel = d_self_K_kernel; + param[i].self_attention.value_weight.kernel = d_self_V_kernel; + param[i].self_attention.attention_output_weight.kernel = d_self_output_kernel; + param[i].self_attention.query_weight.bias = d_self_Q_bias; + param[i].self_attention.key_weight.bias = d_self_K_bias; + param[i].self_attention.value_weight.bias = d_self_V_bias; + param[i].self_attention.attention_output_weight.bias = d_self_output_bias; + + param[i].cross_attention.query_weight.kernel = d_cross_Q_kernel; + param[i].cross_attention.key_weight.kernel = d_cross_K_kernel; + param[i].cross_attention.value_weight.kernel = d_cross_V_kernel; + param[i].cross_attention.attention_output_weight.kernel = d_cross_output_kernel; + param[i].cross_attention.query_weight.bias = d_cross_Q_bias; + param[i].cross_attention.key_weight.bias = d_cross_K_bias; + param[i].cross_attention.value_weight.bias = d_cross_V_bias; + param[i].cross_attention.attention_output_weight.bias = d_cross_output_bias; + + param[i].self_layernorm.gamma = d_self_gamma; + param[i].self_layernorm.beta = d_self_beta; + param[i].cross_layernorm.gamma = d_cross_gamma; + param[i].cross_layernorm.beta = d_cross_beta; + param[i].ffn_layernorm.gamma = d_ffn_gamma; + param[i].ffn_layernorm.beta = d_ffn_beta; + param[i].ffn.intermediate_weight.bias = d_ffn_bias1; + param[i].ffn.output_weight.bias = d_ffn_bias2; + param[i].ffn.intermediate_weight.kernel = d_ffn_kernel1; + param[i].ffn.output_weight.kernel = d_ffn_kernel2; + } + + DecodingInitParam decoding_params; + + T *d_memory_tensor; + T *d_embedding_table; + T* d_embedding_kernel; + float* d_embedding_bias; + T* d_position_encoding_table; + int* d_output_ids; + int* d_parent_ids; + int* d_sequence_lengths; + int* d_memory_sequence_lengths; + T *d_gamma, *d_beta; + + device_malloc(&d_memory_tensor, memory_hidden_units * memory_seq_len * batch_size); + device_malloc(&d_embedding_table, hidden_units * vocab_size); + device_malloc(&d_embedding_kernel, vocab_size * hidden_units); + device_malloc(&d_embedding_bias, vocab_size); + device_malloc(&d_position_encoding_table, max_seq_len * hidden_units); + check_cuda_error(cudaMalloc((void**)&d_output_ids, sizeof(int) * (max_seq_len) * batch_size)); + check_cuda_error(cudaMalloc((void**)&d_parent_ids, sizeof(int) * (max_seq_len) * batch_size)); + check_cuda_error(cudaMalloc((void**)&d_sequence_lengths, sizeof(int) * batch_size)); + check_cuda_error(cudaMalloc((void**)&d_memory_sequence_lengths, sizeof(int) * batch_size)); + device_malloc(&d_gamma, hidden_units); + device_malloc(&d_beta, hidden_units); + + int *h_memory_sequence_lengths = new int[batch_size]; + for(int i = 0; i < batch_size; i++) h_memory_sequence_lengths[i] = memory_seq_len; + check_cuda_error(cudaMemcpy(d_memory_sequence_lengths, h_memory_sequence_lengths, sizeof(int) * batch_size, cudaMemcpyHostToDevice)); + + decoding_params.cublas_handle = cublasHandle; + decoding_params.stream = stream; + decoding_params.memory_tensor = d_memory_tensor; + decoding_params.embedding_table = d_embedding_table; + decoding_params.embedding_kernel = d_embedding_kernel; + decoding_params.embedding_bias = d_embedding_bias; + decoding_params.position_encoding_table = d_position_encoding_table; + decoding_params.output_ids = d_output_ids; + decoding_params.parent_ids = d_parent_ids; + decoding_params.sequence_length = d_sequence_lengths; + decoding_params.memory_sequence_length = d_memory_sequence_lengths; + decoding_params.layernorm.gamma = d_gamma; + decoding_params.layernorm.beta = d_beta; + + const fastertransformer::OperationType type = sizeof(T) == sizeof(float) ? OperationType::FP32 : OperationType::FP16; + + DecodingSampling *decoding = new + DecodingSampling(allocator, batch_size, + max_seq_len, head_num, size_per_head, + vocab_size, decoder_layers, + memory_hidden_units, memory_seq_len, + start_id, end_id, + candidate_num, probability_threshold); + + //warm up + int ite = 50; + for(int i = 0; i < ite; ++i) + decoding->forward(param, decoding_params); + + struct timeval start, end; + cudaDeviceSynchronize(); + gettimeofday(&start, NULL); + + for(int i = 0; i < ite; ++i) + decoding->forward(param, decoding_params); + + cudaDeviceSynchronize(); + gettimeofday(&end, NULL); + printf("[INFO] batch_size %d topk %d topp %f head_num %d size_per_head %d seq_len %d decoder_layers" \ + " %d vocab_size %d FT-CPP-decoding-sampling-time %.2f ms\n", + batch_size, candidate_num, probability_threshold, head_num, size_per_head, seq_len, decoder_layers, vocab_size, + ((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite); + + delete [] param; + delete [] h_memory_sequence_lengths; + delete decoding; + return ; +} \ No newline at end of file diff --git a/sample/cpp/encoder_sample.cc b/sample/cpp/encoder_sample.cc index 5a3036731..85c0343f8 100644 --- a/sample/cpp/encoder_sample.cc +++ b/sample/cpp/encoder_sample.cc @@ -32,16 +32,17 @@ void encoder_sample(int batch_size, int num_layers, int seq_len, int head_num, - int size_per_head); + int size_per_head, + bool is_remove_padding); int main(int argc, char* argv[]) { struct cudaDeviceProp prop; check_cuda_error(cudaGetDeviceProperties(&prop, 0)); - if(argc != 7) + if(argc != 8) { - printf("[ERROR] encoder_sample batch_size num_layers seq_len head_num size_per_head is_fp16\n"); - printf("e.g., ./bin/encoder_sample 1 12 128 12 64 0\n"); + printf("[ERROR] encoder_sample batch_size num_layers seq_len head_num size_per_head is_fp16 is_remove_padding\n"); + printf("e.g., ./bin/encoder_sample 1 12 128 12 64 0 0\n"); return 0; } @@ -51,11 +52,12 @@ int main(int argc, char* argv[]) int seq_len = atoi(argv[3]); int head_num = atoi(argv[4]); int size_per_head = atoi(argv[5]); + bool is_remove_padding = (bool)atoi(argv[7]); if(atoi(argv[6]) == 0) - encoder_sample(batch_size, num_layers, seq_len, head_num, size_per_head); + encoder_sample(batch_size, num_layers, seq_len, head_num, size_per_head, is_remove_padding); else if(atoi(argv[6]) == 1) - encoder_sample(batch_size, num_layers, seq_len, head_num, size_per_head); + encoder_sample(batch_size, num_layers, seq_len, head_num, size_per_head, is_remove_padding); else { printf("[ERROR] is_fp16 should be 0 (use float) or 1 (use half). \n"); @@ -69,6 +71,14 @@ template void device_malloc(T **ptr, int size) { check_cuda_error(cudaMalloc((void **)ptr, sizeof(T) * size)); + T *tmp = new T[size]; + for(int i = 0; i < size; i++) + { + tmp[i] = (T)((rand() % 100) / 50.0f) - 1.0f; + } + cudaMemcpy(*ptr, tmp, sizeof(T) * size, cudaMemcpyHostToDevice); + delete tmp; + } template @@ -76,7 +86,8 @@ void encoder_sample(int batch_size, int num_layers, int seq_len, int head_num, - int size_per_head) + int size_per_head, + bool is_remove_padding) { int from_seq_len = seq_len; int to_seq_len = seq_len; @@ -90,6 +101,22 @@ void encoder_sample(int batch_size, T *d_attr_output_layernorm_gamma = NULL; T *d_inter_kernel = NULL, *d_inter_bias = NULL; T *d_output_kernel = NULL, *d_output_bias = NULL, *d_output_layernorm_beta = NULL, *d_output_layernorm_gamma = NULL; + + // pre_process buffer + T *d_from_tensor_with_padding = NULL; + T *d_transformer_out_with_padding = NULL; + + int* d_sequence_length; + int *d_sequence_id_offset; + int *d_tmp_sequence_id_offset; + int *d_valid_word_num; + + int* h_sequence_length = new int[batch_size]; + for(int i = 0; i < batch_size; i++) + { + h_sequence_length[i] = random() % from_seq_len; + } + size_t free_bytes, total_bytes; check_cuda_error(cudaMemGetInfo(&free_bytes, &total_bytes)); @@ -97,9 +124,11 @@ void encoder_sample(int batch_size, float total = (float)(total_bytes) / 1024.0 / 1024.0 / 1024.0; printf("before allocate free %.2f GB total %.2f GB\n", free, total); + cudaMalloc((void**)&d_sequence_length, sizeof(int) * (ceil(batch_size/4.) * 4)); + device_malloc(&d_from_tensor, batch_size * seq_len * hidden_dim); device_malloc(&d_transformer_out, batch_size * seq_len * hidden_dim); - device_malloc(&d_attr_kernel_Q, hidden_dim * hidden_dim); + device_malloc(&d_attr_kernel_Q, hidden_dim * hidden_dim * 3); device_malloc(&d_attr_kernel_K, hidden_dim * hidden_dim); device_malloc(&d_attr_kernel_V, hidden_dim * hidden_dim); device_malloc(&d_attr_bias_Q, hidden_dim); @@ -117,6 +146,16 @@ void encoder_sample(int batch_size, device_malloc(&d_output_layernorm_beta, hidden_dim); device_malloc(&d_output_layernorm_gamma, hidden_dim); + if(is_remove_padding == true) + { + const int pre_process_buf_size = ceil((batch_size * from_seq_len + 1) * sizeof(int) / 4.) * 4; + cudaMalloc((void**)&d_sequence_id_offset, sizeof(int) * batch_size * from_seq_len); + cudaMalloc((void**)&d_tmp_sequence_id_offset, pre_process_buf_size); + d_valid_word_num = (int*)d_tmp_sequence_id_offset + batch_size * from_seq_len; + device_malloc(&d_from_tensor_with_padding, batch_size * from_seq_len * hidden_dim); + device_malloc(&d_transformer_out_with_padding, batch_size * from_seq_len * hidden_dim); + } + check_cuda_error(cudaMemGetInfo(&free_bytes, &total_bytes)); free = (float)(free_bytes) / 1024.0 / 1024.0 / 1024.0; total = (float)(total_bytes) / 1024.0 / 1024.0 / 1024.0; @@ -164,26 +203,94 @@ void encoder_sample(int batch_size, to_seq_len, head_num, size_per_head); - encoder_transformer_->initialize(encoder_param); - int ite = 200; - //warp up - for (int i = 0; i < ite; ++i) - encoder_transformer_->forward(); + //warm up + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); + for (int i = 0; i < 2; ++i) + { + if(is_remove_padding == true) + { + cudaMemcpyAsync(d_sequence_length, h_sequence_length, sizeof(int) * batch_size, cudaMemcpyHostToDevice, stream); + int* h_valid_word_num = new int[1]; + build_sequence_length_padding_offset_kernelLauncher(d_sequence_length, + batch_size, seq_len, d_valid_word_num, d_tmp_sequence_id_offset, stream); + cudaMemcpyAsync(h_valid_word_num, d_valid_word_num, sizeof(int), cudaMemcpyDeviceToHost, stream); + const int valid_word_num = h_valid_word_num[0]; + delete h_valid_word_num; + + remove_sequence_length_padding_kernelLauncher(d_from_tensor_with_padding, + d_from_tensor, + d_tmp_sequence_id_offset, + d_sequence_id_offset, + valid_word_num, hidden_dim, + stream); + + encoder_param.sequence_id_offset = d_sequence_id_offset; + encoder_param.valid_word_num = valid_word_num; + } + + encoder_transformer_->initialize(encoder_param); + for(int i = 0; i < num_layers; i++) + encoder_transformer_->forward(); + + if(is_remove_padding == true) + { + rebuild_sequence_length_padding_kernelLauncher(d_transformer_out, d_transformer_out_with_padding, + d_sequence_id_offset, + encoder_param.valid_word_num, hidden_dim, + encoder_param.stream); + } + } struct timeval start, end; cudaDeviceSynchronize(); + cudaProfilerStart(); gettimeofday(&start, NULL); + int ite = 50; for (int i = 0; i < ite; ++i) { - for (int j = 0; j < num_layers; ++j) + if(is_remove_padding == true) + { + cudaMemcpyAsync(d_sequence_length, h_sequence_length, sizeof(int) * batch_size, cudaMemcpyHostToDevice, stream); + int* h_valid_word_num = new int[1]; + build_sequence_length_padding_offset_kernelLauncher(d_sequence_length, + batch_size, seq_len, d_valid_word_num, d_tmp_sequence_id_offset, stream); + cudaMemcpyAsync(h_valid_word_num, d_valid_word_num, sizeof(int), cudaMemcpyDeviceToHost, stream); + const int valid_word_num = h_valid_word_num[0]; + delete h_valid_word_num; + + remove_sequence_length_padding_kernelLauncher(d_from_tensor_with_padding, + d_from_tensor, + d_tmp_sequence_id_offset, + d_sequence_id_offset, + valid_word_num, hidden_dim, + stream); + + encoder_param.sequence_id_offset = d_sequence_id_offset; + encoder_param.valid_word_num = valid_word_num; + } + + encoder_transformer_->initialize(encoder_param); + for(int i = 0; i < num_layers; i++) encoder_transformer_->forward(); + + if(is_remove_padding == true) + { + rebuild_sequence_length_padding_kernelLauncher(d_transformer_out, d_transformer_out_with_padding, + d_sequence_id_offset, + encoder_param.valid_word_num, hidden_dim, + encoder_param.stream); + } } + cudaDeviceSynchronize(); gettimeofday(&end, NULL); - printf("[batch_size %d seq_len %d %d transformer layers] costs %.2f ms\n", batch_size, seq_len, num_layers, - ((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite); + cudaProfilerStop(); + + printf("[INFO] batch_size %d seq_len %d layer %d FT-CPP-time %.2f ms \n", batch_size, seq_len, num_layers, + ((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite); delete encoder_transformer_; return; diff --git a/sample/pytorch/decoder_sample.py b/sample/pytorch/decoder_sample.py new file mode 100644 index 000000000..52d6d1daf --- /dev/null +++ b/sample/pytorch/decoder_sample.py @@ -0,0 +1,145 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import argparse +import timeit +import torch +# import torch.cuda.nvtx as nvtx + +from onmt.utils.misc import sequence_mask +from utils.decoder import DecoderWeights, CustomDecoder, ONMTDecoder, init_op_cache, init_onmt_cache + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('batch_size', type=int, + help='batch size') + parser.add_argument('layer_num', type=int, + help='number of layers') + parser.add_argument('seq_len', type=int, + help='sequence length') + parser.add_argument('head_num', type=int, + help='head number') + parser.add_argument('head_size', type=int, + help='size per head') + parser.add_argument('--step', type=int, default=0, + help='decoding step number') + parser.add_argument('--fp16', action='store_true', + help='is fp16') + parser.add_argument('--time', action='store_true', + help='test the time or not.') + parser.add_argument('--module_path', type=str, default='./', + help='directory containing the th_fastertransformer dynamic lib') + parser.add_argument('--ths', action='store_true', + help='use TorchScript mode') + parser.add_argument('--ths_path', type=str, default='./lib/libths_fastertransformer.so', + help='path of the ths_fastertransformer dynamic lib file') + + args = parser.parse_args() + + hidden_dim = args.head_num * args.head_size + + if args.step <= 0: + step = args.seq_len + else: + step = args.step + + print("\n=============== Argument ===============") + print('batch_size: ' + str(args.batch_size)) + print('layer_num: ' + str(args.layer_num)) + print('seq_len: ' + str(args.seq_len)) + print('head_num: ' + str(args.head_num)) + print('head_size: ' + str(args.head_size)) + print('hidden_dim: ' + str(hidden_dim)) + print('step: ' + str(step)) + print('use_fp16: ' + str(args.fp16)) + print('TorchScript mode: ' + str(args.ths)) + print('test_time: ' + str(args.time)) + print("========================================\n") + + inp = torch.empty(args.batch_size, 1, hidden_dim).cuda() + mem = torch.empty(args.batch_size, args.seq_len, hidden_dim).cuda() + torch.nn.init.uniform_(inp, -1, 1) + torch.nn.init.uniform_(mem, -1, 1) + if args.fp16: + inp = inp.half() + mem = mem.half() + mem_seq_lens = torch.randint(1, args.seq_len+1, (args.batch_size,), dtype=torch.int32).cuda() + src_pad_mask = ~sequence_mask(mem_seq_lens, args.seq_len).unsqueeze(1) + + weights = DecoderWeights(args.layer_num, hidden_dim) + onmt_decoder = ONMTDecoder(args.layer_num, args.head_num, args.head_size, weights) + onmt_decoder.cuda() + if args.fp16: + onmt_decoder.half() + onmt_decoder.eval() + + weights.to_cuda() + if args.fp16: + weights.to_half() + if args.ths: + custom_decoder = CustomDecoder(args.layer_num, args.head_num, args.head_size, weights, args.fp16, os.path.abspath(args.ths_path), args.ths) + else: + custom_decoder = CustomDecoder(args.layer_num, args.head_num, args.head_size, weights, args.fp16, os.path.abspath(args.module_path)) + + with torch.no_grad(): + self_cache, mem_cache = init_op_cache(args.layer_num, args.batch_size, 1, args.seq_len, hidden_dim, args.fp16) + cache = init_onmt_cache(args.layer_num, mem) + output1 = inp + output2 = inp + + for i in range(step): + output1 = onmt_decoder(output1, mem, src_pad_mask, cache, 0) + output2, self_cache, mem_cache = custom_decoder(output2, mem, mem_seq_lens, self_cache, mem_cache) + diff = torch.abs((output1 - output2) / output1) + print('step: {} Mean relative diff: {} Max relative diff: {} Min relative diff: {}'.format( + i, torch.mean(diff), torch.max(diff), torch.min(diff))) + + if args.time: + iterations = 10 + + for i in range(iterations): + cache = init_onmt_cache(args.layer_num, mem) + output1 = inp + for i in range(step): + output1 = onmt_decoder(output1, mem, src_pad_mask, cache, 0) + t10 = timeit.default_timer() + for i in range(iterations): + cache = init_onmt_cache(args.layer_num, mem) + output1 = inp + for i in range(step): + output1 = onmt_decoder(output1, mem, src_pad_mask, cache, 0) + t1 = timeit.default_timer() - t10 + + for i in range(iterations): + self_cache, mem_cache = init_op_cache(args.layer_num, args.batch_size, 1, args.seq_len, hidden_dim, args.fp16) + output2 = inp + for i in range(step): + output2, self_cache, mem_cache = custom_decoder(output2, mem, mem_seq_lens, self_cache, mem_cache) + t20 = timeit.default_timer() + for i in range(iterations): + self_cache, mem_cache = init_op_cache(args.layer_num, args.batch_size, 1, args.seq_len, hidden_dim, args.fp16) + output2 = inp + for i in range(step): + output2, self_cache, mem_cache = custom_decoder(output2, mem, mem_seq_lens, self_cache, mem_cache) + t2 = timeit.default_timer() - t20 + print("[INFO] ONMTDecoder time costs: {:.2f} ms".format(t1*1000/iterations)) + print("[INFO] FTDecoder time costs: {:.2f} ms".format(t2*1000/iterations)) + + +if __name__ == '__main__': + main() diff --git a/sample/pytorch/decoding_sample.py b/sample/pytorch/decoding_sample.py new file mode 100644 index 000000000..70418f02d --- /dev/null +++ b/sample/pytorch/decoding_sample.py @@ -0,0 +1,167 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import argparse +import timeit +import torch +# import torch.cuda.nvtx as nvtx + +from onmt.utils.misc import sequence_mask +from utils.decoding import DecodingWeights, CustomDecoding, TorchDecoding, ArgHelper + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('batch_size', type=int, + help='batch size') + parser.add_argument('layer_num', type=int, + help='number of layers') + parser.add_argument('seq_len', type=int, + help='sequence length') + parser.add_argument('head_num', type=int, + help='head number') + parser.add_argument('head_size', type=int, + help='size per head') + parser.add_argument('beam_size', type=int, + help='beam size') + parser.add_argument('vocab_size', type=int, + help='vocab size') + parser.add_argument('--fp16', action='store_true', + help='is fp16') + parser.add_argument('--time', action='store_true', + help='test the time or not.') + parser.add_argument('--use_pretrained', action='store_true', + help='use pretrained weights or not.') + parser.add_argument('--module_path', type=str, default='./', + help='directory containing the th_fastertransformer dynamic lib') + parser.add_argument('--ths', action='store_true', + help='use TorchScript mode') + parser.add_argument('--ths_path', type=str, default='./lib/libths_fastertransformer.so', + help='path of the ths_fastertransformer dynamic lib file') + + args = parser.parse_args() + + if args.use_pretrained: + layer_num = 6 + head_num = 8 + head_size = 64 + vocab_size = 31538 + else: + layer_num = args.layer_num + head_num = args.head_num + head_size = args.head_size + vocab_size = args.vocab_size + hidden_dim = head_num * head_size + + print("\n=============== Argument ===============") + print('batch_size: ' + str(args.batch_size)) + print('layer_num: ' + str(layer_num)) + print('seq_len: ' + str(args.seq_len)) + print('head_num: ' + str(head_num)) + print('head_size: ' + str(head_size)) + print('hidden_dim: ' + str(hidden_dim)) + print('beam_size: ' + str(args.beam_size)) + print('vocab_size: ' + str(vocab_size)) + print('use_pretrained: ' + str(args.use_pretrained)) + print('use_fp16: ' + str(args.fp16)) + print('TorchScript mode: ' + str(args.ths)) + print('test_time: ' + str(args.time)) + print("========================================\n") + + decodingargs1 = ArgHelper('torch_decoding', 'fp16' if args.fp16 else 'fp32', + os.path.abspath(args.module_path), args.ths, os.path.abspath(args.ths_path)) + decodingargs2 = ArgHelper('torch_decoding_with_decoder_ext', 'fp16' if args.fp16 else 'fp32', + os.path.abspath(args.module_path), args.ths, os.path.abspath(args.ths_path)) + + mem = torch.empty(args.batch_size, args.seq_len, hidden_dim).cuda() + torch.nn.init.uniform_(mem, -1, 1) + if args.fp16: + mem = mem.half() + mem_seq_lens = torch.randint(1, args.seq_len+1, (args.batch_size,), dtype=torch.int32).cuda() + + if args.use_pretrained: + ckpt = torch.load('./pytorch/translation/models/averaged-10-epoch.pt') + import re + def fix_key(s): + s = re.sub(r'(.*)\.layer_norm((_\d+)?)\.b_2', + r'\1.layer_norm\2.bias', s) + s = re.sub(r'(.*)\.layer_norm((_\d+)?)\.a_2', + r'\1.layer_norm\2.weight', s) + return s + ckpt['model'] = {fix_key(k): v for k, v in ckpt['model'].items()} + weights = DecodingWeights(layer_num, hidden_dim, vocab_size, ckpt) + else: + weights = DecodingWeights(layer_num, hidden_dim, vocab_size) + torch_decoding = TorchDecoding(layer_num, head_num, head_size, vocab_size, 2, 3, weights, args=decodingargs1) + torch_decoding_with_decoder_ext = TorchDecoding(args.layer_num, head_num, head_size, args.vocab_size, 2, 3, weights, args=decodingargs2) + torch_decoding.cuda() + torch_decoding_with_decoder_ext.cuda() + if args.fp16: + torch_decoding.half() + torch_decoding_with_decoder_ext.half() + torch_decoding.eval() + torch_decoding_with_decoder_ext.eval() + weights.to_cuda() + if args.fp16: + weights.to_half() + custom_decoding = CustomDecoding(layer_num, head_num, head_size, vocab_size, 2, 3, weights, + args=decodingargs1) + + with torch.no_grad(): + output0, lens0 = torch_decoding(args.batch_size, args.beam_size, args.seq_len, mem, mem_seq_lens) + print(output0) + print(lens0) + output1, lens1 = torch_decoding_with_decoder_ext(args.batch_size, args.beam_size, args.seq_len, mem, mem_seq_lens) + print(output1) + print(lens1) + output2, lens2 = custom_decoding(args.batch_size, args.beam_size, args.seq_len, mem, mem_seq_lens) + print(output2) + print(lens2) + # diff = torch.abs((output1 - output2) / output1) + # print('step: {} Mean relative diff: {} Max relative diff: {} Min relative diff: {}'.format( + # i, torch.mean(diff), torch.max(diff), torch.min(diff))) + + if args.time: + iterations = 10 + + for i in range(iterations): + output, lens = torch_decoding(args.batch_size, args.beam_size, args.seq_len, mem, mem_seq_lens) + t00 = timeit.default_timer() + for i in range(iterations): + output, lens = torch_decoding(args.batch_size, args.beam_size, args.seq_len, mem, mem_seq_lens) + t0 = timeit.default_timer() - t00 + + for i in range(iterations): + output, lens = torch_decoding_with_decoder_ext(args.batch_size, args.beam_size, args.seq_len, mem, mem_seq_lens) + t10 = timeit.default_timer() + for i in range(iterations): + output, lens = torch_decoding_with_decoder_ext(args.batch_size, args.beam_size, args.seq_len, mem, mem_seq_lens) + t1 = timeit.default_timer() - t10 + + for i in range(iterations): + output, lens = custom_decoding(args.batch_size, args.beam_size, args.seq_len, mem, mem_seq_lens) + t20 = timeit.default_timer() + for i in range(iterations): + output, lens = custom_decoding(args.batch_size, args.beam_size, args.seq_len, mem, mem_seq_lens) + t2 = timeit.default_timer() - t20 + print("[INFO] TorchDecoding time costs: {:.2f} ms".format(t0*1000/iterations)) + print("[INFO] TorchDecoding (with FTDecoder) time costs: {:.2f} ms".format(t1*1000/iterations)) + print("[INFO] FTDecoding time costs: {:.2f} ms".format(t2*1000/iterations)) + + +if __name__ == '__main__': + main() diff --git a/sample/pytorch/encoder_sample.py b/sample/pytorch/encoder_sample.py new file mode 100644 index 000000000..65b5e2f27 --- /dev/null +++ b/sample/pytorch/encoder_sample.py @@ -0,0 +1,202 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import argparse +import timeit +import torch +import torch.cuda.nvtx as nvtx + +from utils.encoder import EncoderWeights, CustomEncoder, CustomEncoder2, HuggingFaceEncoder + + +def sequence_mask(lengths, max_len=None, is_2d=True): + batch_size = lengths.numel() + max_len = max_len or lengths.max() + mask = (torch.arange(0, max_len, device=lengths.device) + .type_as(lengths) + .repeat(batch_size, 1) + .lt(lengths.unsqueeze(1))) + if is_2d: + return mask + else: + mask = mask.view(-1, 1, 1, max_len) + m2 = mask.transpose(2, 3) + return mask * m2 + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('batch_size', type=int, + help='batch size') + parser.add_argument('layer_num', type=int, + help='number of layers') + parser.add_argument('seq_len', type=int, + help='sequence length') + parser.add_argument('head_num', type=int, + help='head number') + parser.add_argument('head_size', type=int, + help='size per head') + parser.add_argument('--fp16', action='store_true', + help='is fp16') + parser.add_argument('--time', action='store_true', + help='test the time or not.') + parser.add_argument('--use_pretrained', action='store_true', + help='use pretrained weights or not.') + parser.add_argument('--remove_padding', action='store_true', + help='Remove the padding of sentences of encoder.') + parser.add_argument('--weight_path', type=str, + default='./pytorch/bert_squad/models/bert-large-uncased-whole-word-masking-finetuned-squad/pytorch_model.bin', + help='path containing the pretrained weights') + parser.add_argument('--module_path', type=str, default='./', + help='directory containing the th_fastertransformer dynamic lib') + parser.add_argument('--ths', action='store_true', + help='use TorchScript mode') + parser.add_argument('--ths_type', type=int, default=0, + help='custom TorchScript type') + parser.add_argument('--ths_path', type=str, default='./lib/libths_fastertransformer.so', + help='path of the ths_fastertransformer dynamic lib file') + parser.add_argument('--ths_path_2', type=str, default='./lib/libths_fastertransformer_op.so', + help='path of the ths_fastertransformer op dynamic lib file') + + args = parser.parse_args() + + batch_size = args.batch_size + seq_len = args.seq_len + if args.use_pretrained: + if 'large' in args.weight_path: + layer_num = 24 + head_num = 16 + head_size = 64 + elif 'base' in args.weight_path: + layer_num = 12 + head_num = 12 + head_size = 64 + else: + layer_num = args.layer_num + head_num = args.head_num + head_size = args.head_size + else: + layer_num = args.layer_num + head_num = args.head_num + head_size = args.head_size + hidden_dim = head_num * head_size + + print("\n=============== Argument ===============") + print('batch_size: ' + str(batch_size)) + print('layer_num: ' + str(layer_num)) + print('seq_len: ' + str(seq_len)) + print('head_num: ' + str(head_num)) + print('head_size: ' + str(head_size)) + print('hidden_dim: ' + str(hidden_dim)) + print('use_pretrained: ' + str(args.use_pretrained)) + print('use_fp16: ' + str(args.fp16)) + print('TorchScript mode: ' + str(args.ths)) + print('TorchScript type: ' + str(args.ths_type)) + print('test_time: ' + str(args.time)) + print('remove_padding: ' + str(args.remove_padding)) + print("========================================\n") + + inp = torch.empty(batch_size, seq_len, hidden_dim).cuda() + torch.nn.init.uniform_(inp, -1, 1) + mem_seq_lens = torch.randint(1, seq_len+1, (batch_size,), dtype=torch.int32).cuda() + mask = sequence_mask(mem_seq_lens, args.seq_len, False).to(torch.float) + # mask = torch.randint(0, 2, (batch_size, seq_len, seq_len), dtype=torch.float32).cuda() + if args.fp16: + inp = inp.half() + mask = mask.half() + + if args.use_pretrained: + pretrained_weights = torch.load(args.weight_path) + weights = EncoderWeights(layer_num, hidden_dim, pretrained_weights) + else: + weights = EncoderWeights(layer_num, hidden_dim) + + if args.use_pretrained: + hf_encoder = HuggingFaceEncoder(layer_num, head_num, head_size, pretrained_weights) + else: + hf_encoder = HuggingFaceEncoder(layer_num, head_num, head_size, weights) + hf_encoder.cuda() + if args.fp16: + hf_encoder.half() + hf_encoder.eval() + if args.ths: + hf_encoder = torch.jit.trace(hf_encoder, (inp, mask)) + + weights.to_cuda() + if args.fp16: + weights.to_half() + if args.ths: + if args.ths_type == 0: + custom_encoder = CustomEncoder(layer_num, head_num, head_size, weights, + os.path.abspath(args.ths_path), args.ths, remove_padding=args.remove_padding) + else: + custom_encoder = CustomEncoder2(layer_num, head_num, head_size, weights, + os.path.abspath(args.ths_path_2), remove_padding=args.remove_padding) + else: + custom_encoder = CustomEncoder(layer_num, head_num, head_size, weights, + os.path.abspath(args.module_path), remove_padding=args.remove_padding) + if args.ths: + if args.ths_type == 0: + custom_encoder = torch.jit.script(custom_encoder) + else: + custom_encoder = torch.jit.trace(custom_encoder, (inp, mask, mem_seq_lens)) + + with torch.no_grad(): + output_mask = sequence_mask(mem_seq_lens, args.seq_len).to(mask.dtype).unsqueeze(-1) + output1 = hf_encoder(inp, mask)[0] * output_mask + print(output1) + print(output1.size()) + + output2 = custom_encoder(inp, mask, mem_seq_lens)[0] * output_mask + print(output2) + print(output2.size()) + + diff = torch.abs(output1 - output2) + print('Mean diff: {}'.format(torch.mean(diff))) + print('Max diff: {}'.format(torch.max(diff))) + print('Min diff: {}'.format(torch.min(diff))) + + if args.time: + iterations = 100 + + for i in range(iterations): + output = hf_encoder(inp, mask) + t10 = timeit.default_timer() + # nvtx.range_push("hf") + for i in range(iterations): + # nvtx.range_push("hf"+str(i)) + output = hf_encoder(inp, mask) + # nvtx.range_pop() + # nvtx.range_pop() + t1 = timeit.default_timer() - t10 + + for i in range(iterations): + output = custom_encoder(inp, mask, mem_seq_lens) + t20 = timeit.default_timer() + # nvtx.range_push("ext") + for i in range(iterations): + # nvtx.range_push("ext"+str(i)) + output = custom_encoder(inp, mask, mem_seq_lens) + # nvtx.range_pop() + # nvtx.range_pop() + t2 = timeit.default_timer() - t20 + print("[INFO] HuggingFaceEnocder time costs: {:.2f} ms".format(t1*1000/iterations)) + print("[INFO] FasterTransformer time costs: {:.2f} ms".format(t2*1000/iterations)) + + +if __name__ == '__main__': + main() diff --git a/sample/pytorch/run_glue.py b/sample/pytorch/run_glue.py new file mode 100644 index 000000000..f23c3ec10 --- /dev/null +++ b/sample/pytorch/run_glue.py @@ -0,0 +1,383 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import random +import timeit + +import numpy as np +import torch +from torch.utils.data import DataLoader, SequentialSampler, TensorDataset +from tqdm import tqdm, trange + +from transformers import ( + BertConfig, + BertTokenizer, +) +from utils.modeling_bert import BertForSequenceClassification +from transformers import glue_compute_metrics as compute_metrics +from transformers import glue_convert_examples_to_features as convert_examples_to_features +from transformers import glue_output_modes as output_modes +from transformers import glue_processors as processors + + +logger = logging.getLogger(__name__) + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def evaluate(args, model, tokenizer, prefix=""): + # Loop to handle MNLI double evaluation (matched, mis-matched) + eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) + eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,) + + results = {} + for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): + eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) + + if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: + os.makedirs(eval_output_dir) + + args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) + # Note that DistributedSampler samples randomly + eval_sampler = SequentialSampler(eval_dataset) + eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + + # multi-gpu eval + if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): + model = torch.nn.DataParallel(model) + + # Eval! + logger.info("***** Running evaluation {} *****".format(prefix)) + logger.info(" Num examples = %d", len(eval_dataset)) + logger.info(" Batch size = %d", args.eval_batch_size) + # eval_loss = 0.0 + nb_eval_steps = 0 + preds = None + out_label_ids = None + + start_time = timeit.default_timer() + for batch in tqdm(eval_dataloader, desc="Evaluating"): + model.eval() + batch = tuple(t.to(args.device) for t in batch) + + with torch.no_grad(): + inputs = [batch[0], batch[1].half() if args.data_type == 'fp16' else batch[1], batch[2]] + outputs = model(*inputs) + # tmp_eval_loss, logits = outputs[:2] + logits = outputs[0] + + # eval_loss += tmp_eval_loss.mean().item() + nb_eval_steps += 1 + if preds is None: + preds = logits.detach().cpu().numpy() + out_label_ids = batch[3].detach().cpu().numpy() + else: + preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) + out_label_ids = np.append(out_label_ids, batch[3].detach().cpu().numpy(), axis=0) + + evalTime = timeit.default_timer() - start_time + logger.info(" Evaluation for " + eval_task + " done in total %f secs (%f sec per example)", evalTime, evalTime / len(eval_dataset)) + + # eval_loss = eval_loss / nb_eval_steps + if args.output_mode == "classification": + preds = np.argmax(preds, axis=1) + elif args.output_mode == "regression": + preds = np.squeeze(preds) + result = compute_metrics(eval_task, preds, out_label_ids) + results.update(result) + + output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") + with open(output_eval_file, "w") as writer: + logger.info("***** Eval results {} *****".format(prefix)) + for key in sorted(result.keys()): + logger.info(" %s = %s", key, str(result[key])) + writer.write("%s = %s\n" % (key, str(result[key]))) + + return results + + +def load_and_cache_examples(args, task, tokenizer, evaluate=False): + if args.local_rank not in [-1, 0] and not evaluate: + torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache + + processor = processors[task]() + output_mode = output_modes[task] + # Load data features from cache or dataset file + cached_features_file = os.path.join( + args.data_dir, + "cached_{}_{}_{}_{}".format( + "dev" if evaluate else "train", + list(filter(None, args.model_name_or_path.split("/"))).pop(), + str(args.max_seq_length), + str(task), + ), + ) + if os.path.exists(cached_features_file) and not args.overwrite_cache: + logger.info("Loading features from cached file %s", cached_features_file) + features = torch.load(cached_features_file) + else: + logger.info("Creating features from dataset file at %s", args.data_dir) + label_list = processor.get_labels() + examples = ( + processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) + ) + features = convert_examples_to_features( + examples, + tokenizer, + label_list=label_list, + max_length=args.max_seq_length, + output_mode=output_mode, + pad_on_left=False, + pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], + pad_token_segment_id=0, + ) + if args.local_rank in [-1, 0]: + logger.info("Saving features into cached file %s", cached_features_file) + torch.save(features, cached_features_file) + + if args.local_rank == 0 and not evaluate: + torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache + + # Convert to Tensors and build dataset + all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) + all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) + all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) + if output_mode == "classification": + all_labels = torch.tensor([f.label for f in features], dtype=torch.long) + elif output_mode == "regression": + all_labels = torch.tensor([f.label for f in features], dtype=torch.float) + + dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels) + return dataset + + +def main(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--data_dir", + default=None, + type=str, + required=True, + help="The input data dir. Should contain the .tsv files (or other data files) for the task.", + ) + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pre-trained model or shortcut name", + ) + parser.add_argument( + "--task_name", + default=None, + type=str, + required=True, + help="The name of the task to train selected in the list: " + ", ".join(processors.keys()), + ) + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model predictions and checkpoints will be written.", + ) + + # Other parameters + parser.add_argument( + "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + default="", + type=str, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--cache_dir", + default="", + type=str, + help="Where do you want to store the pre-trained models downloaded from s3", + ) + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", + ) + parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") + parser.add_argument( + "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.", + ) + + parser.add_argument( + "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.", + ) + parser.add_argument( + "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets", + ) + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + parser.add_argument("--model_type", type=str, help="ori, ths, ext, thsext") + parser.add_argument("--data_type", type=str, help="fp32, fp16") + parser.add_argument('--module_path', type=str, default='./', + help='path containing the th_fastertransformer dynamic lib') + parser.add_argument('--ths_path', type=str, default='./lib/libths_fastertransformer_op.so', + help='path of the ths_fastertransformer dynamic lib file') + + args = parser.parse_args() + + # Setup CUDA, GPU & distributed training + if args.local_rank == -1: + device = torch.device("cuda") + args.n_gpu = torch.cuda.device_count() + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend="nccl") + args.n_gpu = 1 + args.device = device + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, + ) + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s", + args.local_rank, + device, + args.n_gpu, + ) + + # Set seed + set_seed(args) + + # Prepare GLUE task + args.task_name = args.task_name.lower() + if args.task_name not in processors: + raise ValueError("Task not found: %s" % (args.task_name)) + processor = processors[args.task_name]() + args.output_mode = output_modes[args.task_name] + label_list = processor.get_labels() + num_labels = len(label_list) + + # Load pretrained model and tokenizer + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab + + config = BertConfig.from_pretrained( + args.config_name if args.config_name else args.model_name_or_path, + num_labels=num_labels, + finetuning_task=args.task_name, + cache_dir=args.cache_dir if args.cache_dir else None, + ) + tokenizer = BertTokenizer.from_pretrained( + args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, + do_lower_case=args.do_lower_case, + cache_dir=args.cache_dir if args.cache_dir else None, + ) + + if args.local_rank == 0: + torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab + + logger.info("Parameters %s", args) + + # Evaluation + results = {} + if args.do_eval and args.local_rank in [-1, 0]: + logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path) + checkpoints = [args.model_name_or_path] + logger.info("Evaluate the following checkpoints: %s", checkpoints) + for checkpoint in checkpoints: + global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" + prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else "" + use_ths = args.model_type.startswith('ths') + model = BertForSequenceClassification.from_pretrained(checkpoint, torchscript=use_ths) + model.to(args.device) + + if args.data_type == 'fp16': + logger.info("Use fp16") + model.half() + if args.model_type == 'ext': + logger.info("Use custom BERT encoder") + from utils.encoder import EncoderWeights, CustomEncoder + weights = EncoderWeights(model.config.num_hidden_layers, model.config.hidden_size, model.bert.encoder) + weights.to_cuda() + if args.data_type == 'fp16': + weights.to_half() + enc = CustomEncoder(model.config.num_hidden_layers, + model.config.num_attention_heads, + model.config.hidden_size//model.config.num_attention_heads, + weights, + os.path.abspath(args.module_path)) + model.replace_encoder(enc) + if args.model_type == 'thsext': + logger.info("Use custom BERT encoder for TorchScript") + from utils.encoder import EncoderWeights, CustomEncoder2 + weights = EncoderWeights(model.config.num_hidden_layers, model.config.hidden_size, model.bert.encoder) + weights.to_cuda() + if args.data_type == 'fp16': + weights.to_half() + enc = CustomEncoder2(model.config.num_hidden_layers, + model.config.num_attention_heads, + model.config.hidden_size//model.config.num_attention_heads, + weights, + os.path.abspath(args.ths_path)) + fake_inp = torch.zeros(args.per_gpu_eval_batch_size, args.max_seq_length, model.config.hidden_size).cuda() + torch.nn.init.uniform_(fake_inp, -1, 1) + fake_mask = torch.randint(0, 2, (args.per_gpu_eval_batch_size, args.max_seq_length, args.max_seq_length), dtype=torch.float32).cuda() + if args.data_type == 'fp16': + fake_inp = fake_inp.half() + fake_mask = fake_mask.half() + enc_ = torch.jit.trace(enc, (fake_inp, fake_mask)) + model.replace_encoder(enc_) + if use_ths: + logger.info("Use TorchScript mode") + fake_input_id = torch.LongTensor(args.per_gpu_eval_batch_size, args.max_seq_length) + fake_input_id.fill_(1) + fake_input_id = fake_input_id.to(args.device) + fake_mask = torch.ones(args.per_gpu_eval_batch_size, args.max_seq_length).to(args.device) + fake_type_id = fake_input_id.clone().detach() + if args.data_type == 'fp16': + fake_mask = fake_mask.half() + model.eval() + with torch.no_grad(): + model_ = torch.jit.trace(model, (fake_input_id, fake_mask, fake_type_id)) + model = model_ + + result = evaluate(args, model, tokenizer, prefix=prefix) + result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) + results.update(result) + + return results + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/sample/pytorch/run_squad.py b/sample/pytorch/run_squad.py new file mode 100644 index 000000000..cfbcd7523 --- /dev/null +++ b/sample/pytorch/run_squad.py @@ -0,0 +1,489 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Models for question-answering on SQuAD (Bert) modified from HuggingFace transformers .""" + + +import argparse +import logging +import os +import random +import timeit + +import numpy as np +import torch +from torch.utils.data import DataLoader, SequentialSampler +from tqdm import tqdm + +from transformers import ( + BertConfig, + BertTokenizer, + squad_convert_examples_to_features, +) +from utils.modeling_bert import BertForQuestionAnswering +from transformers.data.metrics.squad_metrics import ( + compute_predictions_logits, + squad_evaluate, +) +from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor + + +logger = logging.getLogger(__name__) + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def to_list(tensor): + return tensor.detach().cpu().tolist() + + +def evaluate(args, model, tokenizer, prefix=""): + dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True) + + if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: + os.makedirs(args.output_dir) + + args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) + + # Note that DistributedSampler samples randomly + eval_sampler = SequentialSampler(dataset) + eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + + # multi-gpu evaluate + if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): + model = torch.nn.DataParallel(model) + + # Eval! + logger.info("***** Running evaluation {} *****".format(prefix)) + logger.info(" Num examples = %d", len(dataset)) + logger.info(" Batch size = %d", args.eval_batch_size) + + all_results = [] + start_time = timeit.default_timer() + + for batch in tqdm(eval_dataloader, desc="Evaluating"): + model.eval() + batch = tuple(t.to(args.device) for t in batch) + + with torch.no_grad(): + # inputs = { + # "input_ids": batch[0], + # "attention_mask": batch[1].half() if args.data_type == 'fp16' else batch[1], + # "token_type_ids": batch[2], + # } + inputs = [batch[0], batch[1].half() if args.data_type == 'fp16' else batch[1], batch[2]] + + example_indices = batch[3] + + # outputs = model(**inputs) + outputs = model(*inputs) + + for i, example_index in enumerate(example_indices): + eval_feature = features[example_index.item()] + unique_id = int(eval_feature.unique_id) + + output = [to_list(output[i]) for output in outputs] + + # Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler" + # models only use two. + if len(output) >= 5: + start_logits = output[0] + start_top_index = output[1] + end_logits = output[2] + end_top_index = output[3] + cls_logits = output[4] + + result = SquadResult( + unique_id, + start_logits, + end_logits, + start_top_index=start_top_index, + end_top_index=end_top_index, + cls_logits=cls_logits, + ) + + else: + start_logits, end_logits = output + result = SquadResult(unique_id, start_logits, end_logits) + + all_results.append(result) + + evalTime = timeit.default_timer() - start_time + logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset)) + + # Compute predictions + output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix)) + output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix)) + + if args.version_2_with_negative: + output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix)) + else: + output_null_log_odds_file = None + + predictions = compute_predictions_logits( + examples, + features, + all_results, + args.n_best_size, + args.max_answer_length, + args.do_lower_case, + output_prediction_file, + output_nbest_file, + output_null_log_odds_file, + args.verbose_logging, + args.version_2_with_negative, + args.null_score_diff_threshold, + tokenizer, + ) + + # Compute the F1 and exact scores. + results = squad_evaluate(examples, predictions) + return results + + +def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False): + if args.local_rank not in [-1, 0] and not evaluate: + # Make sure only the first process in distributed training process the dataset, and the others will use the cache + torch.distributed.barrier() + + # Load data features from cache or dataset file + input_dir = args.data_dir if args.data_dir else "." + cached_features_file = os.path.join( + input_dir, + "cached_{}_{}_{}".format( + "dev" if evaluate else "train", + list(filter(None, args.model_name_or_path.split("/"))).pop(), + str(args.max_seq_length), + ), + ) + + # Init features and dataset from cache if it exists + if os.path.exists(cached_features_file) and not args.overwrite_cache: + logger.info("Loading features from cached file %s", cached_features_file) + features_and_dataset = torch.load(cached_features_file) + features, dataset, examples = ( + features_and_dataset["features"], + features_and_dataset["dataset"], + features_and_dataset["examples"], + ) + else: + logger.info("Creating features from dataset file at %s", input_dir) + + if not args.data_dir and ((evaluate and not args.predict_file) or (not evaluate and not args.train_file)): + try: + import tensorflow_datasets as tfds + except ImportError: + raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.") + + if args.version_2_with_negative: + logger.warn("tensorflow_datasets does not handle version 2 of SQuAD.") + + tfds_examples = tfds.load("squad") + examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate) + else: + processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor() + if evaluate: + examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file) + else: + examples = processor.get_train_examples(args.data_dir, filename=args.train_file) + + features, dataset = squad_convert_examples_to_features( + examples=examples, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, + doc_stride=args.doc_stride, + max_query_length=args.max_query_length, + is_training=not evaluate, + return_dataset="pt", + threads=args.threads, + ) + + if args.local_rank in [-1, 0]: + logger.info("Saving features into cached file %s", cached_features_file) + torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file) + + if args.local_rank == 0 and not evaluate: + # Make sure only the first process in distributed training process the dataset, and the others will use the cache + torch.distributed.barrier() + + if output_examples: + return dataset, examples, features + return dataset + + +def main(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pre-trained model or shortcut name", + ) + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model checkpoints and predictions will be written.", + ) + + # Other parameters + parser.add_argument( + "--data_dir", + default=None, + type=str, + help="The input data dir. Should contain the .json files for the task." + + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", + ) + parser.add_argument( + "--train_file", + default=None, + type=str, + help="The input training file. If a data dir is specified, will look for the file there" + + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", + ) + parser.add_argument( + "--predict_file", + default=None, + type=str, + help="The input evaluation file. If a data dir is specified, will look for the file there" + + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", + ) + parser.add_argument( + "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" + ) + parser.add_argument( + "--tokenizer_name", + default="", + type=str, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--cache_dir", + default="", + type=str, + help="Where do you want to store the pre-trained models downloaded from s3", + ) + + parser.add_argument( + "--version_2_with_negative", + action="store_true", + help="If true, the SQuAD examples contain some that do not have an answer.", + ) + parser.add_argument( + "--null_score_diff_threshold", + type=float, + default=0.0, + help="If null_score - best_non_null is greater than the threshold predict null.", + ) + + parser.add_argument( + "--max_seq_length", + default=384, + type=int, + help="The maximum total input sequence length after WordPiece tokenization. Sequences " + "longer than this will be truncated, and sequences shorter than this will be padded.", + ) + parser.add_argument( + "--doc_stride", + default=128, + type=int, + help="When splitting up a long document into chunks, how much stride to take between chunks.", + ) + parser.add_argument( + "--max_query_length", + default=64, + type=int, + help="The maximum number of tokens for the question. Questions longer than this will " + "be truncated to this length.", + ) + parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") + parser.add_argument( + "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model." + ) + + parser.add_argument( + "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation." + ) + parser.add_argument( + "--n_best_size", + default=20, + type=int, + help="The total number of n-best predictions to generate in the nbest_predictions.json output file.", + ) + parser.add_argument( + "--max_answer_length", + default=30, + type=int, + help="The maximum length of an answer that can be generated. This is needed because the start " + "and end predictions are not conditioned on one another.", + ) + parser.add_argument( + "--verbose_logging", + action="store_true", + help="If true, all of the warnings related to data processing will be printed. " + "A number of warnings are expected for a normal SQuAD evaluation.", + ) + + parser.add_argument( + "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" + ) + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") + parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features") + parser.add_argument("--model_type", type=str, help="ori, ths, ext, thsext") + parser.add_argument("--data_type", type=str, help="fp32, fp16") + parser.add_argument('--module_path', type=str, default='./', + help='path containing the th_fastertransformer dynamic lib') + parser.add_argument('--ths_path', type=str, default='./lib/libths_fastertransformer.so', + help='path of the ths_fastertransformer dynamic lib file') + args = parser.parse_args() + + if args.doc_stride >= args.max_seq_length - args.max_query_length: + logger.warning( + "WARNING - You've set a doc stride which may be superior to the document length in some " + "examples. This could result in errors when building features from the examples. Please reduce the doc " + "stride or increase the maximum length to ensure the features are correctly built." + ) + + # Setup CUDA, GPU & distributed training + if args.local_rank == -1: + device = torch.device("cuda") + args.n_gpu = torch.cuda.device_count() + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend="nccl") + args.n_gpu = 1 + args.device = device + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, + ) + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s", + args.local_rank, + device, + args.n_gpu, + ) + + # Set seed + set_seed(args) + + # Load pretrained model and tokenizer + if args.local_rank not in [-1, 0]: + # Make sure only the first process in distributed training will download model & vocab + torch.distributed.barrier() + + config = BertConfig.from_pretrained( + args.config_name if args.config_name else args.model_name_or_path, + cache_dir=args.cache_dir if args.cache_dir else None, + ) + tokenizer = BertTokenizer.from_pretrained( + args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, + do_lower_case=args.do_lower_case, + cache_dir=args.cache_dir if args.cache_dir else None, + ) + + if args.local_rank == 0: + # Make sure only the first process in distributed training will download model & vocab + torch.distributed.barrier() + + logger.info("Parameters %s", args) + + # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory + results = {} + if args.do_eval and args.local_rank in [-1, 0]: + logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path) + checkpoints = [args.model_name_or_path] + + logger.info("Evaluate the following checkpoints: %s", checkpoints) + + for checkpoint in checkpoints: + # Reload the model + global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" + use_ths = args.model_type.startswith('ths') + model = BertForQuestionAnswering.from_pretrained(checkpoint, torchscript=use_ths) # , force_download=True) + model.to(args.device) + + if args.data_type == 'fp16': + logger.info("Use fp16") + model.half() + if args.model_type == 'ext': + logger.info("Use custom BERT encoder") + from utils.encoder import EncoderWeights, CustomEncoder + weights = EncoderWeights(model.config.num_hidden_layers, model.config.hidden_size, model.bert.encoder) + weights.to_cuda() + if args.data_type == 'fp16': + weights.to_half() + enc = CustomEncoder(model.config.num_hidden_layers, + model.config.num_attention_heads, + model.config.hidden_size//model.config.num_attention_heads, + weights, + os.path.abspath(args.module_path)) + model.replace_encoder(enc) + if args.model_type == 'thsext': + logger.info("Use custom BERT encoder for TorchScript") + from utils.encoder import EncoderWeights, CustomEncoder + weights = EncoderWeights(model.config.num_hidden_layers, model.config.hidden_size, model.bert.encoder) + weights.to_cuda() + if args.data_type == 'fp16': + weights.to_half() + enc = CustomEncoder(model.config.num_hidden_layers, + model.config.num_attention_heads, + model.config.hidden_size//model.config.num_attention_heads, + weights, + os.path.abspath(args.ths_path), True) + enc_ = torch.jit.script(enc) + model.replace_encoder(enc_) + if use_ths: + logger.info("Use TorchScript mode") + fake_input_id = torch.LongTensor(args.per_gpu_eval_batch_size, args.max_seq_length) + fake_input_id.fill_(1) + fake_input_id = fake_input_id.to(args.device) + fake_mask = torch.ones(args.per_gpu_eval_batch_size, args.max_seq_length).to(args.device) + fake_type_id = fake_input_id.clone().detach() + if args.data_type == 'fp16': + fake_mask = fake_mask.half() + model.eval() + model_ = torch.jit.trace(model, (fake_input_id, fake_mask, fake_type_id)) + model = model_ + + # Evaluate + result = evaluate(args, model, tokenizer, prefix=global_step) + + result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items()) + results.update(result) + + logger.info("Results: {}".format(results)) + + return results + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/sample/pytorch/run_translation.py b/sample/pytorch/run_translation.py new file mode 100644 index 000000000..689d3bb21 --- /dev/null +++ b/sample/pytorch/run_translation.py @@ -0,0 +1,80 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import codecs +from onmt.translate import GNMTGlobalScorer +from utils.translation_model import load_test_model +from utils.translator import Translator + + +parser = argparse.ArgumentParser() +parser.add_argument("--batch_size", type=int, default=1, help="batch size") +parser.add_argument("--beam_size", type=int, default=4, help="beam size") +parser.add_argument("--model_type", type=str, help="ori, decoder_ext, decoding_ext, torch_decoding, torch_decoding_with_decoder_ext") +parser.add_argument("--data_type", type=str, help="fp32, fp16") +parser.add_argument('--model_path', type=str, default='./pytorch/translation/models/averaged-10-epoch.pt', + help='path for model checkpoint') +parser.add_argument('--module_path', type=str, default='./', + help='path containing the th_fastertransformer dynamic lib') +parser.add_argument('--ths', action='store_true', help='use custom TorchScript class (only for extensions)') +parser.add_argument('--ths_path', type=str, default='./lib/libths_fastertransformer.so', + help='path of the ths_fastertransformer dynamic lib file') +parser.add_argument('--input_file', type=str, default='./pytorch/translation/data/test.en', + help='input file path') +parser.add_argument('--output_file', type=str, default='', + help='output file path') +args = parser.parse_args() + +opt = argparse.Namespace(models=[args.model_path], + fp32=False, data_type='text', output='/dev/null', report_align=False, report_time=True, + random_sampling_topk=1, random_sampling_temp=1.0, seed=829, + beam_size=args.beam_size, min_length=0, max_length=100, + stepwise_penalty=False, length_penalty='none', ratio=-0.0, coverage_penalty='none', alpha=0.0, beta=-0.0, + block_ngram_repeat=0, ignore_when_blocking=[], replace_unk=False, phrase_table='', + verbose=True, dump_beam='', n_best=1, batch_type='sents', gpu=0) + + +fields, model, model_opt = load_test_model(opt, args) +scorer = GNMTGlobalScorer.from_opt(opt) +out_file = codecs.open(opt.output, 'w+', 'utf-8') +translator = Translator.from_opt( + model, + fields, + opt, + model_opt, + args, + global_scorer=scorer, + out_file=out_file, + report_align=opt.report_align, + report_score=False, + logger=None +) + + +res = [] +n = 1 +with open(args.input_file, 'r') as f: + lines = f.readlines() + lines = [line.strip() for line in lines] + translated = translator.translate(lines, batch_size=args.batch_size) + for i in range(len(translated[1])): + res.append(translated[1][i][0]) + +if args.output_file: + with open(args.output_file, 'w') as f: + for line in res: + f.write(line + '\n') diff --git a/sample/pytorch/scripts/download_translation_model.sh b/sample/pytorch/scripts/download_translation_model.sh new file mode 100644 index 000000000..e0b4d5b94 --- /dev/null +++ b/sample/pytorch/scripts/download_translation_model.sh @@ -0,0 +1,25 @@ +#! /bin/bash +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +MAIN_PATH=$PWD + +mkdir -p $MAIN_PATH/pytorch/translation/models/ + +cd $MAIN_PATH/pytorch/translation/models/ +if [ ! -f "sentencepiece.model" ] || [ ! -f "averaged-10-epoch.pt" ]; then + wget -c https://s3.amazonaws.com/opennmt-models/transformer-ende-wmt-pyOnmt.tar.gz + tar -xzvf transformer-ende-wmt-pyOnmt.tar.gz + rm transformer-ende-wmt-pyOnmt.tar.gz +fi diff --git a/sample/pytorch/scripts/profile_decoder_decoding.sh b/sample/pytorch/scripts/profile_decoder_decoding.sh new file mode 100644 index 000000000..1664af261 --- /dev/null +++ b/sample/pytorch/scripts/profile_decoder_decoding.sh @@ -0,0 +1,63 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# apt-get update +# apt-get install bc +pip install opennmt-py==1.1.1 + +for precision in fp32 fp16; +do + +if [ "$precision" = "fp16" ]; then + echo "Using fp16." + precision_num=1 +else + echo "Using fp32" + precision_num=0 +fi + +logdir="decoding-log-${precision}" +mkdir -p ${logdir} +all_log="${logdir}/all-log.log" +echo -e "| | PyTorch (ms) | Decoder (ms) | Decoding (ms) | Decoder Speedup | Decoding Speedup | " > $all_log +echo -e "|:-----------------------:|:------:|:------:|:------:|:---------:|:---------:| " >> $all_log + +for beam_size in 1 4 ; +do +for batch_size in 1 8 32 64 128 ; +do +for seq_len in 32 64 128 ; +do + ./bin/decoding_gemm ${batch_size} ${beam_size} 8 64 31538 ${seq_len} 512 ${precision_num} + tmp_log=${logdir}/beamsize-${beam_size}-batchsize-${batch_size}-seq-${seq_len}-${precision}-log.log + if [ "$precision" = "fp16" ]; then + python pytorch/decoding_sample.py ${batch_size} 6 ${seq_len} 8 64 ${beam_size} 31538 --fp16 --time 2>&1 | tee $tmp_log + else + python pytorch/decoding_sample.py ${batch_size} 6 ${seq_len} 8 64 ${beam_size} 31538 --time 2>&1 | tee $tmp_log + fi + pt_time=`tail -n 3 ${tmp_log} | head -n 1 | awk '{print $5}'` + decoder_time=`tail -n 2 ${tmp_log} | head -n 1 | awk '{print $7}'` + decoding_o_time=`tail -n 1 ${tmp_log} | awk '{print $5}'` + + speedup_decoder=$(echo "scale=2; $pt_time / $decoder_time" | bc) + speedup_decoding=$(echo "scale=2; $pt_time / $decoding_o_time" | bc) + echo ' ' | awk -v batch_size=$batch_size -v seq_len=$seq_len -v beam_size=$beam_size \ + -v pt_time=$pt_time -v decoder_time=$decoder_time \ + -v decoding_o_time=$decoding_o_time -v speedup_decoder=$speedup_decoder -v speedup_decoding=$speedup_decoding \ + '{print "| <" batch_size ", " seq_len ", " beam_size "> | " pt_time " | " \ + decoder_time " | " decoding_o_time " | " speedup_decoder " | " speedup_decoding " | " }' >> $all_log +done +done +done +done diff --git a/sample/pytorch/scripts/profile_encoder.sh b/sample/pytorch/scripts/profile_encoder.sh new file mode 100644 index 000000000..b61f29efc --- /dev/null +++ b/sample/pytorch/scripts/profile_encoder.sh @@ -0,0 +1,67 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# apt-get update +# apt-get install bc +pip install transformers==2.5.1 + +for precision in fp32 fp16; +do + +if [ "$precision" = "fp16" ]; then + echo "Using fp16." + precision_num=1 +else + echo "Using fp32" + precision_num=0 +fi + +logdir="bert-base-log-${precision}" +mkdir ${logdir} +all_log="${logdir}/all-log.log" +echo -e "| | PyTorch (ms) | TorchScript (ms) | CustomExt (ms) | Speedup (w/ PyTorch) | Speedup (w/ TorchScript) | " > $all_log +echo -e "|:---------------------:|:------:|:------:|:------:|:--------:|:--------:| " >> $all_log + +for batch_size in 1 8 32 64 128 ; +do +for seq_len in 32 64 128 ; +do + ./bin/encoder_gemm ${batch_size} ${seq_len} 12 64 ${precision_num} + + tmp_log_pt=${logdir}/batchsize-${batch_size}-seq-${seq_len}-${precision}-pt-log.log + if [ "$precision" = "fp16" ]; then + python pytorch/encoder_sample.py ${batch_size} 12 ${seq_len} 12 64 --fp16 --time 2>&1 | tee $tmp_log_pt + else + python pytorch/encoder_sample.py ${batch_size} 12 ${seq_len} 12 64 --time 2>&1 | tee $tmp_log_pt + fi + pt_time=`tail -n 2 ${tmp_log_pt} | head -n 1 | awk '{print $5}'` + ft_o_time=`tail -n 1 ${tmp_log_pt} | awk '{print $5}'` + + tmp_log_ths=${logdir}/batchsize-${batch_size}-seq-${seq_len}-${precision}-ths-log.log + if [ "$precision" = "fp16" ]; then + python pytorch/encoder_sample.py ${batch_size} 12 ${seq_len} 12 64 --fp16 --ths --time 2>&1 | tee $tmp_log_ths + else + python pytorch/encoder_sample.py ${batch_size} 12 ${seq_len} 12 64 --ths --time 2>&1 | tee $tmp_log_ths + fi + ths_time=`tail -n 2 ${tmp_log_ths} | head -n 1 | awk '{print $5}'` + + speedup_pt=$(echo "scale=2; $pt_time / $ft_o_time" | bc) + speedup_ths=$(echo "scale=2; $ths_time / $ft_o_time" | bc) + echo ' ' | awk -v batch_size=$batch_size -v seq_len=$seq_len -v pt_time=$pt_time -v ths_time=$ths_time \ + -v ft_o_time=$ft_o_time -v speedup_pt=$speedup_pt -v speedup_ths=$speedup_ths \ + '{print "| <" batch_size ", " seq_len "> | " pt_time " | " \ + ths_time " | " ft_o_time " | " speedup_pt " | " speedup_ths " | " }' >> $all_log +done +done +done diff --git a/sample/pytorch/scripts/run_mrpc.sh b/sample/pytorch/scripts/run_mrpc.sh new file mode 100644 index 000000000..039003138 --- /dev/null +++ b/sample/pytorch/scripts/run_mrpc.sh @@ -0,0 +1,76 @@ +#! /bin/bash + +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ "$1" != "ori" ] && [ "$1" != "ths" ] && [ "$1" != "ext" ] && [ "$1" != "thsext" ]; then + echo "wrong model type" + echo "[Usage]: bash PATH_TO_THIS_SCRIPT model_type[ori, ths, ext, thsext] data_type[fp32, fp16]" + exit 1 +fi +if [ "$2" != "fp32" ] && [ "$2" != "fp16" ]; then + echo "wrong data type" + echo "[Usage]: bash PATH_TO_THIS_SCRIPT model_type[ori, ext] data_type[fp32, fp16]" + exit 1 +fi + +batch_size=8 +seq_len=128 + +MAIN_PATH=$PWD + +mkdir -p $MAIN_PATH/pytorch/bert_mrpc/models/bert-base-cased-finetuned-mrpc +mkdir -p $MAIN_PATH/pytorch/bert_mrpc/data +mkdir -p $MAIN_PATH/pytorch/bert_mrpc/output + +cd $MAIN_PATH/pytorch/bert_mrpc/data +if [ ! -f "dev.tsv" ]; then + python $MAIN_PATH/pytorch/utils/get_mrpc_data.py --data_dir $MAIN_PATH/pytorch/bert_mrpc/data +fi + +cd $MAIN_PATH/pytorch/bert_mrpc/models/bert-base-cased-finetuned-mrpc +if [ ! -f "config.json" ]; then + wget -c https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json + mv bert-base-cased-finetuned-mrpc-config.json config.json +fi +if [ ! -f "pytorch_model.bin" ]; then + wget -c https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin + mv bert-base-cased-finetuned-mrpc-pytorch_model.bin pytorch_model.bin +fi +if [ ! -f "vocab.txt" ]; then + wget -c https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt + mv bert-base-cased-finetuned-mrpc-vocab.txt vocab.txt +fi +cd $MAIN_PATH + +if [ "$1" == "ext" ] || [ "$1" == "thsext" ]; then + if [ "$2" == "fp32" ]; then + $MAIN_PATH/bin/encoder_gemm ${batch_size} ${seq_len} 12 64 0 + else + $MAIN_PATH/bin/encoder_gemm ${batch_size} ${seq_len} 12 64 1 + fi +fi + +python $MAIN_PATH/pytorch/run_glue.py \ + --model_name_or_path $MAIN_PATH/pytorch/bert_mrpc/models/bert-base-cased-finetuned-mrpc \ + --task_name MRPC \ + --do_eval \ + --do_lower_case \ + --data_dir $MAIN_PATH/pytorch/bert_mrpc/data \ + --output_dir $MAIN_PATH/pytorch/bert_mrpc/output/ \ + --cache_dir $MAIN_PATH/pytorch/bert_mrpc/models/ \ + --max_seq_length ${seq_len} \ + --per_gpu_eval_batch_size ${batch_size} \ + --model_type $1 \ + --data_type $2 \ diff --git a/sample/pytorch/scripts/run_squad.sh b/sample/pytorch/scripts/run_squad.sh new file mode 100644 index 000000000..f4860d4e9 --- /dev/null +++ b/sample/pytorch/scripts/run_squad.sh @@ -0,0 +1,75 @@ +#! /bin/bash +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ "$1" != "ori" ] && [ "$1" != "ths" ] && [ "$1" != "ext" ] && [ "$1" != "thsext" ]; then + echo "wrong model type" + echo "[Usage]: bash PATH_TO_THIS_SCRIPT model_type[ori, ths, ext, thsext] data_type[fp32, fp16]" + exit 1 +fi +if [ "$2" != "fp32" ] && [ "$2" != "fp16" ]; then + echo "wrong data type" + echo "[Usage]: bash PATH_TO_THIS_SCRIPT model_type[ori, ext] data_type[fp32, fp16]" + exit 1 +fi + +batch_size=8 +seq_len=384 + +MAIN_PATH=$PWD + +mkdir -p $MAIN_PATH/pytorch/bert_squad/models/bert-large-uncased-whole-word-masking-finetuned-squad +mkdir -p $MAIN_PATH/pytorch/bert_squad/squad_data +mkdir -p $MAIN_PATH/pytorch/bert_squad/output + +cd $MAIN_PATH/pytorch/bert_squad/squad_data +# wget -c https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json +wget -c https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json + +cd $MAIN_PATH/pytorch/bert_squad/models/bert-large-uncased-whole-word-masking-finetuned-squad +if [ ! -f "config.json" ]; then + wget -c https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json + mv bert-large-uncased-whole-word-masking-finetuned-squad-config.json config.json +fi +if [ ! -f "pytorch_model.bin" ]; then + wget -c https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin + mv bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin pytorch_model.bin +fi +if [ ! -f "vocab.txt" ]; then + wget -c https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt + mv bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt vocab.txt +fi +cd $MAIN_PATH + +if [ "$1" == "ext" ] || [ "$1" == "thsext" ]; then + if [ "$2" == "fp32" ]; then + $MAIN_PATH/bin/encoder_gemm ${batch_size} ${seq_len} 16 64 0 + else + $MAIN_PATH/bin/encoder_gemm ${batch_size} ${seq_len} 16 64 1 + fi +fi + +python $MAIN_PATH/pytorch/run_squad.py \ + --model_name_or_path $MAIN_PATH/pytorch/bert_squad/models/bert-large-uncased-whole-word-masking-finetuned-squad \ + --do_eval \ + --do_lower_case \ + --predict_file $MAIN_PATH/pytorch/bert_squad/squad_data/dev-v1.1.json \ + --output_dir $MAIN_PATH/pytorch/bert_squad/output/ \ + --cache_dir $MAIN_PATH/pytorch/bert_squad/models/ \ + --max_seq_length ${seq_len} \ + --doc_stride 128 \ + --max_query_length 64 \ + --per_gpu_eval_batch_size ${batch_size} \ + --model_type $1 \ + --data_type $2 \ diff --git a/sample/pytorch/utils/__init__.py b/sample/pytorch/utils/__init__.py new file mode 100644 index 000000000..9e3250071 --- /dev/null +++ b/sample/pytorch/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sample/pytorch/utils/decoder.py b/sample/pytorch/utils/decoder.py new file mode 100644 index 000000000..c94b55a09 --- /dev/null +++ b/sample/pytorch/utils/decoder.py @@ -0,0 +1,156 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys +import torch + +from onmt.decoders.transformer import TransformerDecoderLayer + + +class DecoderWeights(object): + def __init__(self, layer_num, hidden_dim): + self.layer_num = layer_num + self.w = [[] for _ in range(layer_num)] + for layer_weights in self.w: + layer_weights.append(torch.zeros(hidden_dim)) # self_layernorm_gamma + layer_weights.append(torch.zeros(hidden_dim)) # self_layernorm_beta + layer_weights.append(torch.zeros(hidden_dim, hidden_dim)) # self_kernel_q + layer_weights.append(torch.zeros(hidden_dim, hidden_dim)) # self_kernel_k + layer_weights.append(torch.zeros(hidden_dim, hidden_dim)) # self_kernel_v + layer_weights.append(torch.zeros(hidden_dim)) # self_bias_q + layer_weights.append(torch.zeros(hidden_dim)) # self_bias_k + layer_weights.append(torch.zeros(hidden_dim)) # self_bias_v + layer_weights.append(torch.zeros(hidden_dim, hidden_dim)) # self_output_kernel + layer_weights.append(torch.zeros(hidden_dim)) # self_output_bias + layer_weights.append(torch.zeros(hidden_dim)) # cross_layernorm_gamma + layer_weights.append(torch.zeros(hidden_dim)) # cross_layernorm_beta + layer_weights.append(torch.zeros(hidden_dim, hidden_dim)) # cross_kernel_q + layer_weights.append(torch.zeros(hidden_dim, hidden_dim)) # cross_kernel_k + layer_weights.append(torch.zeros(hidden_dim, hidden_dim)) # cross_kernel_v + layer_weights.append(torch.zeros(hidden_dim)) # cross_bias_q + layer_weights.append(torch.zeros(hidden_dim)) # cross_bias_k + layer_weights.append(torch.zeros(hidden_dim)) # cross_bias_v + layer_weights.append(torch.zeros(hidden_dim, hidden_dim)) # cross_output_kernel + layer_weights.append(torch.zeros(hidden_dim)) # cross_output_bias + layer_weights.append(torch.zeros(hidden_dim)) # ffn_layernorm_gamma + layer_weights.append(torch.zeros(hidden_dim)) # ffn_layernorm_beta + layer_weights.append(torch.zeros(hidden_dim, 4 * hidden_dim)) # inter_kernel + layer_weights.append(torch.zeros(4 * hidden_dim)) # inter_bias + layer_weights.append(torch.zeros(4 * hidden_dim, hidden_dim)) # output_kernel + layer_weights.append(torch.zeros(hidden_dim)) # output_bias + for i in range(len(layer_weights)): + torch.nn.init.uniform_(layer_weights[i], -1, 1) + + def to_cuda(self): + for i in range(self.layer_num): + for j in range(len(self.w[i])): + self.w[i][j] = self.w[i][j].cuda() + + def to_half(self): + for i in range(self.layer_num): + for j in range(len(self.w[i])): + self.w[i][j] = self.w[i][j].half() + + +def init_op_cache(layer_num, batch_size, beam_width, max_seq_len, hidden_dim, is_fp16): + if is_fp16: + self_cache = torch.zeros(layer_num, 2, 0, batch_size * beam_width, hidden_dim, dtype=torch.half).cuda() + mem_cache = torch.zeros(layer_num, 2, batch_size * beam_width, max_seq_len, hidden_dim, dtype=torch.half).cuda() + else: + self_cache = torch.zeros(layer_num, 2, 0, batch_size * beam_width, hidden_dim).cuda() + mem_cache = torch.zeros(layer_num, 2, batch_size * beam_width, max_seq_len, hidden_dim).cuda() + return self_cache, mem_cache + +def init_onmt_cache(layer_num, memory_bank): + cache = {} + for i in range(layer_num): + layer_cache = {"memory_keys": None, "memory_values": None} + layer_cache["self_keys"] = None + layer_cache["self_values"] = None + cache[i] = layer_cache + return cache + + +class CustomDecoder(torch.nn.Module): + def __init__(self, layer_num, head_num, head_size, weights, is_fp16, path='./', use_ths=False): + super().__init__() + self.layer_num = layer_num + self.hidden_dim = head_num * head_size + self.fp16 = is_fp16 + self.decoders = [] + if use_ths: + torch.classes.load_library(path) + for i in range(layer_num): + self.decoders.append(torch.classes.FasterTransformerDecoder(head_num, head_size, *weights.w[i])) + else: + sys.path.insert(0, path) + from th_fastertransformer import FasterTransformerDecoder + for i in range(layer_num): + self.decoders.append(FasterTransformerDecoder(head_num, head_size, *weights.w[i])) + + def forward(self, inputs, memory, memory_seq_lens, self_cache, mem_cache): + if self.fp16: + self_cache_tmp = torch.zeros(self.layer_num, 2, 1, self_cache.size(3), self.hidden_dim, dtype=torch.half).cuda() + else: + self_cache_tmp = torch.zeros(self.layer_num, 2, 1, self_cache.size(3), self.hidden_dim).cuda() + self_cache = torch.cat([self_cache, self_cache_tmp], 2) + output = inputs + for i in range(self.layer_num): + output = self.decoders[i].forward(output, memory, memory_seq_lens, self_cache[i], mem_cache[i]) + return output, self_cache, mem_cache + + +class ONMTDecoder(torch.nn.Module): + def __init__(self, layer_num, head_num, head_size, weights): + super().__init__() + self.layer_num = layer_num + self.hidden_dim = head_num * head_size + self.decoders = torch.nn.ModuleList() + for i in range(layer_num): + self.decoders.append(TransformerDecoderLayer(self.hidden_dim, head_num, 4 * self.hidden_dim, 0, 0)) + for i in range(layer_num): + self.decoders[i].layer_norm_1.weight.data = weights.w[i][0] + self.decoders[i].layer_norm_1.bias.data = weights.w[i][1] + self.decoders[i].self_attn.linear_query.weight.data = weights.w[i][2].transpose(-1, -2).contiguous() + self.decoders[i].self_attn.linear_keys.weight.data = weights.w[i][3].transpose(-1, -2).contiguous() + self.decoders[i].self_attn.linear_values.weight.data = weights.w[i][4].transpose(-1, -2).contiguous() + self.decoders[i].self_attn.linear_query.bias.data = weights.w[i][5] + self.decoders[i].self_attn.linear_keys.bias.data = weights.w[i][6] + self.decoders[i].self_attn.linear_values.bias.data = weights.w[i][7] + self.decoders[i].self_attn.final_linear.weight.data = weights.w[i][8].transpose(-1, -2).contiguous() + self.decoders[i].self_attn.final_linear.bias.data = weights.w[i][9] + self.decoders[i].layer_norm_2.weight.data = weights.w[i][10] + self.decoders[i].layer_norm_2.bias.data = weights.w[i][11] + self.decoders[i].context_attn.linear_query.weight.data = weights.w[i][12].transpose(-1, -2).contiguous() + self.decoders[i].context_attn.linear_keys.weight.data = weights.w[i][13].transpose(-1, -2).contiguous() + self.decoders[i].context_attn.linear_values.weight.data = weights.w[i][14].transpose(-1, -2).contiguous() + self.decoders[i].context_attn.linear_query.bias.data = weights.w[i][15] + self.decoders[i].context_attn.linear_keys.bias.data = weights.w[i][16] + self.decoders[i].context_attn.linear_values.bias.data = weights.w[i][17] + self.decoders[i].context_attn.final_linear.weight.data = weights.w[i][18].transpose(-1, -2).contiguous() + self.decoders[i].context_attn.final_linear.bias.data = weights.w[i][19] + self.decoders[i].feed_forward.layer_norm.weight.data = weights.w[i][20] + self.decoders[i].feed_forward.layer_norm.bias.data = weights.w[i][21] + self.decoders[i].feed_forward.w_1.weight.data = weights.w[i][22].transpose(-1, -2).contiguous() + self.decoders[i].feed_forward.w_1.bias.data = weights.w[i][23] + self.decoders[i].feed_forward.w_2.weight.data = weights.w[i][24].transpose(-1, -2).contiguous() + self.decoders[i].feed_forward.w_2.bias.data = weights.w[i][25] + + def forward(self, inputs, memory, src_pad_msk, cache, step): + output = inputs + for i in range(self.layer_num): + output, _, _ = self.decoders[i](output, memory, src_pad_msk, None, cache[i], step) + return output diff --git a/sample/pytorch/utils/decoding.py b/sample/pytorch/utils/decoding.py new file mode 100644 index 000000000..bdbd31dc7 --- /dev/null +++ b/sample/pytorch/utils/decoding.py @@ -0,0 +1,626 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys +import os +import math +import torch +import torch.nn as nn + +from onmt.modules import Embeddings, AverageAttention +from onmt.decoders.decoder import DecoderBase +from onmt.decoders.transformer import TransformerDecoderLayer +from onmt.utils.misc import tile, sequence_mask + + +class DecodingWeights(object): + def __init__(self, layer_num, hidden_dim, vocab_size, onmtcheckpoint=None, max_step_for_pe=2048): + self.hidden_dim = hidden_dim + self.max_step_for_pe = max_step_for_pe + self.w = [] + if onmtcheckpoint: + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.layer_norm_1.weight'] for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.layer_norm_1.bias'] for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.self_attn.linear_query.weight'].transpose(-1, -2) for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.self_attn.linear_keys.weight'].transpose(-1, -2) for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.self_attn.linear_values.weight'].transpose(-1, -2) for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.self_attn.linear_query.bias'] for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.self_attn.linear_keys.bias'] for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.self_attn.linear_values.bias'] for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.self_attn.final_linear.weight'].transpose(-1, -2) for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.self_attn.final_linear.bias'] for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.layer_norm_2.weight'] for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.layer_norm_2.bias'] for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.context_attn.linear_query.weight'].transpose(-1, -2) for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.context_attn.linear_keys.weight'].transpose(-1, -2) for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.context_attn.linear_values.weight'].transpose(-1, -2) for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.context_attn.linear_query.bias'] for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.context_attn.linear_keys.bias'] for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.context_attn.linear_values.bias'] for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.context_attn.final_linear.weight'].transpose(-1, -2) for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.context_attn.final_linear.bias'] for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.feed_forward.layer_norm.weight'] for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.feed_forward.layer_norm.bias'] for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.feed_forward.w_1.weight'].transpose(-1, -2) for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.feed_forward.w_1.bias'] for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.feed_forward.w_2.weight'].transpose(-1, -2) for i in range(layer_num)], + 0).contiguous()) + self.w.append(torch.stack( + [onmtcheckpoint['model']['decoder.transformer_layers.' + str(i) + '.feed_forward.w_2.bias'] for i in range(layer_num)], + 0).contiguous()) + self.w.append(onmtcheckpoint['model']['decoder.layer_norm.weight']) + self.w.append(onmtcheckpoint['model']['decoder.layer_norm.bias']) + self.w.append(onmtcheckpoint['model']['decoder.embeddings.make_embedding.emb_luts.0.weight']) + self.w.append(self._get_position_encoding()) # pe_encoding + self.w.append(onmtcheckpoint['generator']['0.weight'].transpose(-1, -2).contiguous()) + self.w.append(onmtcheckpoint['generator']['0.bias']) + else: + self.w.append(torch.zeros(layer_num, hidden_dim)) # self_layernorm_gamma + self.w.append(torch.zeros(layer_num, hidden_dim)) # self_layernorm_beta + self.w.append(torch.zeros(layer_num, hidden_dim, hidden_dim)) # self_kernel_q + self.w.append(torch.zeros(layer_num, hidden_dim, hidden_dim)) # self_kernel_k + self.w.append(torch.zeros(layer_num, hidden_dim, hidden_dim)) # self_kernel_v + self.w.append(torch.zeros(layer_num, hidden_dim)) # self_bias_q + self.w.append(torch.zeros(layer_num, hidden_dim)) # self_bias_k + self.w.append(torch.zeros(layer_num, hidden_dim)) # self_bias_v + self.w.append(torch.zeros(layer_num, hidden_dim, hidden_dim)) # self_output_kernel + self.w.append(torch.zeros(layer_num, hidden_dim)) # self_output_bias + self.w.append(torch.zeros(layer_num, hidden_dim)) # cross_layernorm_gamma + self.w.append(torch.zeros(layer_num, hidden_dim)) # cross_layernorm_beta + self.w.append(torch.zeros(layer_num, hidden_dim, hidden_dim)) # cross_kernel_q + self.w.append(torch.zeros(layer_num, hidden_dim, hidden_dim)) # cross_kernel_k + self.w.append(torch.zeros(layer_num, hidden_dim, hidden_dim)) # cross_kernel_v + self.w.append(torch.zeros(layer_num, hidden_dim)) # cross_bias_q + self.w.append(torch.zeros(layer_num, hidden_dim)) # cross_bias_k + self.w.append(torch.zeros(layer_num, hidden_dim)) # cross_bias_v + self.w.append(torch.zeros(layer_num, hidden_dim, hidden_dim)) # cross_output_kernel + self.w.append(torch.zeros(layer_num, hidden_dim)) # cross_output_bias + self.w.append(torch.zeros(layer_num, hidden_dim)) # ffn_layernorm_gamma + self.w.append(torch.zeros(layer_num, hidden_dim)) # ffn_layernorm_beta + self.w.append(torch.zeros(layer_num, hidden_dim, 4 * hidden_dim)) # inter_kernel + self.w.append(torch.zeros(layer_num, 4 * hidden_dim)) # inter_bias + self.w.append(torch.zeros(layer_num, 4 * hidden_dim, hidden_dim)) # output_kernel + self.w.append(torch.zeros(layer_num, hidden_dim)) # output_bias + self.w.append(torch.zeros(hidden_dim)) # decoding_gamma + self.w.append(torch.zeros(hidden_dim)) # decoding_beta + self.w.append(torch.zeros(vocab_size, hidden_dim)) # embedding_table + self.w.append(self._get_position_encoding()) # pe_encoding + self.w.append(torch.zeros(hidden_dim, vocab_size)) # embedding_kernel + self.w.append(torch.zeros(vocab_size)) # embedding_bias + for i in range(len(self.w)): + torch.nn.init.uniform_(self.w[i], -1, 1) + + def to_cuda(self): + for i in range(len(self.w)): + self.w[i] = self.w[i].cuda() + + def to_half(self): + for i in range(len(self.w) - 1): # embedding_bias is float32 + self.w[i] = self.w[i].half() + + def _get_position_encoding(self): + pe = torch.zeros(self.max_step_for_pe, self.hidden_dim) + position = torch.arange(0, self.max_step_for_pe).unsqueeze(1) + div_term = torch.exp((torch.arange(0, self.hidden_dim, 2, dtype=torch.float) * + -(math.log(10000.0) / self.hidden_dim))) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + return pe + + +def gather_nd(params, indices): + indices = indices.t().long() + ndim = indices.size(0) + idx = torch.zeros_like(indices[0]).long() + m = 1 + + for i in range(ndim)[::-1]: + idx += indices[i] * m + m *= params.size(i) + + params = params.reshape((-1, *tuple(torch.tensor(params.size()[ndim:])))) + return params[idx] + + +def gather_tree(step_ids, parent_ids, max_sequence_lengths, end_token): + beams = torch.empty_like(step_ids) + beams.fill_(end_token) + max_len = step_ids.size(0) + batch_size = step_ids.size(1) + beam_size = step_ids.size(-1) + batch_beam = batch_size * beam_size + for i in range(batch_beam): + batch = i // beam_size + beam = i % beam_size + max_seq_len_b = min(max_len, max_sequence_lengths[batch]) + if max_seq_len_b <= 0: + continue + beams[max_seq_len_b - 1, batch, beam] = step_ids[max_seq_len_b - 1, batch, beam] + parent = parent_ids[max_seq_len_b - 1, batch, beam] + for level in range(max_seq_len_b - 2, -1, -1): + if parent < 0 or parent > beam_size: + raise ValueError("wrong parent id") + beams[level, batch, beam] = step_ids[level, batch, parent] + parent = parent_ids[level, batch, parent] + finished = False + for time in range(max_seq_len_b): + if finished: + beams[time, batch, beam] = end_token + elif beams[time, batch, beam] == end_token: + finished = True + return beams + + +def finalize(beam_size, output_ids, parent_ids, out_seq_lens, end_id, max_seq_len=None, args=None): + out_seq_lens = torch.reshape(out_seq_lens, (-1, beam_size)) + max_lens = torch.max(out_seq_lens, 1)[0] + if max_seq_len: + shape = (max_seq_len, -1, beam_size) + else: + shape = (torch.max(max_lens), -1, beam_size) + output_ids = torch.reshape(output_ids, shape) + parent_ids = torch.reshape(parent_ids, shape) + if output_ids.is_cuda: + if args.ths: + torch.classes.load_library(args.ths_path) + ids = torch.ops.fastertransformer.gather_tree(output_ids.to(torch.int32), parent_ids.to(torch.int32), max_lens.to(torch.int32), end_id) + else: + sys.path.insert(0, os.path.abspath(args.module_path)) + from th_fastertransformer import gather_tree as gather_tree_cuda + ids = gather_tree_cuda(output_ids.to(torch.int32), parent_ids.to(torch.int32), max_lens.to(torch.int32), end_id) + else: + ids = gather_tree(output_ids, parent_ids, max_lens, end_id) + ids = torch.einsum('ijk->jki', ids) # batch_size, beam_size, max_seq_len + lengths = torch.eq(ids, end_id) + lengths = 1 - lengths.to(output_ids.dtype) + lengths = torch.sum(lengths, -1) + return ids, lengths + + +class FTDecoderLayer(nn.Module): + def __init__(self, head_num, head_size, weights, args): + super().__init__() + self.args = args + if args.ths: + torch.classes.load_library(args.ths_path) + self.dec_layer = torch.classes.FasterTransformerDecoder(head_num, head_size, *weights) + else: + sys.path.insert(0, os.path.abspath(args.module_path)) + from th_fastertransformer import FasterTransformerDecoder + self.dec_layer = FasterTransformerDecoder(head_num, head_size, *weights) + + def forward(self, inputs, memory, memory_seq_lens, self_cache, mem_cache): + if self.args.data_type == 'fp16': + self_cache_tmp = torch.zeros(2, 1, self_cache.size(2), self_cache.size(3), dtype=torch.half).cuda() + else: + self_cache_tmp = torch.zeros(2, 1, self_cache.size(2), self_cache.size(3)).cuda() + self_cache = torch.cat([self_cache, self_cache_tmp], 1) + output = self.dec_layer.forward(inputs, memory, memory_seq_lens, self_cache, mem_cache) + return output, self_cache, mem_cache + + +class TransformerDecoder(DecoderBase): + """The Transformer decoder from "Attention is All You Need". + Args: + num_layers (int): number of encoder layers. + d_model (int): size of the model + heads (int): number of heads + d_ff (int): size of the inner FF layer + copy_attn (bool): if using a separate copy attention + self_attn_type (str): type of self-attention scaled-dot, average + dropout (float): dropout in residual, self-attn(dot) and feed-forward + attention_dropout (float): dropout in context_attn (and self-attn(avg)) + embeddings (onmt.modules.Embeddings): + embeddings to use, should have positional encodings + max_relative_positions (int): + Max distance between inputs in relative positions representations + aan_useffn (bool): Turn on the FFN layer in the AAN decoder + full_context_alignment (bool): + whether enable an extra full context decoder forward for alignment + alignment_layer (int): N° Layer to supervise with for alignment guiding + alignment_heads (int): + N. of cross attention heads to use for alignment guiding + """ + + def __init__(self, num_layers, d_model, heads, d_ff, + copy_attn, self_attn_type, dropout, attention_dropout, + embeddings, max_relative_positions, aan_useffn, + full_context_alignment, alignment_layer, + alignment_heads, args): + super(TransformerDecoder, self).__init__() + + self.args = args + self.embeddings = embeddings + + # Decoder State + self.state = {} + + self.transformer_layers = nn.ModuleList( + [TransformerDecoderLayer(d_model, heads, d_ff, dropout, + attention_dropout, self_attn_type=self_attn_type, + max_relative_positions=max_relative_positions, + aan_useffn=aan_useffn, + full_context_alignment=full_context_alignment, + alignment_heads=alignment_heads) + for i in range(num_layers)]) + + # previously, there was a GlobalAttention module here for copy + # attention. But it was never actually used -- the "copy" attention + # just reuses the context attention. + self._copy = copy_attn + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + + self.alignment_layer = alignment_layer + + @classmethod + def from_opt(cls, opt, embeddings, args): + """Alternate constructor.""" + return cls( + opt.dec_layers, + opt.dec_rnn_size, + opt.heads, + opt.transformer_ff, + opt.copy_attn, + opt.self_attn_type, + opt.dropout[0] if type(opt.dropout) is list else opt.dropout, + opt.attention_dropout[0] if type(opt.attention_dropout) + is list else opt.dropout, + embeddings, + opt.max_relative_positions, + opt.aan_useffn, + opt.full_context_alignment, + opt.alignment_layer, + alignment_heads=opt.alignment_heads, + args=args) + + def init_state(self, src, memory_bank, enc_hidden): + """Initialize decoder state.""" + self.state["src"] = src + self.state["cache"] = None + + def map_state(self, fn): + def _recursive_map(struct, batch_dim=0): + for k, v in struct.items(): + if v is not None: + if isinstance(v, dict): + _recursive_map(v, batch_dim) + else: + struct[k] = fn(v, batch_dim) + + self.state["src"] = fn(self.state["src"], 1) + if self.args.model_type == 'ori' or self.args.model_type == 'torch_decoding': + if self.state["cache"] is not None: + _recursive_map(self.state["cache"]) + if self.args.model_type == 'decoder_ext' or self.args.model_type == 'torch_decoding_with_decoder_ext': + if self.state["cache"] is not None: + _recursive_map(self.state["cache"], 2) + + def detach_state(self): + self.state["src"] = self.state["src"].detach() + + def forward(self, tgt, memory_bank, step=None, **kwargs): + """Decode, possibly stepwise.""" + if step == 0: + self._init_cache(memory_bank) + + tgt_words = tgt[:, :, 0].transpose(0, 1) + + emb = self.embeddings(tgt, step=step) + assert emb.dim() == 3 # len x batch x embedding_dim + + output = emb.transpose(0, 1).contiguous() + src_memory_bank = memory_bank.transpose(0, 1).contiguous() + + pad_idx = self.embeddings.word_padding_idx + src_lens = kwargs["memory_lengths"] + + if self.args.model_type == 'ori' or self.args.model_type == 'torch_decoding': + src_max_len = self.state["src"].shape[0] + src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1) + tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt] + + with_align = kwargs.pop('with_align', False) + attn_aligns = [] + + for i, layer in enumerate(self.transformer_layers): + layer_cache = self.state["cache"]["layer_{}".format(i)] \ + if step is not None else None + output, attn, attn_align = layer( + output, + src_memory_bank, + src_pad_mask, + tgt_pad_mask, + layer_cache=layer_cache, + step=step, + with_align=with_align) + if attn_align is not None: + attn_aligns.append(attn_align) + elif self.args.model_type == 'decoder_ext' or self.args.model_type == 'torch_decoding_with_decoder_ext': + src_lens_ = src_lens.to(torch.int) + for i, layer in enumerate(self.transformer_layers): + layer_cache = self.state["cache"]["layer_{}".format(i)] + output, self_cache_, mem_cache_ = layer(output, src_memory_bank, src_lens_, layer_cache['self'], layer_cache['mem']) + layer_cache['self'] = self_cache_ + layer_cache['mem'] = mem_cache_ + + output = self.layer_norm(output) + dec_outs = output.transpose(0, 1).contiguous() + attns = {} + # attn = attn.transpose(0, 1).contiguous() + + # attns = {"std": attn} + # if self._copy: + # attns["copy"] = attn + # if with_align: + # attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)` + # # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg + + # TODO change the way attns is returned dict => list or tuple (onnx) + return dec_outs, attns + + def _init_cache(self, memory_bank): + self.state["cache"] = {} + batch_size = memory_bank.size(1) + depth = memory_bank.size(-1) + + if self.args.model_type == 'ori' or self.args.model_type == 'torch_decoding': + for i, layer in enumerate(self.transformer_layers): + layer_cache = {"memory_keys": None, "memory_values": None} + if isinstance(layer.self_attn, AverageAttention): + layer_cache["prev_g"] = torch.zeros((batch_size, 1, depth), + device=memory_bank.device) + else: + layer_cache["self_keys"] = None + layer_cache["self_values"] = None + self.state["cache"]["layer_{}".format(i)] = layer_cache + elif self.args.model_type == 'decoder_ext' or self.args.model_type == 'torch_decoding_with_decoder_ext': + max_seq_len = memory_bank.size(0) + for i in range(len(self.transformer_layers)): + layer_cache = {} + if self.args.data_type == 'fp16': + layer_cache['self'] = torch.zeros(2, 0, batch_size, depth, dtype=torch.half).cuda() + layer_cache['mem'] = torch.zeros(1, 2, batch_size, max_seq_len, depth, dtype=torch.half).cuda() + else: + layer_cache['self'] = torch.zeros(2, 0, batch_size, depth).cuda() + layer_cache['mem'] = torch.zeros(1, 2, batch_size, max_seq_len, depth).cuda() + self.state["cache"]["layer_{}".format(i)] = layer_cache + + def update_dropout(self, dropout, attention_dropout): + self.embeddings.update_dropout(dropout) + for layer in self.transformer_layers: + layer.update_dropout(dropout, attention_dropout) + + +class CustomDecoding(nn.Module): + def __init__(self, layer_num, head_num, head_size, vocab_size, start_id, end_id, weights, beam_search_diversity_rate=0.0, args=None): + super().__init__() + hidden_dim = head_num * head_size + self.end_id = end_id + self.args = args + if args.ths: + torch.classes.load_library(os.path.abspath(args.ths_path)) + self.decoding = torch.classes.FasterTransformerDecoding(head_num, head_size, hidden_dim, layer_num, vocab_size, start_id, end_id, beam_search_diversity_rate, *weights.w) + else: + sys.path.insert(0, os.path.abspath(args.module_path)) + from th_fastertransformer import FasterTransformerDecoding + self.decoding = FasterTransformerDecoding(head_num, head_size, hidden_dim, layer_num, vocab_size, start_id, end_id, beam_search_diversity_rate, *weights.w) + + def forward(self, batch_size, beam_size, max_seq_len, memory, memory_seq_lens): + extended_memory = tile(memory, beam_size) + extended_memory_seq_lens = tile(memory_seq_lens, beam_size) + output_ids, parent_ids, out_seq_lens = self.decoding.forward(batch_size, beam_size, max_seq_len, extended_memory, extended_memory_seq_lens) + parent_ids = parent_ids % beam_size + beams, lengths = finalize(beam_size, output_ids, parent_ids, out_seq_lens, self.end_id, max_seq_len, args=self.args) + return beams, lengths + + +class ArgHelper(object): + def __init__(self, model_type=None, data_type=None, module_path=None, ths=False, ths_path=None): + self.model_type = model_type + self.data_type = data_type + self.module_path = module_path + self.ths = ths + self.ths_path = ths_path + + +class TorchDecoding(nn.Module): + def __init__(self, layer_num, head_num, head_size, vocab_size, start_id, end_id, weights, + beam_search_diversity_rate=0.0, args=None): + super().__init__() + self.layer_num = layer_num + self.hidden_dim = head_num * head_size + self.start_id = start_id + self.end_id = end_id + self.vocab_size = vocab_size + self.diversity_rate = beam_search_diversity_rate + self.args = args + emb = Embeddings(self.hidden_dim, vocab_size, 1, position_encoding=True) + self.decoder = TransformerDecoder(layer_num, self.hidden_dim, head_num, 4*self.hidden_dim, + False, 'scaled-dot', 0, 0, emb, 0, False, False, -3, 0, args) + self.generator = nn.Linear(self.hidden_dim, vocab_size) + self.logsoftmax = nn.LogSoftmax(dim=-1) + self.module_path = args.module_path + if args.model_type == 'torch_decoding': + for i in range(layer_num): + self.decoder.transformer_layers[i].layer_norm_1.weight.data = weights.w[0][i] + self.decoder.transformer_layers[i].layer_norm_1.bias.data = weights.w[1][i] + self.decoder.transformer_layers[i].self_attn.linear_query.weight.data = weights.w[2][i].transpose(-1, -2).contiguous() + self.decoder.transformer_layers[i].self_attn.linear_keys.weight.data = weights.w[3][i].transpose(-1, -2).contiguous() + self.decoder.transformer_layers[i].self_attn.linear_values.weight.data = weights.w[4][i].transpose(-1, -2).contiguous() + self.decoder.transformer_layers[i].self_attn.linear_query.bias.data = weights.w[5][i] + self.decoder.transformer_layers[i].self_attn.linear_keys.bias.data = weights.w[6][i] + self.decoder.transformer_layers[i].self_attn.linear_values.bias.data = weights.w[7][i] + self.decoder.transformer_layers[i].self_attn.final_linear.weight.data = weights.w[8][i].transpose(-1, -2).contiguous() + self.decoder.transformer_layers[i].self_attn.final_linear.bias.data = weights.w[9][i] + self.decoder.transformer_layers[i].layer_norm_2.weight.data = weights.w[10][i] + self.decoder.transformer_layers[i].layer_norm_2.bias.data = weights.w[11][i] + self.decoder.transformer_layers[i].context_attn.linear_query.weight.data = weights.w[12][i].transpose(-1, -2).contiguous() + self.decoder.transformer_layers[i].context_attn.linear_keys.weight.data = weights.w[13][i].transpose(-1, -2).contiguous() + self.decoder.transformer_layers[i].context_attn.linear_values.weight.data = weights.w[14][i].transpose(-1, -2).contiguous() + self.decoder.transformer_layers[i].context_attn.linear_query.bias.data = weights.w[15][i] + self.decoder.transformer_layers[i].context_attn.linear_keys.bias.data = weights.w[16][i] + self.decoder.transformer_layers[i].context_attn.linear_values.bias.data = weights.w[17][i] + self.decoder.transformer_layers[i].context_attn.final_linear.weight.data = weights.w[18][i].transpose(-1, -2).contiguous() + self.decoder.transformer_layers[i].context_attn.final_linear.bias.data = weights.w[19][i] + self.decoder.transformer_layers[i].feed_forward.layer_norm.weight.data = weights.w[20][i] + self.decoder.transformer_layers[i].feed_forward.layer_norm.bias.data = weights.w[21][i] + self.decoder.transformer_layers[i].feed_forward.w_1.weight.data = weights.w[22][i].transpose(-1, -2).contiguous() + self.decoder.transformer_layers[i].feed_forward.w_1.bias.data = weights.w[23][i] + self.decoder.transformer_layers[i].feed_forward.w_2.weight.data = weights.w[24][i].transpose(-1, -2).contiguous() + self.decoder.transformer_layers[i].feed_forward.w_2.bias.data = weights.w[25][i] + elif args.model_type == 'torch_decoding_with_decoder_ext': + w = [] + for i in range(layer_num): + w.append([weights.w[j][i].clone().detach() for j in range(26)]) + for i in range(len(w[-1])): + w[-1][i] = w[-1][i].cuda() + if args.data_type == 'fp16': + for i in range(len(w[-1])): + w[-1][i] = w[-1][i].half() + decoder_layers = nn.ModuleList( + [FTDecoderLayer(head_num, head_size, w[i], args) for i in range(layer_num)]) + self.decoder.transformer_layers = decoder_layers + else: + raise ValueError('wrong model_type') + self.decoder.layer_norm.weight.data = weights.w[26] + self.decoder.layer_norm.bias.data = weights.w[27] + self.decoder.embeddings.make_embedding.emb_luts[0].weight.data = weights.w[28] + self.generator.weight.data = weights.w[30].transpose(-1, -2).contiguous() + self.generator.bias.data = weights.w[31] + + def forward(self, batch_size, beam_size, max_seq_len, memory, memory_seq_lens): + extended_memory = tile(memory, beam_size) + batchxbeam = extended_memory.size(0) + extended_memory = extended_memory.transpose(0, 1).contiguous() + + extended_memory_seq_lens = tile(memory_seq_lens, beam_size) + start_ids = extended_memory_seq_lens.new_full((batchxbeam,), self.start_id, dtype=torch.int64) + + initial_log_probs = extended_memory.new_full((beam_size,), -float("inf"), dtype=torch.float32) + initial_log_probs[0] = 0. + initial_log_probs = initial_log_probs.repeat(batch_size) + sequence_lengths = extended_memory_seq_lens.new_full((batchxbeam,), 0) + finished = extended_memory_seq_lens.new_full((batchxbeam,), 0, dtype=torch.bool) + + dtype_info = torch.finfo(extended_memory.dtype) + eos_max_prob = extended_memory.new_full((batchxbeam, self.vocab_size), dtype_info.min) + eos_max_prob[:, self.end_id] = dtype_info.max + + self.decoder.init_state(extended_memory, extended_memory, None) + word_ids = start_ids + cum_log_probs = initial_log_probs + + for step in range(max_seq_len): + if not torch.bitwise_not(finished).any(): + break + word_ids = word_ids.view(1, -1, 1) + dec_out, dec_attn = self.decoder(word_ids, extended_memory, memory_lengths=extended_memory_seq_lens, step=step) + logits = self.generator(dec_out.squeeze(0)) + logits = torch.where(finished.view(-1, 1), eos_max_prob, logits).to(torch.float32) + log_probs = self.logsoftmax(logits.to(torch.float32)) + + total_probs = log_probs + torch.unsqueeze(cum_log_probs, 1) + total_probs = total_probs.view(-1, beam_size * self.vocab_size) + + # beamsearch + # _, sample_ids = torch.topk(total_probs, beam_size) + # sample_ids = sample_ids.view(-1) + + #diversesiblingsearch + sibling_score = torch.arange(1, beam_size+1).to(total_probs.dtype).to(extended_memory.device) * self.diversity_rate # [beam_size] + scores, ids = torch.topk(total_probs.view(-1, beam_size, self.vocab_size), beam_size) # [batch size, beam width, beam width] + scores = scores + sibling_score # [batch size, beam width, beam width] + scores = scores.view(-1, beam_size * beam_size) + ids = ids + torch.unsqueeze(torch.unsqueeze(torch.arange(0, beam_size).to(extended_memory.device) * self.vocab_size, 0), -1) + ids = ids.view(-1, beam_size * beam_size) + _, final_ids = torch.topk(scores, beam_size) # [batch size, beam size] + final_ids = final_ids.view(-1, 1) + batch_index = torch.arange(0, batch_size).to(extended_memory.device).view(-1, 1).repeat(1, beam_size).view(-1, 1) + index = torch.cat([batch_index, final_ids], 1) + sample_ids = gather_nd(ids, index) + + word_ids = sample_ids % self.vocab_size # [batch_size * beam_size] + beam_ids = sample_ids // self.vocab_size # [batch_size * beam_size] + beam_indices = (torch.arange(batchxbeam).to(extended_memory.device) // beam_size) * beam_size + beam_ids + + sequence_lengths = torch.where(finished, sequence_lengths, sequence_lengths + 1) + + batch_pos = torch.arange(batchxbeam).to(extended_memory.device) // beam_size + next_cum_log_probs = gather_nd(total_probs, torch.stack([batch_pos, sample_ids], -1)) # [batch_size * beam_size] + finished = finished.index_select(0, beam_indices) + sequence_lengths = sequence_lengths.index_select(0, beam_indices) + + self.decoder.map_state(lambda state, dim: state.index_select(dim, beam_indices)) + if step == 0: + parent_ids = beam_ids.view(1, -1) + output_ids = word_ids.view(1, -1) + else: + parent_ids = torch.cat((parent_ids, beam_ids.view(1, -1))) + output_ids = torch.cat((output_ids, word_ids.view(1, -1))) + cum_log_probs = torch.where(finished, cum_log_probs, next_cum_log_probs) + finished = torch.bitwise_or(finished, torch.eq(word_ids, self.end_id)) + + beams, lengths = finalize(beam_size, output_ids, parent_ids, sequence_lengths, self.end_id, args=self.args) + return beams, lengths diff --git a/sample/pytorch/utils/encoder.py b/sample/pytorch/utils/encoder.py new file mode 100644 index 000000000..3ba06de64 --- /dev/null +++ b/sample/pytorch/utils/encoder.py @@ -0,0 +1,173 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from typing import List + +import sys +import torch + +from transformers import BertConfig +from transformers.modeling_bert import BertEncoder + + +class EncoderWeights(object): + def __init__(self, layer_num, hidden_dim, weights=None): + self.layer_num = layer_num + self.w = [[] for _ in range(layer_num)] + if weights: + if isinstance(weights, dict): + for i in range(layer_num): + pre = 'bert.encoder.layer.' + str(i) + '.' + self.w[i].append(weights[pre + 'attention.self.query.weight'].transpose(-1, -2).contiguous()) + self.w[i].append(weights[pre + 'attention.self.query.bias']) + self.w[i].append(weights[pre + 'attention.self.key.weight'].transpose(-1, -2).contiguous()) + self.w[i].append(weights[pre + 'attention.self.key.bias']) + self.w[i].append(weights[pre + 'attention.self.value.weight'].transpose(-1, -2).contiguous()) + self.w[i].append(weights[pre + 'attention.self.value.bias']) + self.w[i].append(weights[pre + 'attention.output.dense.weight'].transpose(-1, -2).contiguous()) + self.w[i].append(weights[pre + 'attention.output.dense.bias']) + self.w[i].append(weights[pre + 'attention.output.LayerNorm.weight']) + self.w[i].append(weights[pre + 'attention.output.LayerNorm.bias']) + self.w[i].append(weights[pre + 'intermediate.dense.weight'].transpose(-1, -2).contiguous()) + self.w[i].append(weights[pre + 'intermediate.dense.bias']) + self.w[i].append(weights[pre + 'output.dense.weight'].transpose(-1, -2).contiguous()) + self.w[i].append(weights[pre + 'output.dense.bias']) + self.w[i].append(weights[pre + 'output.LayerNorm.weight']) + self.w[i].append(weights[pre + 'output.LayerNorm.bias']) + else: + for i in range(layer_num): + self.w[i].append(weights.layer[i].attention.self.query.weight.data.transpose(-1, -2).contiguous()) + self.w[i].append(weights.layer[i].attention.self.query.bias.data) + self.w[i].append(weights.layer[i].attention.self.key.weight.data.transpose(-1, -2).contiguous()) + self.w[i].append(weights.layer[i].attention.self.key.bias.data) + self.w[i].append(weights.layer[i].attention.self.value.weight.data.transpose(-1, -2).contiguous()) + self.w[i].append(weights.layer[i].attention.self.value.bias.data) + self.w[i].append(weights.layer[i].attention.output.dense.weight.data.transpose(-1, -2).contiguous()) + self.w[i].append(weights.layer[i].attention.output.dense.bias.data) + self.w[i].append(weights.layer[i].attention.output.LayerNorm.weight.data) + self.w[i].append(weights.layer[i].attention.output.LayerNorm.bias.data) + self.w[i].append(weights.layer[i].intermediate.dense.weight.data.transpose(-1, -2).contiguous()) + self.w[i].append(weights.layer[i].intermediate.dense.bias.data) + self.w[i].append(weights.layer[i].output.dense.weight.data.transpose(-1, -2).contiguous()) + self.w[i].append(weights.layer[i].output.dense.bias.data) + self.w[i].append(weights.layer[i].output.LayerNorm.weight.data) + self.w[i].append(weights.layer[i].output.LayerNorm.bias.data) + else: + for layer_weights in self.w: + layer_weights.append(torch.zeros(hidden_dim, hidden_dim)) # q_kernel + layer_weights.append(torch.zeros(hidden_dim)) # q_bias + layer_weights.append(torch.zeros(hidden_dim, hidden_dim)) # k_kernel + layer_weights.append(torch.zeros(hidden_dim)) # k_bias + layer_weights.append(torch.zeros(hidden_dim, hidden_dim)) # v_kernel + layer_weights.append(torch.zeros(hidden_dim)) # v_bias + layer_weights.append(torch.zeros(hidden_dim, hidden_dim)) # attr_output_kernel + layer_weights.append(torch.zeros(hidden_dim)) # attr_output_bias + layer_weights.append(torch.zeros(hidden_dim)) # attr_output_layernorm_beta + layer_weights.append(torch.zeros(hidden_dim)) # attr_output_layernorm_gamma + layer_weights.append(torch.zeros(hidden_dim, 4 * hidden_dim)) # inter_kernel + layer_weights.append(torch.zeros(4 * hidden_dim)) # inter_bias + layer_weights.append(torch.zeros(4 * hidden_dim, hidden_dim)) # output_kernel + layer_weights.append(torch.zeros(hidden_dim)) # output_bias + layer_weights.append(torch.zeros(hidden_dim)) # output_layernorm_beta + layer_weights.append(torch.zeros(hidden_dim)) # output_layernorm_gamma + for i in range(len(layer_weights)): + torch.nn.init.uniform_(layer_weights[i], -1, 1) + + def to_cuda(self): + for i in range(self.layer_num): + for j in range(len(self.w[i])): + self.w[i][j] = self.w[i][j].cuda() + + def to_half(self): + for i in range(self.layer_num): + for j in range(len(self.w[i])): + self.w[i][j] = self.w[i][j].half() + + +class CustomEncoder(torch.nn.Module): + def __init__(self, layer_num, head_num, head_size, weights, path='./', use_ths=False, remove_padding=False): + super().__init__() + self.layer_num = layer_num + self.encoders = [] + if use_ths: + torch.classes.load_library(path) + for i in range(layer_num): + self.encoders.append(torch.classes.FasterTransformerEncoder(head_num, head_size, remove_padding, *weights.w[i])) + else: + sys.path.insert(0, path) + from th_fastertransformer import FasterTransformerEncoder + for i in range(layer_num): + self.encoders.append(FasterTransformerEncoder(head_num, head_size, remove_padding, *weights.w[i])) + + def forward(self, hidden_states, attention_mask, sequence_lengths=torch.Tensor(0).to(torch.int).cuda()): + for i in range(self.layer_num): + hidden_states = self.encoders[i].forward(hidden_states, attention_mask, sequence_lengths) + return (hidden_states,) + + +class CustomEncoder2(torch.nn.Module): + w: List[List[torch.Tensor]] + def __init__(self, layer_num, head_num, head_size, weights, path='./', remove_padding=False): + super().__init__() + self.layer_num = layer_num + self.head_num = head_num + self.head_size = head_size + self.remove_padding = remove_padding + self.w = weights.w + torch.ops.load_library(path) + + def forward(self, hidden_states, attention_mask, sequence_lengths=torch.Tensor(0).to(torch.int).cuda()): + for i in range(self.layer_num): + hidden_states = torch.ops.fastertransformer.encoder(self.head_num, self.head_size, self.remove_padding, + *self.w[i], hidden_states, attention_mask, sequence_lengths) + return (hidden_states,) + + +class HuggingFaceEncoder(torch.nn.Module): + def __init__(self, layer_num, head_num, head_size, weights=None): + super().__init__() + hidden_dim = head_num * head_size + conf = BertConfig(hidden_size=hidden_dim, intermediate_size=4*hidden_dim, num_attention_heads=head_num, num_hidden_layers=layer_num) + self.encoder = BertEncoder(conf) + if isinstance(weights, dict): + w = {} + for k, v in weights.items(): + if k.startswith('bert.encoder'): + w[k[13:]] = weights[k] + self.encoder.load_state_dict(w) + else: + for i in range(layer_num): + self.encoder.layer[i].attention.self.query.weight.data = weights.w[i][0].transpose(-1, -2).contiguous() + self.encoder.layer[i].attention.self.query.bias.data = weights.w[i][1] + self.encoder.layer[i].attention.self.key.weight.data = weights.w[i][2].transpose(-1, -2).contiguous() + self.encoder.layer[i].attention.self.key.bias.data = weights.w[i][3] + self.encoder.layer[i].attention.self.value.weight.data = weights.w[i][4].transpose(-1, -2).contiguous() + self.encoder.layer[i].attention.self.value.bias.data = weights.w[i][5] + self.encoder.layer[i].attention.output.dense.weight.data = weights.w[i][6].transpose(-1, -2).contiguous() + self.encoder.layer[i].attention.output.dense.bias.data = weights.w[i][7] + self.encoder.layer[i].attention.output.LayerNorm.weight.data = weights.w[i][8] + self.encoder.layer[i].attention.output.LayerNorm.bias.data = weights.w[i][9] + self.encoder.layer[i].intermediate.dense.weight.data = weights.w[i][10].transpose(-1, -2).contiguous() + self.encoder.layer[i].intermediate.dense.bias.data = weights.w[i][11] + self.encoder.layer[i].output.dense.weight.data = weights.w[i][12].transpose(-1, -2).contiguous() + self.encoder.layer[i].output.dense.bias.data = weights.w[i][13] + self.encoder.layer[i].output.LayerNorm.weight.data = weights.w[i][14] + self.encoder.layer[i].output.LayerNorm.bias.data = weights.w[i][15] + self.head_mask = [None] * layer_num + + def forward(self, hidden_states, attention_mask): + extended_attention_mask = (1.0 - attention_mask) * -10000.0 + output = self.encoder(hidden_states, extended_attention_mask, self.head_mask) + return output diff --git a/sample/pytorch/utils/get_mrpc_data.py b/sample/pytorch/utils/get_mrpc_data.py new file mode 100644 index 000000000..79efcbe98 --- /dev/null +++ b/sample/pytorch/utils/get_mrpc_data.py @@ -0,0 +1,71 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import shutil +import argparse +import tempfile +import urllib.request +import zipfile + +MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt' +MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt' + +def format_mrpc(mrpc_dir): + print("Processing MRPC...") + mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt") + mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt") + urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file) + urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file) + assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file + assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file + urllib.request.urlretrieve('https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc', os.path.join(mrpc_dir, "dev_ids.tsv")) + + dev_ids = [] + with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh: + for row in ids_fh: + dev_ids.append(row.strip().split('\t')) + + with open(mrpc_train_file, encoding="utf8") as data_fh, \ + open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \ + open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh: + header = data_fh.readline() + train_fh.write(header) + dev_fh.write(header) + for row in data_fh: + label, id1, id2, s1, s2 = row.strip().split('\t') + if [id1, id2] in dev_ids: + dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) + else: + train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) + + with open(mrpc_test_file, encoding="utf8") as data_fh, \ + open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh: + header = data_fh.readline() + test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n") + for idx, row in enumerate(data_fh): + label, id1, id2, s1, s2 = row.strip().split('\t') + test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2)) + print("\tCompleted!") + +def main(arguments): + parser = argparse.ArgumentParser() + parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data') + args = parser.parse_args(arguments) + + format_mrpc(args.data_dir) + +if __name__ == '__main__': + sys.exit(main(sys.argv[1:])) \ No newline at end of file diff --git a/sample/pytorch/utils/modeling_bert.py b/sample/pytorch/utils/modeling_bert.py new file mode 100644 index 000000000..ee2669b19 --- /dev/null +++ b/sample/pytorch/utils/modeling_bert.py @@ -0,0 +1,218 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model modified from HuggingFace transformers. """ + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss + +from transformers.configuration_bert import BertConfig +from transformers.modeling_utils import PreTrainedModel, prune_linear_layer +from transformers.modeling_bert import BertPreTrainedModel, BertEmbeddings, BertEncoder, BertPooler + + +class BertModel(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) + + self.init_weights() + self.use_ext_encoder = False + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + ): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + if self.use_ext_encoder: + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask + elif attention_mask.dim() == 2: + extended_attention_mask = attention_mask[:, None, :].repeat(1, input_shape[1], 1) + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + ) + if self.use_ext_encoder: + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask + ) + else: + head_mask = [None] * self.config.num_hidden_layers + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + ) + + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + outputs = (sequence_output, pooled_output,) + encoder_outputs[ + 1: + ] # add hidden_states and attentions if they are here + return outputs # sequence_output, pooled_output, (hidden_states), (attentions) + + +class BertForQuestionAnswering(BertPreTrainedModel): + def __init__(self, config): + super(BertForQuestionAnswering, self).__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + ): + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + outputs = (start_logits, end_logits,) + return outputs # start_logits, end_logits + + def replace_encoder(self, encoder): + self.bert.use_ext_encoder = True + self.bert.encoder = encoder + +class BertForSequenceClassification(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + labels=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + ): + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here + + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + outputs = (loss,) + outputs + + return outputs # (loss), logits, (hidden_states), (attentions) + + def replace_encoder(self, encoder): + self.bert.use_ext_encoder = True + self.bert.encoder = encoder diff --git a/sample/pytorch/utils/recover_bpe.py b/sample/pytorch/utils/recover_bpe.py new file mode 100644 index 000000000..f1a2b3a0f --- /dev/null +++ b/sample/pytorch/utils/recover_bpe.py @@ -0,0 +1,38 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('infile', type=str) +parser.add_argument('outfile', type=str) +args = parser.parse_args() + +with open(args.infile, 'r') as infile: + with open(args.outfile, 'w') as outfile: + for line in infile.readlines(): + line = line.strip().split() + if line[-1] == '': + line.pop() + if line[0][0] == '▁': + s = line[0][1:] + else: + s = line[0] + for w in line[1:]: + if w[0] == '▁': + s += ' ' + w[1:] + else: + s += w + s += '\n' + outfile.write(s) diff --git a/sample/pytorch/utils/translation_model.py b/sample/pytorch/utils/translation_model.py new file mode 100644 index 000000000..978e42888 --- /dev/null +++ b/sample/pytorch/utils/translation_model.py @@ -0,0 +1,287 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import torch +import torch.nn as nn +from torch.nn.init import xavier_uniform_ + +import onmt.inputters as inputters +import onmt.modules +from onmt.encoders.transformer import TransformerEncoder + +from onmt.modules import Embeddings, VecEmbedding, CopyGenerator +from onmt.modules.util_class import Cast +from onmt.utils.misc import use_gpu +from onmt.utils.parse import ArgumentParser + +from .decoding import FTDecoderLayer, DecodingWeights, CustomDecoding, TorchDecoding, TransformerDecoder + + +def build_embeddings(opt, text_field, for_encoder=True): + """ + Args: + opt: the option in current environment. + text_field(TextMultiField): word and feats field. + for_encoder(bool): build Embeddings for encoder or decoder? + """ + emb_dim = opt.src_word_vec_size if for_encoder else opt.tgt_word_vec_size + + if opt.model_type == "vec" and for_encoder: + return VecEmbedding( + opt.feat_vec_size, + emb_dim, + position_encoding=opt.position_encoding, + dropout=(opt.dropout[0] if type(opt.dropout) is list + else opt.dropout), + ) + + pad_indices = [f.vocab.stoi[f.pad_token] for _, f in text_field] + word_padding_idx, feat_pad_indices = pad_indices[0], pad_indices[1:] + + num_embs = [len(f.vocab) for _, f in text_field] + num_word_embeddings, num_feat_embeddings = num_embs[0], num_embs[1:] + + fix_word_vecs = opt.fix_word_vecs_enc if for_encoder \ + else opt.fix_word_vecs_dec + + emb = Embeddings( + word_vec_size=emb_dim, + position_encoding=opt.position_encoding, + feat_merge=opt.feat_merge, + feat_vec_exponent=opt.feat_vec_exponent, + feat_vec_size=opt.feat_vec_size, + dropout=opt.dropout[0] if type(opt.dropout) is list else opt.dropout, + word_padding_idx=word_padding_idx, + feat_padding_idx=feat_pad_indices, + word_vocab_size=num_word_embeddings, + feat_vocab_sizes=num_feat_embeddings, + sparse=opt.optim == "sparseadam", + fix_word_vecs=fix_word_vecs + ) + return emb + + +def load_test_model(opt, args): + model_path = opt.models[0] + checkpoint = torch.load(model_path, + map_location=lambda storage, loc: storage) + + model_opt = ArgumentParser.ckpt_model_opts(checkpoint['opt']) + ArgumentParser.update_model_opts(model_opt) + ArgumentParser.validate_model_opts(model_opt) + vocab = checkpoint['vocab'] + if inputters.old_style_vocab(vocab): + fields = inputters.load_old_vocab( + vocab, opt.data_type, dynamic_dict=model_opt.copy_attn + ) + else: + fields = vocab + + model = build_base_model(model_opt, fields, use_gpu(opt), args, checkpoint, + opt.gpu) + if args.data_type == 'fp32': + model.float() + elif args.data_type == 'fp16': + model.half() + else: + raise ValueError('wrong data_type argument {}'.format(args.data_type)) + model.eval() + model.generator.eval() + return fields, model, model_opt + + +def build_base_model(model_opt, fields, gpu, args, checkpoint=None, gpu_id=None): + """Build a model from opts. + + Args: + model_opt: the option loaded from checkpoint. It's important that + the opts have been updated and validated. See + :class:`onmt.utils.parse.ArgumentParser`. + fields (dict[str, torchtext.data.Field]): + `Field` objects for the model. + gpu (bool): whether to use gpu. + checkpoint: the model gnerated by train phase, or a resumed snapshot + model from a stopped training. + gpu_id (int or NoneType): Which GPU to use. + + Returns: + the NMTModel. + """ + + # for back compat when attention_dropout was not defined + try: + model_opt.attention_dropout + except AttributeError: + model_opt.attention_dropout = model_opt.dropout + + # Build embeddings. + if model_opt.model_type == "text" or model_opt.model_type == "vec": + src_field = fields["src"] + src_emb = build_embeddings(model_opt, src_field) + else: + src_emb = None + + # Build encoder. + encoder = TransformerEncoder.from_opt(model_opt, src_emb) + + # Build decoder. + tgt_field = fields["tgt"] + tgt_emb = build_embeddings(model_opt, tgt_field, for_encoder=False) + + # Share the embedding matrix - preprocess with share_vocab required. + if model_opt.share_embeddings: + # src/tgt vocab should be the same if `-share_vocab` is specified. + assert src_field.base_field.vocab == tgt_field.base_field.vocab, \ + "preprocess with -share_vocab if you use share_embeddings" + + tgt_emb.word_lut.weight = src_emb.word_lut.weight + + decoder = TransformerDecoder.from_opt(model_opt, tgt_emb, args) + + # Build NMTModel(= encoder + decoder). + if gpu and gpu_id is not None: + device = torch.device("cuda", gpu_id) + elif gpu and not gpu_id: + device = torch.device("cuda") + elif not gpu: + device = torch.device("cpu") + model = onmt.models.NMTModel(encoder, decoder) + + # Build Generator. + if not model_opt.copy_attn: + if model_opt.generator_function == "sparsemax": + gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1) + else: + gen_func = nn.LogSoftmax(dim=-1) + generator = nn.Sequential( + nn.Linear(model_opt.dec_rnn_size, + len(fields["tgt"].base_field.vocab)), + Cast(torch.float32), + gen_func + ) + if model_opt.share_decoder_embeddings: + generator[0].weight = decoder.embeddings.word_lut.weight + else: + tgt_base_field = fields["tgt"].base_field + vocab_size = len(tgt_base_field.vocab) + pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token] + generator = CopyGenerator(model_opt.dec_rnn_size, vocab_size, pad_idx) + if model_opt.share_decoder_embeddings: + generator.linear.weight = decoder.embeddings.word_lut.weight + + # Load the model states from checkpoint or initialize them. + if checkpoint is not None: + # This preserves backward-compat for models using customed layernorm + def fix_key(s): + s = re.sub(r'(.*)\.layer_norm((_\d+)?)\.b_2', + r'\1.layer_norm\2.bias', s) + s = re.sub(r'(.*)\.layer_norm((_\d+)?)\.a_2', + r'\1.layer_norm\2.weight', s) + return s + + checkpoint['model'] = {fix_key(k): v + for k, v in checkpoint['model'].items()} + # end of patch for backward compatibility + + model.load_state_dict(checkpoint['model'], strict=False) + generator.load_state_dict(checkpoint['generator'], strict=False) + + if args.model_type == 'decoder_ext': + w = [] + for i in range(model_opt.dec_layers): + w.append([ + decoder.transformer_layers[i].layer_norm_1.weight.data, + decoder.transformer_layers[i].layer_norm_1.bias.data, + decoder.transformer_layers[i].self_attn.linear_query.weight.data.transpose(-1, -2).contiguous(), + decoder.transformer_layers[i].self_attn.linear_keys.weight.data.transpose(-1, -2).contiguous(), + decoder.transformer_layers[i].self_attn.linear_values.weight.data.transpose(-1, -2).contiguous(), + decoder.transformer_layers[i].self_attn.linear_query.bias.data, + decoder.transformer_layers[i].self_attn.linear_keys.bias.data, + decoder.transformer_layers[i].self_attn.linear_values.bias.data, + decoder.transformer_layers[i].self_attn.final_linear.weight.data.transpose(-1, -2).contiguous(), + decoder.transformer_layers[i].self_attn.final_linear.bias.data, + decoder.transformer_layers[i].layer_norm_2.weight.data, + decoder.transformer_layers[i].layer_norm_2.bias.data, + decoder.transformer_layers[i].context_attn.linear_query.weight.data.transpose(-1, -2).contiguous(), + decoder.transformer_layers[i].context_attn.linear_keys.weight.data.transpose(-1, -2).contiguous(), + decoder.transformer_layers[i].context_attn.linear_values.weight.data.transpose(-1, -2).contiguous(), + decoder.transformer_layers[i].context_attn.linear_query.bias.data, + decoder.transformer_layers[i].context_attn.linear_keys.bias.data, + decoder.transformer_layers[i].context_attn.linear_values.bias.data, + decoder.transformer_layers[i].context_attn.final_linear.weight.data.transpose(-1, -2).contiguous(), + decoder.transformer_layers[i].context_attn.final_linear.bias.data, + decoder.transformer_layers[i].feed_forward.layer_norm.weight.data, + decoder.transformer_layers[i].feed_forward.layer_norm.bias.data, + decoder.transformer_layers[i].feed_forward.w_1.weight.data.transpose(-1, -2).contiguous(), + decoder.transformer_layers[i].feed_forward.w_1.bias.data, + decoder.transformer_layers[i].feed_forward.w_2.weight.data.transpose(-1, -2).contiguous(), + decoder.transformer_layers[i].feed_forward.w_2.bias.data + ]) + for i in range(len(w[-1])): + w[-1][i] = w[-1][i].cuda() + if args.data_type == 'fp16': + for i in range(len(w[-1])): + w[-1][i] = w[-1][i].half() + decoder_layers = nn.ModuleList( + [FTDecoderLayer(model_opt.heads, model_opt.dec_rnn_size // model_opt.heads, w[i], args) for i in range(model_opt.dec_layers)]) + model.decoder.transformer_layers = decoder_layers + elif args.model_type == 'decoding_ext': + vocab_size = len(fields["tgt"].base_field.vocab) + bos_idx = fields["tgt"].base_field.vocab.stoi[fields["tgt"].base_field.init_token] + eos_idx = fields["tgt"].base_field.vocab.stoi[fields["tgt"].base_field.eos_token] + decoding_weights = DecodingWeights(model_opt.dec_layers, model_opt.dec_rnn_size, vocab_size, checkpoint) + decoding_weights.to_cuda() + if args.data_type == 'fp16': + decoding_weights.to_half() + model.decoder = CustomDecoding(model_opt.dec_layers, model_opt.heads, model_opt.dec_rnn_size // model_opt.heads, + vocab_size, bos_idx, eos_idx, decoding_weights, args=args) + elif args.model_type == 'torch_decoding' or args.model_type == 'torch_decoding_with_decoder_ext': + vocab_size = len(fields["tgt"].base_field.vocab) + bos_idx = fields["tgt"].base_field.vocab.stoi[fields["tgt"].base_field.init_token] + eos_idx = fields["tgt"].base_field.vocab.stoi[fields["tgt"].base_field.eos_token] + decoding_weights = DecodingWeights(model_opt.dec_layers, model_opt.dec_rnn_size, vocab_size, checkpoint) + decoding_weights.to_cuda() + if args.data_type == 'fp16': + decoding_weights.to_half() + model.decoder = TorchDecoding(model_opt.dec_layers, model_opt.heads, model_opt.dec_rnn_size // model_opt.heads, + vocab_size, bos_idx, eos_idx, decoding_weights, args=args) + + else: + if model_opt.param_init != 0.0: + for p in model.parameters(): + p.data.uniform_(-model_opt.param_init, model_opt.param_init) + for p in generator.parameters(): + p.data.uniform_(-model_opt.param_init, model_opt.param_init) + if model_opt.param_init_glorot: + for p in model.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + for p in generator.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + + if hasattr(model.encoder, 'embeddings'): + model.encoder.embeddings.load_pretrained_vectors( + model_opt.pre_word_vecs_enc) + if hasattr(model.decoder, 'embeddings'): + model.decoder.embeddings.load_pretrained_vectors( + model_opt.pre_word_vecs_dec) + + model.generator = generator + model.to(device) + if model_opt.model_dtype == 'fp16' and model_opt.optim == 'fusedadam': + model.half() + return model diff --git a/sample/pytorch/utils/translator.py b/sample/pytorch/utils/translator.py new file mode 100644 index 000000000..4e77a93ff --- /dev/null +++ b/sample/pytorch/utils/translator.py @@ -0,0 +1,675 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import codecs +import os +import time +import numpy as np +from itertools import count, zip_longest + +import torch + +import onmt.model_builder +import onmt.inputters as inputters +import onmt.decoders.ensemble +from onmt.translate.beam_search import BeamSearch +from onmt.translate.greedy_search import GreedySearch +from onmt.utils.misc import tile, set_random_seed, report_matrix +from onmt.utils.alignment import extract_alignment, build_align_pharaoh +from onmt.modules.copy_generator import collapse_copy_scores + + +def max_tok_len(new, count, sofar): + """ + In token batching scheme, the number of sequences is limited + such that the total number of src/tgt tokens (including padding) + in a batch <= batch_size + """ + # Maintains the longest src and tgt length in the current batch + global max_src_in_batch # this is a hack + # Reset current longest length at a new batch (count=1) + if count == 1: + max_src_in_batch = 0 + # max_tgt_in_batch = 0 + # Src: [ w1 ... wN ] + max_src_in_batch = max(max_src_in_batch, len(new.src[0]) + 2) + # Tgt: [w1 ... wM ] + src_elements = count * max_src_in_batch + return src_elements + + +class Translator(object): + """Translate a batch of sentences with a saved model. + + Args: + model (onmt.modules.NMTModel): NMT model to use for translation + fields (dict[str, torchtext.data.Field]): A dict + mapping each side to its list of name-Field pairs. + src_reader (onmt.inputters.DataReaderBase): Source reader. + tgt_reader (onmt.inputters.TextDataReader): Target reader. + gpu (int): GPU device. Set to negative for no GPU. + n_best (int): How many beams to wait for. + min_length (int): See + :class:`onmt.translate.decode_strategy.DecodeStrategy`. + max_length (int): See + :class:`onmt.translate.decode_strategy.DecodeStrategy`. + beam_size (int): Number of beams. + random_sampling_topk (int): See + :class:`onmt.translate.greedy_search.GreedySearch`. + random_sampling_temp (int): See + :class:`onmt.translate.greedy_search.GreedySearch`. + stepwise_penalty (bool): Whether coverage penalty is applied every step + or not. + dump_beam (bool): Debugging option. + block_ngram_repeat (int): See + :class:`onmt.translate.decode_strategy.DecodeStrategy`. + ignore_when_blocking (set or frozenset): See + :class:`onmt.translate.decode_strategy.DecodeStrategy`. + replace_unk (bool): Replace unknown token. + data_type (str): Source data type. + verbose (bool): Print/log every translation. + report_time (bool): Print/log total time/frequency. + copy_attn (bool): Use copy attention. + global_scorer (onmt.translate.GNMTGlobalScorer): Translation + scoring/reranking object. + out_file (TextIO or codecs.StreamReaderWriter): Output file. + report_score (bool) : Whether to report scores + logger (logging.Logger or NoneType): Logger. + """ + + def __init__( + self, + model, + fields, + src_reader, + tgt_reader, + model_type='ori', + gpu=-1, + n_best=1, + min_length=0, + max_length=100, + ratio=0., + beam_size=30, + random_sampling_topk=1, + random_sampling_temp=1, + stepwise_penalty=None, + dump_beam=False, + block_ngram_repeat=0, + ignore_when_blocking=frozenset(), + replace_unk=False, + phrase_table="", + data_type="text", + verbose=False, + report_time=False, + copy_attn=False, + global_scorer=None, + out_file=None, + report_align=False, + report_score=True, + logger=None, + seed=-1): + self.model = model + self.fields = fields + tgt_field = dict(self.fields)["tgt"].base_field + self._tgt_vocab = tgt_field.vocab + self._tgt_eos_idx = self._tgt_vocab.stoi[tgt_field.eos_token] + self._tgt_pad_idx = self._tgt_vocab.stoi[tgt_field.pad_token] + self._tgt_bos_idx = self._tgt_vocab.stoi[tgt_field.init_token] + self._tgt_unk_idx = self._tgt_vocab.stoi[tgt_field.unk_token] + self._tgt_vocab_len = len(self._tgt_vocab) + + self.model_type = model_type + self._gpu = gpu + self._use_cuda = gpu > -1 + self._dev = torch.device("cuda", self._gpu) \ + if self._use_cuda else torch.device("cpu") + + self.n_best = n_best + self.max_length = max_length + + self.beam_size = beam_size + self.random_sampling_temp = random_sampling_temp + self.sample_from_topk = random_sampling_topk + + self.min_length = min_length + self.ratio = ratio + self.stepwise_penalty = stepwise_penalty + self.dump_beam = dump_beam + self.block_ngram_repeat = block_ngram_repeat + self.ignore_when_blocking = ignore_when_blocking + self._exclusion_idxs = { + self._tgt_vocab.stoi[t] for t in self.ignore_when_blocking} + self.src_reader = src_reader + self.tgt_reader = tgt_reader + self.replace_unk = replace_unk + if self.replace_unk and not self.model.decoder.attentional: + raise ValueError( + "replace_unk requires an attentional decoder.") + self.phrase_table = phrase_table + self.data_type = data_type + self.verbose = verbose + self.report_time = report_time + + self.copy_attn = copy_attn + + self.global_scorer = global_scorer + if self.global_scorer.has_cov_pen and \ + not self.model.decoder.attentional: + raise ValueError( + "Coverage penalty requires an attentional decoder.") + self.out_file = out_file + self.report_align = report_align + self.report_score = report_score + self.logger = logger + + self.use_filter_pred = False + self._filter_pred = None + + set_random_seed(seed, self._use_cuda) + + @classmethod + def from_opt( + cls, + model, + fields, + opt, + model_opt, + args, + global_scorer=None, + out_file=None, + report_align=False, + report_score=True, + logger=None): + """Alternate constructor. + + Args: + model (onmt.modules.NMTModel): See :func:`__init__()`. + fields (dict[str, torchtext.data.Field]): See + :func:`__init__()`. + opt (argparse.Namespace): Command line options + model_opt (argparse.Namespace): Command line options saved with + the model checkpoint. + global_scorer (onmt.translate.GNMTGlobalScorer): See + :func:`__init__()`.. + out_file (TextIO or codecs.StreamReaderWriter): See + :func:`__init__()`. + report_align (bool) : See :func:`__init__()`. + report_score (bool) : See :func:`__init__()`. + logger (logging.Logger or NoneType): See :func:`__init__()`. + """ + + src_reader = inputters.str2reader[opt.data_type].from_opt(opt) + tgt_reader = inputters.str2reader["text"].from_opt(opt) + return cls( + model, + fields, + src_reader, + tgt_reader, + model_type=args.model_type, + gpu=opt.gpu, + n_best=opt.n_best, + min_length=opt.min_length, + max_length=opt.max_length, + ratio=opt.ratio, + beam_size=opt.beam_size, + random_sampling_topk=opt.random_sampling_topk, + random_sampling_temp=opt.random_sampling_temp, + stepwise_penalty=opt.stepwise_penalty, + dump_beam=opt.dump_beam, + block_ngram_repeat=opt.block_ngram_repeat, + ignore_when_blocking=set(opt.ignore_when_blocking), + replace_unk=opt.replace_unk, + phrase_table=opt.phrase_table, + data_type=opt.data_type, + verbose=opt.verbose, + report_time=opt.report_time, + copy_attn=model_opt.copy_attn, + global_scorer=global_scorer, + out_file=out_file, + report_align=report_align, + report_score=report_score, + logger=logger, + seed=opt.seed) + + def _log(self, msg): + if self.logger: + self.logger.info(msg) + else: + print(msg) + + def _gold_score(self, batch, memory_bank, src_lengths, src_vocabs, + use_src_map, enc_states, batch_size, src): + if "tgt" in batch.__dict__: + gs = self._score_target( + batch, memory_bank, src_lengths, src_vocabs, + batch.src_map if use_src_map else None) + self.model.decoder.init_state(src, memory_bank, enc_states) + else: + gs = [0] * batch_size + return gs + + def translate( + self, + src, + tgt=None, + src_dir=None, + batch_size=None, + batch_type="sents", + attn_debug=False, + align_debug=False, + phrase_table=""): + """Translate content of ``src`` and get gold scores from ``tgt``. + + Args: + src: See :func:`self.src_reader.read()`. + tgt: See :func:`self.tgt_reader.read()`. + src_dir: See :func:`self.src_reader.read()` (only relevant + for certain types of data). + batch_size (int): size of examples per mini-batch + attn_debug (bool): enables the attention logging + align_debug (bool): enables the word alignment logging + + Returns: + (`list`, `list`) + + * all_scores is a list of `batch_size` lists of `n_best` scores + * all_predictions is a list of `batch_size` lists + of `n_best` predictions + """ + + if batch_size is None: + raise ValueError("batch_size must be set") + + src_data = {"reader": self.src_reader, "data": src, "dir": src_dir} + tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None} + _readers, _data, _dir = inputters.Dataset.config( + [('src', src_data), ('tgt', tgt_data)]) + + # corpus_id field is useless here + if self.fields.get("corpus_id", None) is not None: + self.fields.pop('corpus_id') + data = inputters.Dataset( + self.fields, readers=_readers, data=_data, dirs=_dir, + sort_key=inputters.str2sortkey[self.data_type], + filter_pred=self._filter_pred + ) + + data_iter = inputters.OrderedIterator( + dataset=data, + device=self._dev, + batch_size=batch_size, + batch_size_fn=max_tok_len if batch_type == "tokens" else None, + train=False, + sort=False, + sort_within_batch=True, + shuffle=False + ) + + xlation_builder = onmt.translate.TranslationBuilder( + data, self.fields, self.n_best, self.replace_unk, tgt, + self.phrase_table + ) + + # Statistics + counter = count(1) + pred_score_total, pred_words_total = 0, 0 + gold_score_total, gold_words_total = 0, 0 + + all_scores = [] + all_predictions = [] + + start_time = time.time() + + for batch in data_iter: + if self.model_type == 'decoding_ext' or self.model_type == 'torch_decoding' or self.model_type == 'torch_decoding_with_decoder_ext': + batch_data = self.translate_batch_ftdecoding(batch, data.src_vocabs) + else: + batch_data = self.translate_batch( + batch, data.src_vocabs, attn_debug + ) + translations = xlation_builder.from_batch(batch_data) + + for trans in translations: + all_scores += [trans.pred_scores[:self.n_best]] + pred_score_total += trans.pred_scores[0] + pred_words_total += len(trans.pred_sents[0]) + if tgt is not None: + gold_score_total += trans.gold_score + gold_words_total += len(trans.gold_sent) + 1 + + n_best_preds = [" ".join(pred) + for pred in trans.pred_sents[:self.n_best]] + if self.report_align: + align_pharaohs = [build_align_pharaoh(align) for align + in trans.word_aligns[:self.n_best]] + n_best_preds_align = [" ".join(align) for align + in align_pharaohs] + n_best_preds = [pred + " ||| " + align + for pred, align in zip( + n_best_preds, n_best_preds_align)] + all_predictions += [n_best_preds] + self.out_file.write('\n'.join(n_best_preds) + '\n') + self.out_file.flush() + + if self.verbose: + sent_number = next(counter) + output = trans.log(sent_number) + if self.logger: + self.logger.info(output) + else: + os.write(1, output.encode('utf-8')) + + if attn_debug: + preds = trans.pred_sents[0] + preds.append('') + attns = trans.attns[0].tolist() + if self.data_type == 'text': + srcs = trans.src_raw + else: + srcs = [str(item) for item in range(len(attns[0]))] + output = report_matrix(srcs, preds, attns) + if self.logger: + self.logger.info(output) + else: + os.write(1, output.encode('utf-8')) + + if align_debug: + if trans.gold_sent is not None: + tgts = trans.gold_sent + else: + tgts = trans.pred_sents[0] + align = trans.word_aligns[0].tolist() + if self.data_type == 'text': + srcs = trans.src_raw + else: + srcs = [str(item) for item in range(len(align[0]))] + output = report_matrix(srcs, tgts, align) + if self.logger: + self.logger.info(output) + else: + os.write(1, output.encode('utf-8')) + + end_time = time.time() + + if self.report_score: + msg = self._report_score('PRED', pred_score_total, + pred_words_total) + self._log(msg) + if tgt is not None: + msg = self._report_score('GOLD', gold_score_total, + gold_words_total) + self._log(msg) + + if self.report_time: + total_time = end_time - start_time + print("Total translation time (s): %f" % total_time) + print("Average translation time (s): %f" % ( + total_time / len(all_predictions))) + print("Tokens per second: %f" % ( + pred_words_total / total_time)) + + if self.dump_beam: + import json + json.dump(self.translator.beam_accum, + codecs.open(self.dump_beam, 'w', 'utf-8')) + return all_scores, all_predictions + + def translate_batch(self, batch, src_vocabs, attn_debug): + """Translate a batch of sentences.""" + with torch.no_grad(): + if self.beam_size == 1: + decode_strategy = GreedySearch( + pad=self._tgt_pad_idx, + bos=self._tgt_bos_idx, + eos=self._tgt_eos_idx, + batch_size=batch.batch_size, + min_length=self.min_length, max_length=self.max_length, + block_ngram_repeat=self.block_ngram_repeat, + exclusion_tokens=self._exclusion_idxs, + return_attention=attn_debug or self.replace_unk, + sampling_temp=self.random_sampling_temp, + keep_topk=self.sample_from_topk) + else: + # TODO: support these blacklisted features + assert not self.dump_beam + decode_strategy = BeamSearch( + self.beam_size, + batch_size=batch.batch_size, + pad=self._tgt_pad_idx, + bos=self._tgt_bos_idx, + eos=self._tgt_eos_idx, + n_best=self.n_best, + global_scorer=self.global_scorer, + min_length=self.min_length, max_length=self.max_length, + return_attention=attn_debug or self.replace_unk, + block_ngram_repeat=self.block_ngram_repeat, + exclusion_tokens=self._exclusion_idxs, + stepwise_penalty=self.stepwise_penalty, + ratio=self.ratio) + return self._translate_batch_with_strategy(batch, src_vocabs, + decode_strategy) + + def translate_batch_ftdecoding(self, batch, src_vocabs): + with torch.no_grad(): + use_src_map = self.copy_attn + batch_size = batch.batch_size + + src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) + + results = { + "predictions": None, + "scores": None, + "attention": None, + "batch": batch, + "gold_score": self._gold_score( + batch, memory_bank, src_lengths, src_vocabs, use_src_map, + enc_states, batch_size, src)} + + src_lengths_ = src_lengths.to(torch.int32) + memory_bank_ = memory_bank.transpose(0, 1).contiguous() + output, lengths = self.model.decoder(batch_size, self.beam_size, self.max_length, memory_bank_, src_lengths_) + + results["scores"] = [(0,) for _ in range(batch_size)] + results["predictions"] = output + results["attention"] = [[None] * self.n_best for _ in range(batch_size)] + results["alignment"] = [[] for _ in range(batch_size)] + return results + + def _run_encoder(self, batch): + src, src_lengths = batch.src if isinstance(batch.src, tuple) \ + else (batch.src, None) + + enc_states, memory_bank, src_lengths = self.model.encoder( + src, src_lengths) + if src_lengths is None: + assert not isinstance(memory_bank, tuple), \ + 'Ensemble decoding only supported for text data' + src_lengths = torch.Tensor(batch.batch_size) \ + .type_as(memory_bank) \ + .long() \ + .fill_(memory_bank.size(0)) + return src, enc_states, memory_bank, src_lengths + + def _decode_and_generate( + self, + decoder_in, + memory_bank, + batch, + src_vocabs, + memory_lengths, + src_map=None, + step=None, + batch_offset=None): + if self.copy_attn: + # Turn any copied words into UNKs. + decoder_in = decoder_in.masked_fill( + decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx + ) + + # Decoder forward, takes [tgt_len, batch, nfeats] as input + # and [src_len, batch, hidden] as memory_bank + # in case of inference tgt_len = 1, batch = beam times batch_size + # in case of Gold Scoring tgt_len = actual length, batch = 1 batch + dec_out, dec_attn = self.model.decoder( + decoder_in, memory_bank, memory_lengths=memory_lengths, step=step + ) + + # Generator forward. + if not self.copy_attn: + if "std" in dec_attn: + attn = dec_attn["std"] + else: + attn = None + log_probs = self.model.generator(dec_out.squeeze(0)) + # returns [(batch_size x beam_size) , vocab ] when 1 step + # or [ tgt_len, batch_size, vocab ] when full sentence + else: + attn = dec_attn["copy"] + scores = self.model.generator(dec_out.view(-1, dec_out.size(2)), + attn.view(-1, attn.size(2)), + src_map) + # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] + if batch_offset is None: + scores = scores.view(-1, batch.batch_size, scores.size(-1)) + scores = scores.transpose(0, 1).contiguous() + else: + scores = scores.view(-1, self.beam_size, scores.size(-1)) + scores = collapse_copy_scores( + scores, + batch, + self._tgt_vocab, + src_vocabs, + batch_dim=0, + batch_offset=batch_offset + ) + scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) + log_probs = scores.squeeze(0).log() + # returns [(batch_size x beam_size) , vocab ] when 1 step + # or [ tgt_len, batch_size, vocab ] when full sentence + return log_probs, attn + + def _translate_batch_with_strategy( + self, + batch, + src_vocabs, + decode_strategy): + """Translate a batch of sentences step by step using cache. + + Args: + batch: a batch of sentences, yield by data iterator. + src_vocabs (list): list of torchtext.data.Vocab if can_copy. + decode_strategy (DecodeStrategy): A decode strategy to use for + generate translation step by step. + + Returns: + results (dict): The translation results. + """ + # (0) Prep the components of the search. + use_src_map = self.copy_attn + parallel_paths = decode_strategy.parallel_paths # beam_size + batch_size = batch.batch_size + + # (1) Run the encoder on the src. + src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) + self.model.decoder.init_state(src, memory_bank, enc_states) + + results = { + "predictions": None, + "scores": None, + "attention": None, + "batch": batch, + "gold_score": self._gold_score( + batch, memory_bank, src_lengths, src_vocabs, use_src_map, + enc_states, batch_size, src)} + + # (2) prep decode_strategy. Possibly repeat src objects. + src_map = batch.src_map if use_src_map else None + fn_map_state, memory_bank, memory_lengths, src_map = \ + decode_strategy.initialize(memory_bank, src_lengths, src_map) + if fn_map_state is not None: + self.model.decoder.map_state(fn_map_state) + + # (3) Begin decoding step by step: + for step in range(decode_strategy.max_length): + decoder_input = decode_strategy.current_predictions.view(1, -1, 1) + + log_probs, attn = self._decode_and_generate( + decoder_input, + memory_bank, + batch, + src_vocabs, + memory_lengths=memory_lengths, + src_map=src_map, + step=step, + batch_offset=decode_strategy.batch_offset) + + decode_strategy.advance(log_probs, attn) + any_finished = decode_strategy.is_finished.any() + if any_finished: + decode_strategy.update_finished() + if decode_strategy.done: + break + + select_indices = decode_strategy.select_indices + + if any_finished: + # Reorder states. + if isinstance(memory_bank, tuple): + memory_bank = tuple(x.index_select(1, select_indices) + for x in memory_bank) + else: + memory_bank = memory_bank.index_select(1, select_indices) + + memory_lengths = memory_lengths.index_select(0, select_indices) + + if src_map is not None: + src_map = src_map.index_select(1, select_indices) + + if parallel_paths > 1 or any_finished: + self.model.decoder.map_state( + lambda state, dim: state.index_select(dim, select_indices)) + + results["scores"] = decode_strategy.scores + results["predictions"] = decode_strategy.predictions + results["attention"] = decode_strategy.attention + results["alignment"] = [[] for _ in range(batch_size)] + return results + + def _score_target(self, batch, memory_bank, src_lengths, + src_vocabs, src_map): + tgt = batch.tgt + tgt_in = tgt[:-1] + + log_probs, attn = self._decode_and_generate( + tgt_in, memory_bank, batch, src_vocabs, + memory_lengths=src_lengths, src_map=src_map) + + log_probs[:, :, self._tgt_pad_idx] = 0 + gold = tgt[1:] + gold_scores = log_probs.gather(2, gold) + gold_scores = gold_scores.sum(dim=0).view(-1) + + return gold_scores + + def _report_score(self, name, score_total, words_total): + if words_total == 0: + msg = "%s No words predicted" % (name,) + else: + avg_score = score_total / words_total + ppl = np.exp(-score_total.item() / words_total) + msg = ("%s AVG SCORE: %.4f, %s PPL: %.4f" % ( + name, avg_score, + name, ppl)) + return msg diff --git a/sample/tensorRT/CMakeLists.txt b/sample/tensorRT/CMakeLists.txt index 22a54b0d3..5d7fe119b 100644 --- a/sample/tensorRT/CMakeLists.txt +++ b/sample/tensorRT/CMakeLists.txt @@ -18,4 +18,4 @@ set(trt_files ) add_executable(transformer_trt ${trt_files}) -target_link_libraries(transformer_trt PRIVATE -lcublas -lcudart -lnvinfer fastertransformer) +target_link_libraries(transformer_trt PRIVATE -lcublas -lcudart -lnvinfer encoder) diff --git a/sample/tensorRT/transformer_trt.cc b/sample/tensorRT/transformer_trt.cc index 06e416d0c..3572d52d0 100644 --- a/sample/tensorRT/transformer_trt.cc +++ b/sample/tensorRT/transformer_trt.cc @@ -130,6 +130,7 @@ void run_bert_transformer(int batch_size, int seq_len, int layers, int head_num, TRT_Transformer* trt_transformer = new TRT_Transformer(batch_size, seq_len, head_num, hidden_dim, layers); trt_transformer->build_engine(params); + trt_transformer->do_inference(batch_size, h_from_tensor, h_attr_mask, h_transformer_out, stream); delete trt_transformer; diff --git a/sample/tensorflow/decoder_sample.py b/sample/tensorflow/decoder_sample.py index 5061a5cf8..9d09498c0 100644 --- a/sample/tensorflow/decoder_sample.py +++ b/sample/tensorflow/decoder_sample.py @@ -12,16 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. +''' +This is a sample code to demonstrate how to replace the decoder transformer +layer of TensorFlow by the decoder of FasterTransformer. + +This sample code builds a decoding model by TensorFlow, and user can replace +the decoder transformer layer of model by the decoder of FasterTransformer. +The other parts, including the embedding lookup, position encoder and beam +search, are still computed by TensorFlow. + +Namely, the baseline model is like: + embedding-lookup -> position encoding -> TensorFlow decoder -> beam search + User can build this model by using "-decoder 0" to set the decoder type +and the new model is like: + embedding-lookup -> position encoding -> FasterTransformer decoder -> beam search + User can build this model by using "-decoder 1" to set the decoder type + +If user wants to verify the correctness of decoder, they can use "-decoder 2", +which will run the both TensorFlow decoder and FasterTransformer in one model, +and compare their difference. + +Users are also able to use this sample code to test the average forward time of +TensorFlow and FasterTransformer. +''' + from __future__ import print_function import numpy as np import tensorflow as tf import argparse -import os -from utils.common import time_test, DecodingArgument -from utils.decoding import tf_decoding, generate_encoder_result +from utils.common import time_test +from utils.common import TransformerArgument +from utils.common import DecodingBeamsearchArgument +from utils.decoding import tf_beamsearch_decoding +from utils.decoding import generate_encoder_result if __name__ == "__main__": - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) parser.add_argument('-batch', '--batch_size', type=int, default=1, metavar='NUMBER', help='batch size (default: 1)') parser.add_argument('-beam', '--beam_width', type=int, default=4, metavar='NUMBER', @@ -39,19 +65,23 @@ parser.add_argument('-v', '--vocab_size', type=int, default=30000, metavar='BOOL', help='vocabulary size. (default: 30000).') parser.add_argument('-d', '--data_type', type=str, default="fp32", metavar='STRING', - help='data type (default: fp32)') + help='data type (default: fp32)', choices=['fp32', 'fp16']) parser.add_argument('-time', '--test_time', type=int, default=0, metavar='BOOL', - help='test the time or not. (default: False (0)), True is 1.') + help='test the time or not. (default: False (0)), True is 1.', + choices=[0, 1]) parser.add_argument('-decoder', '--decoder_type', type=int, default=2, metavar='NUMBER', - help='Decoder type:' - + ' type 0: only run tf decoder;' - + ' type 1: only run op decoder;' - + ' type 2: run both tf and op decoder, and compare the difference.' - + ' default: type 2') + help=''' + Decoder type: + type 0: only run tf decoder; + type 1: only run op decoder; + type 2: run both tf and op decoder, and compare the difference. + default: type 2 ''', choices=[0, 1, 2]) args = parser.parse_args() print("\n=============== Argument ===============") - print(args) + for key in vars(args): + print("{}: {}".format(key, vars(args)[key])) + print("========================================") start_of_sentence_id = 1 end_of_sentence_id = 2 @@ -72,39 +102,36 @@ vocab_size = args.vocab_size tf_datatype = tf.float32 np_datatype = np.float32 - atol_threshold = 2e-5 if args.data_type == "fp16": tf_datatype = tf.float16 np_datatype = np.float16 - atol_threshold = 2e-2 - decoding_args = DecodingArgument(batch_size=batch_size, - beam_width=beam_width, - head_num=head_num, - size_per_head=size_per_head, - num_layer=num_layer, - max_seq_len=max_seq_len, - vocab_size=vocab_size, - start_id=start_of_sentence_id, - end_id=end_of_sentence_id, - encoder_hidden_dim=memory_hidden_dim, - dtype=tf_datatype) - - embedding_table = np.random.randn(vocab_size, hidden_dim).astype( - np_datatype) * 0.01 # a [vocab_size, hidden_dim] table + decoder_args = TransformerArgument(beam_width=beam_width, + head_num=head_num, + size_per_head=size_per_head, + num_layer=num_layer, + dtype=tf_datatype, + kernel_init_range=kernel_initializer_range, + bias_init_range=bias_initializer_range, + fuse_qkv=False) + + decoding_args = DecodingBeamsearchArgument(vocab_size, + start_of_sentence_id, + end_of_sentence_id, + max_seq_len, + decoder_args, + 0.0) + + embedding_table = np.random.randn(vocab_size, hidden_dim).astype(np_datatype) * 0.01 # a [vocab_size, hidden_dim] table embedding_table = tf.convert_to_tensor(embedding_table) memory, memory_sequence_length = generate_encoder_result( batch_size, max_seq_len, memory_hidden_dim, tf_datatype) - - finalized_tf_output_ids, finalized_tf_sequence_lengths, tf_output_ids, \ - tf_parent_ids, tf_sequence_lengths = tf_decoding(memory, - memory_sequence_length, - embedding_table, - decoding_args, - args.decoder_type, - kernel_initializer_range, - bias_initializer_range, - atol_threshold) + + finalized_tf_output_ids, finalized_tf_sequence_lengths, _, _, _ = tf_beamsearch_decoding(memory, + memory_sequence_length, + embedding_table, + decoding_args, + decoder_type=args.decoder_type) config = tf.ConfigProto() config.gpu_options.allow_growth = True @@ -114,7 +141,10 @@ sess.run(finalized_tf_output_ids) if args.test_time == 1: - time_cost = time_test(sess, finalized_tf_output_ids, iterations=50) - types = ["TF", "OP", "TF+OP"] - print("[INFO] time costs of {} decoder: {} ms.".format( - types[args.decoder_type], time_cost)) + + time_cost = time_test(sess, finalized_tf_output_ids, iterations=10) + types = ["TF-decoding-beamsearch", "FT-OP-decoder", "TF+FT-OP"] + + print("[INFO] batch_size {} beam_width {} head_num {} size_per_head {} seq_len {} " \ + "decoder_layers {} vocab_size {} {}-time {:6.2f} ms.".format(batch_size, beam_width, head_num, size_per_head, + max_seq_len, num_layer, vocab_size, types[args.decoder_type], time_cost)) diff --git a/sample/tensorflow/decoding_sample.py b/sample/tensorflow/decoding_sample.py index 90934a478..ca683b30f 100644 --- a/sample/tensorflow/decoding_sample.py +++ b/sample/tensorflow/decoding_sample.py @@ -12,16 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. +''' +This is a sample code to demonstrate how to use the TensorFlow custom op with +FasterTransformer library in decoding. + +This sample code builds a decoding model by TensorFlow and TensorFlow custom +op. Compare 1. the results of TensorFlow decoding with beam search and +the results FasterTransformer decoding with beam search; and 2. the results +of TensorFlow decoding with sampling and the results FasterTransformer decoding +with sampling. + +Users are also able to use this sample code to test the average forward time of +TensorFlow and FasterTransformer. +''' + +import copy import numpy as np import argparse import tensorflow as tf -import os -from utils.common import time_test, DecodingArgument, int_result_cross_check -from utils.decoding import tf_decoding, generate_encoder_result, op_decoding +from utils.common import time_test +from utils.common import DecodingBeamsearchArgument +from utils.common import DecodingSamplingArgument +from utils.common import TransformerArgument +from utils.common import int_result_cross_check +from utils.decoding import tf_beamsearch_decoding +from utils.decoding import op_beamsearch_decoding +from utils.decoding import tf_sampling_decoding +from utils.decoding import op_sampling_decoding +from utils.decoding import generate_encoder_result if __name__ == "__main__": - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) parser.add_argument('-batch', '--batch_size', type=int, default=1, metavar='NUMBER', help='batch size (default: 1)') parser.add_argument('-beam', '--beam_width', type=int, default=4, metavar='NUMBER', @@ -39,21 +61,34 @@ parser.add_argument('-v', '--vocab_size', type=int, default=30000, metavar='BOOL', help='vocabulary size. (default: 30000).') parser.add_argument('-d', '--data_type', type=str, default="fp32", metavar='STRING', - help='data type (default: fp32)') + help='data type (default: fp32)', choices=['fp32', 'fp16']) parser.add_argument('-x', '--use_XLA', type=int, default=0, metavar='BOOL', - help='use XLA (default: False 0)') - parser.add_argument('-time', '--test_time', type=int, default=0, metavar='BOOL', - help='test the time or not. (default: False (0)), True is 1.') + help='use XLA (default: False 0)', choices=[0, 1]) + parser.add_argument('-time', '--test_time', type=str, default='', metavar='STRING', + help=''' + Test the time of which one (default: '' (not test anyone) ); + '': not test anyone + '0': test tf_decoding_beamsearch + '1': test op_decoding_beamsearch + '2': test tf_decoding_sampling + '3': test op_decoding_sampling + 'e.g., if you want to test tf_decoding_beamsearch and op_decoding_sampling, + then you need to use -time '02' ''') parser.add_argument('-check', '--cross_check', type=int, default=1, metavar='BOOL', - help='cross check the answer of TF and OP. (default: True (1)), False is 0.') - parser.add_argument('-op_time', '--test_op_time', type=int, default=0, metavar='BOOL', - help='test the op time or not. (default: False (0)), True is 1.') - parser.add_argument('-tf_time', '--test_tf_time', type=int, default=0, metavar='BOOL', - help='test the tf time or not. (default: False (0)), True is 1.') - + help='cross check the answer of TF and OP. (default: True (1)), False is 0.', + choices=[0, 1]) + parser.add_argument('-diversity_rate', '--beam_search_diversity_rate', type=float, default=0.0, metavar='NUMBER', + help='deviersity rate of beam search. default is 0. When diversity rate = 0, it is equivalent to the naive beams earch.') + parser.add_argument('-topk', '--sampling_topk', type=int, default=1, metavar='NUMBER', + help='Candidate (k) value of top k sampling in decoding. Default is 1.') + parser.add_argument('-topp', '--sampling_topp', type=float, default=0.0, metavar='NUMBER', + help='Probability (p) value of top p sampling in decoding. Default is 0.0. ') + args = parser.parse_args() print("\n=============== Argument ===============") - print(args) + for key in vars(args): + print("{}: {}".format(key, vars(args)[key])) + print("========================================") start_of_sentence_id = 1 end_of_sentence_id = 2 @@ -75,44 +110,72 @@ tf_datatype = tf.float16 np_datatype = np.float16 use_XLA = args.use_XLA + beam_search_diversity_rate = args.beam_search_diversity_rate + sampling_topk = args.sampling_topk + sampling_topp = args.sampling_topp hidden_dim = head_num * size_per_head memory_hidden_dim = args.memory_hidden_dim - - decoding_args = DecodingArgument(batch_size=batch_size, - beam_width=beam_width, - head_num=head_num, - size_per_head=size_per_head, - num_layer=num_layer, - max_seq_len=max_seq_len, - vocab_size=vocab_size, - start_id=start_of_sentence_id, - end_id=end_of_sentence_id, - encoder_hidden_dim=memory_hidden_dim, - dtype=tf_datatype) + + decoder_args = TransformerArgument(beam_width=beam_width, + head_num=head_num, + size_per_head=size_per_head, + num_layer=num_layer, + dtype=tf_datatype, + kernel_init_range=kernel_initializer_range, + bias_init_range=bias_initializer_range) + + decoding_args = DecodingBeamsearchArgument(vocab_size, + start_of_sentence_id, + end_of_sentence_id, + max_seq_len, + decoder_args, + beam_search_diversity_rate) + + decoder_args_2 = copy.deepcopy(decoder_args) # for beam search + decoder_args_2.__dict__ = copy.deepcopy(decoder_args.__dict__) + decoder_args_2.beam_width = 1 # for sampling + + decoding_sampling_args = DecodingSamplingArgument(vocab_size, + start_of_sentence_id, + end_of_sentence_id, + max_seq_len, + decoder_args_2, + sampling_topk, + sampling_topp) embedding_table = np.random.rand(vocab_size, hidden_dim).astype( np_datatype) # a [vocab_size, hidden_dim] table embedding_table = tf.convert_to_tensor(embedding_table) memory, memory_sequence_length = generate_encoder_result( batch_size, max_seq_len, memory_hidden_dim, tf_datatype) - + finalized_tf_output_ids, finalized_tf_sequence_lengths, tf_output_ids, \ - tf_parent_ids, tf_sequence_lengths = tf_decoding(memory, - memory_sequence_length, - embedding_table, - decoding_args, - 0, - kernel_initializer_range, - bias_initializer_range) + tf_parent_ids, tf_sequence_lengths = tf_beamsearch_decoding(memory, + memory_sequence_length, + embedding_table, + decoding_args, + decoder_type=0) all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) finalized_op_output_ids, finalized_op_sequence_lengths, op_output_ids, \ - op_parent_ids, op_sequence_lengths = op_decoding(memory, + op_parent_ids, op_sequence_lengths = op_beamsearch_decoding(memory, memory_sequence_length, embedding_table, all_vars, decoding_args) + + tf_sampling_target_ids, tf_sampling_target_length = tf_sampling_decoding(memory, + memory_sequence_length, + embedding_table, + decoding_sampling_args, + decoder_type=0) + + op_sampling_target_ids, op_sampling_target_length = op_sampling_decoding(memory, + memory_sequence_length, + embedding_table, + all_vars, + decoding_sampling_args) config = tf.ConfigProto() config.gpu_options.allow_growth = True @@ -121,7 +184,7 @@ with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) - + if args.cross_check == 1: finalized_tf_output_ids_result, tf_output_ids_result, tf_parent_ids_result, \ tf_sequence_lengths_result = sess.run( @@ -129,26 +192,55 @@ finalized_op_output_ids_result, op_output_ids_result, op_parent_ids_result, \ op_sequence_lengths_result = sess.run( [finalized_op_output_ids, op_output_ids, op_parent_ids, op_sequence_lengths]) - + + print("[INFO] BeamSearch cross check:") int_result_cross_check("Output ids", tf_output_ids_result, op_output_ids_result, - shape=[max_seq_len, batch_size * beam_width]) + shape=[batch_size, beam_width, max_seq_len]) int_result_cross_check("Parent ids", tf_parent_ids_result, op_parent_ids_result, - shape=[max_seq_len, batch_size * beam_width]) + shape=[batch_size, beam_width, max_seq_len]) int_result_cross_check("Sequence lengths", tf_sequence_lengths_result, - op_sequence_lengths_result, shape=[1, batch_size * beam_width]) + op_sequence_lengths_result, shape=[batch_size, beam_width, 1]) int_result_cross_check("Finalized output ids", finalized_tf_output_ids_result.T, finalized_op_output_ids_result.T, - shape=[max_seq_len, batch_size * beam_width]) - - if args.test_time == 1 or args.test_tf_time == 1 or args.test_op_time == 1: - if args.test_time == 1 or args.test_tf_time == 1: - tf_time_result = time_test( - sess, finalized_tf_output_ids, iterations=50, warmup=True) - if args.test_time == 1 or args.test_op_time == 1: - op_time_result = time_test( - sess, finalized_op_output_ids, iterations=50, warmup=True) + shape=[batch_size, beam_width, max_seq_len]) + + tf_sampling_ids, tf_sampling_length = sess.run([tf_sampling_target_ids, + tf_sampling_target_length]) + op_sampling_ids, op_sampling_length = sess.run([op_sampling_target_ids, + op_sampling_target_length]) + print("[INFO] Sampling cross check:") + int_result_cross_check("Output ids", tf_sampling_ids, op_sampling_ids, + shape=[batch_size, max_seq_len]) + int_result_cross_check("Sequence length", tf_sampling_length, op_sampling_length, + shape=[batch_size]) + - if args.test_time == 1 or args.test_tf_time == 1: - print("[INFO] TF execution time: {} ms".format(tf_time_result)) - if args.test_time == 1 or args.test_op_time == 1: - print("[INFO] OP execution time: {} ms".format(op_time_result)) + time_args = args.test_time + test_lists = [] + test_names = [] + if time_args.find("0") != -1: + test_lists.append(finalized_tf_output_ids) + test_names.append("TF-decoding-beamsearch") + if time_args.find("1") != -1: + test_lists.append(finalized_op_output_ids) + test_names.append("FT-OP-decoding-beamsearch") + if time_args.find("2") != -1: + test_lists.append(tf_sampling_target_ids) + test_names.append("TF-decoding-sampling") + if time_args.find("3") != -1: + test_lists.append(op_sampling_target_ids) + test_names.append("FT-OP-decoding-sampling") + + test_time_result = [] + for op in test_lists: + test_time_result.append(time_test(sess, op, iterations=10, warmup=True)) + + for name, t_result in zip(test_names, test_time_result): + if name.find("beamsearch") != -1: + print("[INFO] batch_size {} beam_width {} head_num {} size_per_head {} seq_len {} " \ + "decoder_layers {} vocab_size {} {}-time {:6.2f} ms.".format(batch_size, beam_width, head_num, size_per_head, + max_seq_len, num_layer, vocab_size, name, t_result)) + elif name.find("sampling") != -1: + print("[INFO] batch_size {} topk {} topp {} head_num {} size_per_head {} seq_len {} " \ + "decoder_layers {} vocab_size {} {}-time {:6.2f} ms.".format(batch_size, sampling_topk, sampling_topp, head_num, size_per_head, + max_seq_len, num_layer, vocab_size, name, t_result)) diff --git a/sample/tensorflow/encoder_decoder_sample.py b/sample/tensorflow/encoder_decoder_sample.py index e4e096cd8..986a9d260 100644 --- a/sample/tensorflow/encoder_decoder_sample.py +++ b/sample/tensorflow/encoder_decoder_sample.py @@ -16,9 +16,13 @@ import numpy as np import argparse import numpy as np -from utils.common import DecodingArgument, TransformerArgument -from utils.decoding import tf_decoding -from utils.encoder import tf_encoder, op_encoder +from utils.common import TransformerArgument +from utils.common import DecodingBeamsearchArgument +from utils.encoder import tf_encoder +from utils.encoder import op_encoder +from utils.encoder import build_sequence_mask +from utils.decoding import tf_beamsearch_decoding +from utils.decoding import generate_encoder_result if __name__ == "__main__": @@ -51,10 +55,16 @@ + ' type 1: only run op decoder;' + ' type 2: run both tf and op decoder, and compare the difference.' + ' default: type 2') + parser.add_argument("-remove_padding", "--remove_padding", type=str, default="False", metavar="BOOL", + choices=["True", "False"], + help="remove the padding of sentence or not. This brings speedups when the average of \ + sequence length is smaller than the maximum sequence length.") args = parser.parse_args() print("\n=============== Argument ===============") - print(args) + for key in vars(args): + print("{}: {}".format(key, vars(args)[key])) + print("========================================") start_of_sentence_id = 1 end_of_sentence_id = 2 @@ -77,6 +87,7 @@ encoder_hidden_dim = encoder_head_num * encoder_size_per_head decoder_hidden_dim = decoder_head_num * decoder_size_per_head vocab_size = args.vocab_size + remove_padding = True if args.remove_padding.lower() == "true" else False tf_datatype = tf.float32 np_datatype = np.float32 atol_threshold = 2e-5 @@ -85,37 +96,38 @@ np_datatype = np.float16 atol_threshold = 2e-2 - initializer_range = 0.02 - from_data = np.random.randn(batch_size, max_seq_len, encoder_hidden_dim) + from_data = np.random.randn(batch_size, max_seq_len, encoder_hidden_dim) * initializer_range from_tensor = tf.convert_to_tensor(from_data, dtype=tf_datatype) memory_sequence_length = np.random.randint( 1, max_seq_len + 1, size=batch_size).astype(np.int32) - embedding_table = np.random.randn(vocab_size, decoder_hidden_dim).astype( - np_datatype) # a [vocab_size, decoder_hidden_dim] table + memory_sequence_length[np.random.randint(0, batch_size)] = max_seq_len + embedding_table = np.random.randn(vocab_size, decoder_hidden_dim).astype(np_datatype) * initializer_range # a [vocab_size, decoder_hidden_dim] table embedding_table = tf.convert_to_tensor(embedding_table) - - mask = np.random.randint(2, size=(batch_size, max_seq_len, max_seq_len)) - attention_mask = tf.convert_to_tensor(mask, dtype=tf_datatype) - - encoder_args = TransformerArgument(batch_size=batch_size, - beam_width=1, - head_num=encoder_head_num, - size_per_head=encoder_size_per_head, - num_layer=encoder_num_layer, - max_seq_len=max_seq_len, - dtype=tf_datatype) - - decoding_args = DecodingArgument(batch_size=batch_size, - beam_width=beam_width, - head_num=decoder_head_num, - size_per_head=decoder_size_per_head, - num_layer=decoder_num_layer, - max_seq_len=max_seq_len, - vocab_size=vocab_size, - start_id=start_of_sentence_id, - end_id=end_of_sentence_id, - encoder_hidden_dim=encoder_head_num * encoder_size_per_head, - dtype=tf_datatype) + + attention_mask = build_sequence_mask(memory_sequence_length, num_heads=encoder_head_num, maximum_length=max_seq_len, dtype=tf_datatype) + + encoder_args = TransformerArgument(beam_width=1, + head_num=encoder_head_num, + size_per_head=encoder_size_per_head, + num_layer=encoder_num_layer, + dtype=tf_datatype, + remove_padding=remove_padding) + + decoder_args = TransformerArgument(beam_width=beam_width, + head_num=decoder_head_num, + size_per_head=decoder_size_per_head, + num_layer=decoder_num_layer, + dtype=tf_datatype, + kernel_init_range=kernel_initializer_range, + bias_init_range=bias_initializer_range, + fuse_qkv=False) + + decoding_args = DecodingBeamsearchArgument(vocab_size, + start_of_sentence_id, + end_of_sentence_id, + max_seq_len, + decoder_args, + 0.0) tf_encoder_result = tf_encoder(input_tensor=from_tensor, encoder_args=encoder_args, @@ -123,31 +135,31 @@ tf_encoder_result = tf.reshape( tf_encoder_result, [batch_size, max_seq_len, encoder_hidden_dim]) - tf_decoding_result, _, _, _, _ = tf_decoding(tf_encoder_result, - memory_sequence_length, - embedding_table, - decoding_args, - args.decoder_type, - kernel_initializer_range, - bias_initializer_range, - atol_threshold) - - encoder_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) + tf_encoder_result = tf_encoder_result * tf.expand_dims(tf.sequence_mask(memory_sequence_length, maxlen=max_seq_len, dtype=tf_datatype), axis=-1) + tf_decoding_result, _, _, _, _ = tf_beamsearch_decoding(tf_encoder_result, + memory_sequence_length, + embedding_table, + decoding_args, + decoder_type=args.decoder_type) + + encoder_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) + encoder_variables_dict = {} + for v in encoder_vars: + encoder_variables_dict[v.name] = v op_encoder_result = op_encoder(inputs=from_tensor, encoder_args=encoder_args, - encoder_vars=encoder_variables, - attention_mask=attention_mask) + attention_mask=attention_mask, + encoder_vars_dict=encoder_variables_dict, + sequence_length=memory_sequence_length) op_encoder_result = tf.reshape( op_encoder_result, [batch_size, max_seq_len, encoder_hidden_dim]) - - op_decoding_result, _, _, _, _ = tf_decoding(op_encoder_result, - memory_sequence_length, - embedding_table, - decoding_args, - args.decoder_type, - kernel_initializer_range, - bias_initializer_range, - atol_threshold) + op_encoder_result = op_encoder_result * tf.expand_dims(tf.sequence_mask(memory_sequence_length, maxlen=max_seq_len, dtype=tf_datatype), axis=-1) + + op_decoding_result, _, _, _, _ = tf_beamsearch_decoding(op_encoder_result, + memory_sequence_length, + embedding_table, + decoding_args, + decoder_type=args.decoder_type) config = tf.ConfigProto() config.gpu_options.allow_growth = True diff --git a/sample/tensorflow/encoder_decoding_sample.py b/sample/tensorflow/encoder_decoding_sample.py index 0d68c1b7e..d7a08d0a8 100644 --- a/sample/tensorflow/encoder_decoding_sample.py +++ b/sample/tensorflow/encoder_decoding_sample.py @@ -15,9 +15,16 @@ import tensorflow as tf import numpy as np import argparse -from utils.common import time_test, DecodingArgument, int_result_cross_check, TransformerArgument -from utils.decoding import tf_decoding, op_decoding -from utils.encoder import tf_encoder, op_encoder +from utils.common import int_result_cross_check +from utils.common import time_test +from utils.common import TransformerArgument +from utils.common import DecodingBeamsearchArgument +from utils.encoder import tf_encoder +from utils.encoder import op_encoder +from utils.encoder import build_sequence_mask +from utils.decoding import tf_beamsearch_decoding +from utils.decoding import op_beamsearch_decoding + if __name__ == "__main__": @@ -46,10 +53,16 @@ help='data type (default: fp32)') parser.add_argument('-time', '--test_time', type=int, default=0, metavar='BOOL', help='test the time or not. (default: False (0)), True is 1.') + parser.add_argument("-remove_padding", "--remove_padding", type=str, default="False", metavar="BOOL", + choices=["True", "False"], + help="remove the padding of sentence or not. This brings speedups when the average of \ + sequence length is smaller than the maximum sequence length.") args = parser.parse_args() print("\n=============== Argument ===============") - print(args) + for key in vars(args): + print("{}: {}".format(key, vars(args)[key])) + print("========================================") start_of_sentence_id = 1 end_of_sentence_id = 2 @@ -73,6 +86,7 @@ encoder_hidden_dim = encoder_head_num * encoder_size_per_head decoder_hidden_dim = decoder_head_num * decoder_size_per_head vocab_size = args.vocab_size + remove_padding = True if args.remove_padding.lower() == "true" else False tf_datatype = tf.float32 np_datatype = np.float32 atol_threshold = 2e-5 @@ -81,36 +95,37 @@ np_datatype = np.float16 atol_threshold = 2e-2 - initializer_range = 0.02 from_data = np.random.randn(batch_size, seq_len, encoder_hidden_dim) from_tensor = tf.convert_to_tensor(from_data, dtype=tf_datatype) memory_sequence_length = np.random.randint( 1, max_seq_len + 1, size=batch_size).astype(np.int32) - embedding_table = np.random.randn(vocab_size, decoder_hidden_dim).astype( - np_datatype) # a [vocab_size, decoder_hidden_dim] table - - mask = np.random.randint(2, size=(batch_size, seq_len, seq_len)) - attention_mask = tf.convert_to_tensor(mask, dtype=tf_datatype) - - encoder_args = TransformerArgument(batch_size=batch_size, - beam_width=1, - head_num=encoder_head_num, - size_per_head=encoder_size_per_head, - num_layer=encoder_num_layer, - max_seq_len=max_seq_len, - dtype=tf_datatype) - - decoding_args = DecodingArgument(batch_size=batch_size, - beam_width=beam_width, - head_num=decoder_head_num, - size_per_head=decoder_size_per_head, - num_layer=decoder_num_layer, - max_seq_len=max_seq_len, - vocab_size=vocab_size, - start_id=start_of_sentence_id, - end_id=end_of_sentence_id, - encoder_hidden_dim=encoder_head_num * encoder_size_per_head, - dtype=tf_datatype) + memory_sequence_length[np.random.randint(0, batch_size)] = max_seq_len + embedding_table = np.random.randn(vocab_size, decoder_hidden_dim).astype(np_datatype) * initializer_range # a [vocab_size, decoder_hidden_dim] table + + attention_mask = build_sequence_mask(memory_sequence_length, num_heads=encoder_head_num, maximum_length=max_seq_len, dtype=tf_datatype) + + encoder_args = TransformerArgument(beam_width=1, + head_num=encoder_head_num, + size_per_head=encoder_size_per_head, + num_layer=encoder_num_layer, + dtype=tf_datatype, + remove_padding=remove_padding) + + decoder_args = TransformerArgument(beam_width=beam_width, + head_num=decoder_head_num, + size_per_head=decoder_size_per_head, + num_layer=decoder_num_layer, + dtype=tf_datatype, + kernel_init_range=kernel_initializer_range, + bias_init_range=bias_initializer_range, + fuse_qkv=False) + + decoding_args = DecodingBeamsearchArgument(vocab_size, + start_of_sentence_id, + end_of_sentence_id, + max_seq_len, + decoder_args, + 0.0) tf_encoder_result = tf_encoder(input_tensor=from_tensor, encoder_args=encoder_args, @@ -118,34 +133,43 @@ tf_encoder_result = tf.reshape( tf_encoder_result, [batch_size, max_seq_len, encoder_hidden_dim]) + tf_encoder_result = tf_encoder_result * tf.expand_dims(tf.sequence_mask(memory_sequence_length, maxlen=max_seq_len, dtype=tf_datatype), axis=-1) + finalized_tf_output_ids, finalized_tf_sequence_lengths, tf_output_ids, \ - tf_parent_ids, tf_sequence_lengths = tf_decoding(tf_encoder_result, - memory_sequence_length, - embedding_table, - decoding_args, - 0, - kernel_initializer_range, - bias_initializer_range) + tf_parent_ids, tf_sequence_lengths = tf_beamsearch_decoding(tf_encoder_result, + memory_sequence_length, + embedding_table, + decoding_args, + decoder_type=0) + all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) decoder_var_start_id = 0 - while all_vars[decoder_var_start_id].name.find("transformer/decoding") == -1: + while all_vars[decoder_var_start_id].name.find("transformer/decoder") == -1: decoder_var_start_id += 1 + encoder_variables = all_vars[:decoder_var_start_id] decoder_variables = all_vars[decoder_var_start_id:] + + encoder_variables_dict = {} + for v in encoder_variables: + encoder_variables_dict[v.name] = v + op_encoder_result = op_encoder(inputs=from_tensor, encoder_args=encoder_args, - encoder_vars=encoder_variables, - attention_mask=attention_mask) + attention_mask=attention_mask, + encoder_vars_dict=encoder_variables_dict, + sequence_length=memory_sequence_length) op_encoder_result = tf.reshape( op_encoder_result, [batch_size, max_seq_len, encoder_hidden_dim]) - + op_encoder_result = op_encoder_result * tf.expand_dims(tf.sequence_mask(memory_sequence_length, maxlen=max_seq_len, dtype=tf_datatype), axis=-1) + finalized_op_output_ids, finalized_op_sequence_lengths, op_output_ids, \ - op_parent_ids, op_sequence_lengths = op_decoding(op_encoder_result, - memory_sequence_length, - embedding_table, - decoder_variables, - decoding_args) + op_parent_ids, op_sequence_lengths = op_beamsearch_decoding(op_encoder_result, + memory_sequence_length, + embedding_table, + decoder_variables, + decoding_args) config = tf.ConfigProto() config.gpu_options.allow_growth = True diff --git a/sample/tensorflow/encoder_sample.py b/sample/tensorflow/encoder_sample.py index 29041322e..45c7b37d9 100644 --- a/sample/tensorflow/encoder_sample.py +++ b/sample/tensorflow/encoder_sample.py @@ -12,11 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +''' +This is a sample code to demonstrate how to use the TensorFlow custom op with +FasterTransformer library in encoder. + +This sample code builds a BERT transformer model by TensorFlow and TensorFlow +custom op. Then compare the maximum difference of them to verify the correctness +of FasterTransformer. + +Users are also able to use this sample code to test the average forward time of +TensorFlow and FasterTransformer. +''' + import tensorflow as tf import numpy as np import argparse -from utils.common import TransformerArgument, time_test, cross_check -from utils.encoder import tf_encoder, op_encoder +from utils.common import TransformerArgument +from utils.common import time_test +from utils.common import cross_check +from utils.encoder import tf_encoder +from utils.encoder import op_encoder +from utils.encoder import build_sequence_mask if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -24,20 +40,29 @@ help='batch size (default: 1)') parser.add_argument('-l', '--num_layer', type=int, default=12, metavar='NUMBER', help='number of layers (default: 12)') - parser.add_argument('-s', '--seq_len', type=int, default=32, metavar='NUMBER', - help='sequence length (default: 32)') + parser.add_argument('-s', '--max_seq_len', type=int, default=32, metavar='NUMBER', + help='max sequence length (default: 32)') parser.add_argument('-n', '--head_number', type=int, default=12, metavar='NUMBER', help='head number (default: 12)') parser.add_argument('-size', '--size_per_head', type=int, default=64, metavar='NUMBER', help='size per head (default: 64)') parser.add_argument('-d', '--data_type', type=str, default="fp32", metavar='STRING', - help='data type (default: fp32)') + help='data type (default: fp32)', choices=['fp32', 'fp16']) parser.add_argument('-time', '--test_time', type=int, default=0, metavar='BOOL', - help='test the time or not. (default: False (0)), True is 1.') + help='test the time or not. (default: False (0)), True is 1.', + choices=[0, 1]) + parser.add_argument("-remove_padding", "--remove_padding", type=str, default="False", metavar="BOOL", + choices=["True", "False"], + help="remove the padding of sentence or not. This brings speedups when the average of \ + sequence length is smaller than the maximum sequence length.") + parser.add_argument('-avg_seq', '--avg_seq_len', type=int, default=-1, metavar='NUMBER', + help='average sequence length (default: -1)') args = parser.parse_args() print("\n=============== Argument ===============") - print(args) + for key in vars(args): + print("{}: {}".format(key, vars(args)[key])) + print("========================================") np.random.seed(1) tf.set_random_seed(1) @@ -46,64 +71,108 @@ batch_size = args.batch_size num_layer = args.num_layer - seq_len = args.seq_len + max_seq_len = args.max_seq_len + avg_seq_len = args.avg_seq_len head_num = args.head_number size_per_head = args.size_per_head + remove_padding = True if args.remove_padding.lower() == "true" else False tf_datatype = tf.float32 np_datatype = np.float32 - atol_threshold = 2e-5 + atol_threshold = 3e-5 if args.data_type == "fp16": tf_datatype = tf.float16 np_datatype = np.float16 - atol_threshold = 2e-2 + atol_threshold = 3e-2 hidden_dim = head_num * size_per_head initializer_range = 0.02 - from_data = np.random.randn(batch_size, seq_len, hidden_dim) - from_tensor = tf.convert_to_tensor(from_data, dtype=tf_datatype) - - mask = np.random.randint(2, size=(batch_size, seq_len, seq_len)) - attention_mask = tf.convert_to_tensor(mask, dtype=tf_datatype) - encoder_args = TransformerArgument(batch_size=batch_size, - beam_width=1, + sequence_length = np.random.randint(1, max_seq_len + 1, size=batch_size).astype(np.int32) + if avg_seq_len != -1 and remove_padding == True: + # This means we use "remove_padding" and set a smaller average sequence length + sequence_length = np.ones(batch_size) * avg_seq_len + + from_data = np.random.randn(batch_size, max_seq_len, hidden_dim) + from_tensor = tf.convert_to_tensor(from_data, dtype=tf_datatype) + + attention_mask = build_sequence_mask(sequence_length, num_heads=head_num, maximum_length=max_seq_len, dtype=tf_datatype) + + encoder_args = TransformerArgument(beam_width=1, head_num=head_num, size_per_head=size_per_head, num_layer=num_layer, - max_seq_len=seq_len, - dtype=tf_datatype) + dtype=tf_datatype, + remove_padding=remove_padding) tf_encoder_result = tf_encoder(input_tensor=from_tensor, encoder_args=encoder_args, attention_mask=attention_mask) - encoder_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) + encoder_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) + encoder_variables_dict = {} + for v in encoder_vars: + encoder_variables_dict[v.name] = v + op_encoder_result = op_encoder(inputs=from_tensor, encoder_args=encoder_args, - encoder_vars=encoder_variables, - attention_mask=attention_mask) + attention_mask=attention_mask, + encoder_vars_dict=encoder_variables_dict, + sequence_length=sequence_length) + + ''' + Because FasterTransformer skip some computation for the padding parts, + if we do not mask these parts, the cross check result would be wrong. + ''' + tf_encoder_result = tf_encoder_result * tf.expand_dims(tf.sequence_mask(sequence_length, maxlen=max_seq_len, dtype=tf_datatype), axis=-1) + op_encoder_result = op_encoder_result * tf.expand_dims(tf.sequence_mask(sequence_length, maxlen=max_seq_len, dtype=tf_datatype), axis=-1) config = tf.ConfigProto() + config.gpu_options.allow_growth = True config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 + with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() - for idx, var in enumerate(encoder_variables): - print((str(idx) + " " + str(var.name) + " " + - str(var.shape)) + " " + str(var.dtype)) - + for idx, name in enumerate(encoder_variables_dict): + print((str(idx) + " " + str(name) + " " + + str(encoder_variables_dict[name].shape)) + " " + str(encoder_variables_dict[name].dtype)) + print("#################################") tf_encoder_result_val = sess.run(tf_encoder_result) op_encoder_result_val = sess.run(op_encoder_result) - cross_check("Encoder", tf_encoder_result_val, + + cross_check("Encoder TF v.s. FT with tensor input", tf_encoder_result_val, op_encoder_result_val, atol_threshold) + + + ''' + Use the numpy array as inputs of FasterTransformer OP. + + This method require more time for the op initialization (especially for FP16), + but the inference time would be little faster than using tensor as input. + ''' + encoder_variables_dict_2 = {} + for var, val in zip(encoder_vars, sess.run(encoder_vars)): + encoder_variables_dict_2[var.name] = val + + # op_encoder_result_2 = op_encoder(inputs=from_tensor, + # encoder_args=encoder_args, + # attention_mask=attention_mask, + # encoder_vars_dict=encoder_variables_dict_2, + # sequence_length=sequence_length) + # op_encoder_result_val_2 = sess.run(op_encoder_result_2) + # cross_check("Encoder TF v.s. FT with numpy input", tf_encoder_result_val, + # op_encoder_result_val_2, atol_threshold) if args.test_time == 1: - ite = 100 + + ite = 50 tf_time = time_test(sess, tf_encoder_result, ite) op_time = time_test(sess, op_encoder_result, ite) - - print("[INFO] TF encoder time costs: {} ms".format(tf_time)) - print("[INFO] OP encoder time costs: {} ms".format(op_time)) + # op_time_2 = time_test(sess, op_encoder_result_2, ite) + + print("[INFO] batch_size {} max_seq_len {} {} layer TF-time {:6.2f} ms".format(batch_size, max_seq_len, num_layer, tf_time)) + print("[INFO] batch_size {} max_seq_len {} {} layer FT-OP-tensor-time {:6.2f} ms".format(batch_size, max_seq_len, num_layer, op_time)) + # print("[INFO] batch_size {} max_seq_len {} {} layer FT-OP-numpy-time {:6.2f} ms".format(batch_size, max_seq_len, num_layer, op_time_2)) \ No newline at end of file diff --git a/sample/tensorflow/scripts/profile_decoder_op_performance.sh b/sample/tensorflow/scripts/profile_decoder_op_performance.sh deleted file mode 100644 index 10eea82fa..000000000 --- a/sample/tensorflow/scripts/profile_decoder_op_performance.sh +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -logdir="decoder-log" -mkdir ${logdir} -export CUDA_VISIBLE_DEVICES=1 -all_log="${logdir}/all-log.log" -echo -e "Type \t Batch_size \t Sequence_length \t DataType \t Time " > $all_log - -for batch in 1 32 64 128 256; -do - # For FP32 - tmp_log=${logdir}/batchsize-${batch}-seq-32-fp32-log.log - ./bin/decoding_gemm ${batch} 1 8 64 30000 32 768 0 - python decoder_sample.py \ - --batch_size ${batch} \ - --beam_width 1 \ - --max_seq_len 32 \ - --head_number 8 \ - --size_per_head 64 \ - --memory_hidden_dim 768 \ - --num_layer 6 \ - --data_type fp32 \ - --decoder_type 1 \ - --test_time 1 2>&1 | tee ${tmp_log} - tail ${tmp_log} -n 1 | awk -v batch_size=$batch '{print $5 "\t" batch_size "\t" 32 "\t" "FP32" "\t" $7 " " $8 }' >> $all_log - - # For FP16 - tmp_log=${logdir}/batchsize-${batch}-seq-32-fp16-log.log - ./bin/decoding_gemm ${batch} 1 8 64 30000 32 768 1 - python decoder_sample.py \ - --batch_size ${batch} \ - --beam_width 1 \ - --max_seq_len 32 \ - --head_number 8 \ - --size_per_head 64 \ - --memory_hidden_dim 768 \ - --num_layer 6 \ - --data_type fp16 \ - --decoder_type 1 \ - --test_time 1 2>&1 | tee ${tmp_log} - tail ${tmp_log} -n 1 | awk -v batch_size=$batch '{print $5 "\t" batch_size "\t" 32 "\t" "FP16" "\t" $7 " " $8 }' >> $all_log -done - -for sequence_length in 64 128; -do - # For FP32 - tmp_log=${logdir}/batchsize-1-seq-${sequence_length}-fp32-log.log - ./bin/decoding_gemm 1 1 8 64 30000 $sequence_length 768 0 - python decoder_sample.py \ - --batch_size 1 \ - --beam_width 1 \ - --max_seq_len $sequence_length \ - --head_number 8 \ - --size_per_head 64 \ - --memory_hidden_dim 768 \ - --num_layer 6 \ - --data_type fp32 \ - --decoder_type 1 \ - --test_time 1 2>&1 | tee ${tmp_log} - tail ${tmp_log} -n 1 | awk -v seq=$sequence_length '{print $5 "\t" 1 "\t" seq "\t" "FP32" "\t" $7 " " $8 }' >> $all_log - - # For FP16 - tmp_log=${logdir}/batchsize-1-seq-$sequence_length-fp16-log.log - ./bin/decoding_gemm 1 1 8 64 30000 $sequence_length 768 1 - python decoder_sample.py \ - --batch_size 1 \ - --beam_width 1 \ - --max_seq_len $sequence_length \ - --head_number 8 \ - --size_per_head 64 \ - --memory_hidden_dim 768 \ - --num_layer 6 \ - --data_type fp16 \ - --decoder_type 1 \ - --test_time 1 2>&1 | tee ${tmp_log} - tail ${tmp_log} -n 1 | awk -v seq=$sequence_length '{print $5 "\t" 1 "\t" seq "\t" "FP16" "\t" $7 " " $8 }' >> $all_log -done diff --git a/sample/tensorflow/scripts/profile_decoder_performance.sh b/sample/tensorflow/scripts/profile_decoder_performance.sh new file mode 100644 index 000000000..ec71f8427 --- /dev/null +++ b/sample/tensorflow/scripts/profile_decoder_performance.sh @@ -0,0 +1,74 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +for precision in fp32 fp16; +do + +if [ "$precision" = "fp16" ]; then + echo "Using fp16." + precision_num=1 +else + echo "Using fp32" + precision_num=0 +fi + +logdir="decoder-log-${precision}" +mkdir ${logdir} +all_log="${logdir}/all-log.log" +echo -e "| | TF (ms) | FT-OP (ms) | FT-OP Speedup | " > $all_log +echo -e "|:---------------------------------:|:-------:|:----------:|:-------------:| " >> $all_log + +for batch_size in 1 8 32 64 128 ; +do +for beam_width in 1 4 ; +do +for seq_len in 32 64 128 ; +do + tmp_log_tf=${logdir}/batchsize-${batch_size}-beamwidth-${beam_width}-seq-${seq_len}-${precision}-tf-log.log + tmp_log_ft=${logdir}/batchsize-${batch_size}-beamwidth-${beam_width}-seq-${seq_len}-${precision}-ft-log.log + + ./bin/decoding_gemm ${batch_size} ${beam_width} 8 64 30000 ${seq_len} 512 ${precision_num} + python tensorflow/decoder_sample.py \ + --batch_size ${batch_size} \ + --beam_width ${beam_width} \ + --max_seq_len ${seq_len} \ + --head_number 8 \ + --size_per_head 64 \ + --memory_hidden_dim 512 \ + --num_layer 6 \ + --data_type ${precision} \ + --decoder_type 0 \ + --test_time 1 2>&1 | tee ${tmp_log_tf} + python tensorflow/decoder_sample.py \ + --batch_size ${batch_size} \ + --beam_width ${beam_width} \ + --max_seq_len ${seq_len} \ + --head_number 8 \ + --size_per_head 64 \ + --memory_hidden_dim 512 \ + --num_layer 6 \ + --data_type ${precision} \ + --decoder_type 1 \ + --test_time 1 2>&1 | tee ${tmp_log_ft} + + ft_time=`tail -n 1 ${tmp_log_ft} | awk '{print $17}'` + tf_time=`tail -n 1 ${tmp_log_tf} | awk '{print $17}'` + ft_speedup=$(echo "scale=2; $tf_time / $ft_time" | bc) + tail -n 1 ${tmp_log_tf} | awk -v tf_time=$tf_time -v ft_time=$ft_time -v ft_speedup=$ft_speedup \ + '{print "| <" $3 ", " $5 ", " $11 "> | " tf_time " | " \ + ft_time " | " ft_speedup " | " }' >> $all_log +done # for seq_len +done # for beam_width +done # for batch_size +done # for precision \ No newline at end of file diff --git a/sample/tensorflow/scripts/profile_decoding_op_performance.sh b/sample/tensorflow/scripts/profile_decoding_op_performance.sh deleted file mode 100644 index f512a6bec..000000000 --- a/sample/tensorflow/scripts/profile_decoding_op_performance.sh +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -logdir="decoding-log" -mkdir ${logdir} -export CUDA_VISIBLE_DEVICES=1 -all_log="${logdir}/all-log.log" -echo -e "Type \t Batch_size \t Sequence_length \t DataType \t Time " > $all_log - -for batch in 1 32 64 128 256; -do - # For FP32 - tmp_log=${logdir}/batchsize-${batch}-seq-32-fp32-log.log - ./bin/decoding_gemm ${batch} 4 8 64 30000 32 768 0 - python decoding_sample.py \ - --batch_size ${batch} \ - --beam_width 4 \ - --max_seq_len 32 \ - --head_number 8 \ - --size_per_head 64 \ - --memory_hidden_dim 768 \ - --num_layer 6 \ - --data_type fp32 \ - --test_time 1 2>&1 | tee ${tmp_log} - tail ${tmp_log} -n 2 | awk -v batch_size=$batch '{print $2 "\t" batch_size "\t" 32 "\t" "FP32" "\t" $5 " " $6 }' >> $all_log - - # For FP16 - tmp_log=${logdir}/batchsize-${batch}-seq-32-fp16-log.log - ./bin/decoding_gemm ${batch} 4 8 64 30000 32 768 1 - python decoding_sample.py \ - --batch_size ${batch} \ - --beam_width 4 \ - --max_seq_len 32 \ - --head_number 8 \ - --size_per_head 64 \ - --memory_hidden_dim 768 \ - --num_layer 6 \ - --data_type fp16 \ - --test_time 1 2>&1 | tee ${tmp_log} - tail ${tmp_log} -n 2 | awk -v batch_size=$batch '{print $2 "\t" batch_size "\t" 32 "\t" "FP16" "\t" $5 " " $6 }' >> $all_log -done - -for sequence_length in 64 128; -do - # For FP32 - tmp_log=${logdir}/batchsize-1-seq-${sequence_length}-fp32-log.log - ./bin/decoding_gemm 1 4 8 64 30000 $sequence_length 768 0 - python decoding_sample.py \ - --batch_size 1 \ - --beam_width 4 \ - --max_seq_len $sequence_length \ - --head_number 8 \ - --size_per_head 64 \ - --memory_hidden_dim 768 \ - --num_layer 6 \ - --data_type fp32 \ - --test_time 1 2>&1 | tee ${tmp_log} - tail ${tmp_log} -n 2 | awk -v seq=$sequence_length '{print $2 "\t" 1 "\t" seq "\t" "FP32" "\t" $5 " " $6 }' >> $all_log - - # For FP16 - tmp_log=${logdir}/batchsize-1-seq-$sequence_length-fp16-log.log - ./bin/decoding_gemm 1 4 8 64 30000 $sequence_length 768 1 - python decoding_sample.py \ - --batch_size 1 \ - --beam_width 4 \ - --max_seq_len $sequence_length \ - --head_number 8 \ - --size_per_head 64 \ - --memory_hidden_dim 768 \ - --num_layer 6 \ - --data_type fp16 \ - --test_time 1 2>&1 | tee ${tmp_log} - tail ${tmp_log} -n 2 | awk -v seq=$sequence_length '{print $2 "\t" 1 "\t" seq "\t" "FP16" "\t" $5 " " $6 }' >> $all_log -done - diff --git a/sample/tensorflow/scripts/profile_decoding_performance.sh b/sample/tensorflow/scripts/profile_decoding_performance.sh new file mode 100644 index 000000000..a1fe130f1 --- /dev/null +++ b/sample/tensorflow/scripts/profile_decoding_performance.sh @@ -0,0 +1,65 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +for precision in fp32 fp16; +do + +if [ "$precision" = "fp16" ]; then + echo "Using fp16." + precision_num=1 +else + echo "Using fp32" + precision_num=0 +fi + +logdir="decoding-log-${precision}" +mkdir ${logdir} +all_log="${logdir}/all-log.log" +echo -e "| | TF (ms) | FT-OP (ms) | FT-OP Speedup | FT-CPP (ms) | FT-CPP Speedup | " > $all_log +echo -e "|:---------------------------------:|:-------:|:----------:|:-------------:|:-----------:|:--------------:| " >> $all_log + +for batch_size in 1 8 32 64 128 ; +do +for beam_width in 1 4 ; +do +for seq_len in 32 64 128 ; +do + tmp_log_cpp=${logdir}/batchsize-${batch_size}-beamwidth-${beam_width}-seq-${seq_len}-${precision}-cpp-log.log + tmp_log_tf=${logdir}/batchsize-${batch_size}-beamwidth-${beam_width}-seq-${seq_len}-${precision}-tf-log.log + + ./bin/decoding_gemm ${batch_size} ${beam_width} 8 64 30000 ${seq_len} 512 ${precision_num} + ./bin/decoding_beamsearch_sample ${batch_size} ${beam_width} 8 64 30000 ${seq_len} 6 512 ${precision_num} 2>&1 | tee ${tmp_log_cpp} + python tensorflow/decoding_sample.py \ + --batch_size ${batch_size} \ + --beam_width ${beam_width} \ + --max_seq_len ${seq_len} \ + --head_number 8 \ + --size_per_head 64 \ + --memory_hidden_dim 512 \ + --num_layer 6 \ + --data_type ${precision} \ + --test_time 01 2>&1 | tee ${tmp_log_tf} + ft_c_time=`tail -n 1 ${tmp_log_cpp} | awk '{print $17}'` + ft_o_time=`tail -n 1 ${tmp_log_tf} | awk '{print $17}'` + tf_time=`tail -n 2 ${tmp_log_tf} | head -n 1 | awk '{print $17}'` + ft_o_speedup=$(echo "scale=2; $tf_time / $ft_o_time" | bc) + ft_c_speedup=$(echo "scale=2; $tf_time / $ft_c_time" | bc) + tail -n 1 ${tmp_log_cpp} | awk -v tf_time=$tf_time -v ft_o_time=$ft_o_time \ + -v ft_c_time=$ft_c_time -v ft_o_speedup=$ft_o_speedup -v ft_c_speedup=$ft_c_speedup \ + '{print "| <" $3 ", " $5 ", " $11 "> | " tf_time " | " \ + ft_o_time " | " ft_o_speedup " | " ft_c_time " | " ft_c_speedup " | " }' >> $all_log +done # for seq_len +done # for beam_width +done # for batch_size +done # for precision \ No newline at end of file diff --git a/sample/tensorflow/scripts/profile_effective_transformer_performance.sh b/sample/tensorflow/scripts/profile_effective_transformer_performance.sh new file mode 100644 index 000000000..5a2bcdda0 --- /dev/null +++ b/sample/tensorflow/scripts/profile_effective_transformer_performance.sh @@ -0,0 +1,57 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +for precision in fp32 fp16; +do + +if [ "$precision" = "fp16" ]; then + echo "Using fp16." + precision_num=1 +else + echo "Using fp32" + precision_num=0 +fi + +logdir="effective-transformer-log-${precision}" +mkdir ${logdir} +all_log="${logdir}/all-log.log" +echo -e "| | TF (ms) | FT-OP (ms) | Effective FT (ms) | TF Speedup (ms) | FT-OP Speedup | " > $all_log +echo -e "|:---------------------:|:-------:|:----------:|:-----------------:|:---------------:|:-------------:| " >> $all_log + +for batch_size in 1 8 32 64 128 ; +do +for seq_len in 32 64 128 ; +do + ./bin/encoder_gemm ${batch_size} ${seq_len} 12 64 ${precision_num} + + tmp_log_tf=${logdir}/batchsize-${batch_size}-seq-${seq_len}-${precision}-tf-log.log + tmp_log_tf_2=${logdir}/batchsize-${batch_size}-seq-${seq_len}-${precision}-eff-log.log + python tensorflow/encoder_sample.py -batch ${batch_size} -s ${seq_len} -time 1 -d ${precision} 2>&1 | tee $tmp_log_tf + avg_seq_len=$(echo "scale=0; $seq_len / 2" | bc) + python tensorflow/encoder_sample.py -batch ${batch_size} -s ${seq_len} --avg_seq_len ${avg_seq_len} -remove_padding True -time 1 -d ${precision} 2>&1 | tee $tmp_log_tf_2 + + ft_o_time=`tail -n 1 ${tmp_log_tf} | awk '{print $9}'` + tf_time=`tail -n 2 ${tmp_log_tf} | head -n 1 | awk '{print $9}'` + eff_time=`tail -n 1 ${tmp_log_tf_2} | awk '{print $9}'` + eff_tf_speedup=$(echo "scale=2; $tf_time / $eff_time" | bc) + eff_ft_speedup=$(echo "scale=2; $ft_o_time / $eff_time" | bc) + + tail -n 1 ${tmp_log_tf_2} | awk -v batch_size=${batch_size} -v seq_len=${seq_len} -v avg_seq_len=${avg_seq_len} \ + -v tf_time=$tf_time -v ft_o_time=$ft_o_time -v eff_time=$eff_time \ + -v eff_tf_speedup=$eff_tf_speedup -v eff_ft_speedup=$eff_ft_speedup \ + '{print "| <" batch_size ", " seq_len ", " avg_seq_len "> | " tf_time " | " \ + ft_o_time " | " eff_time " | " eff_tf_speedup " | " eff_ft_speedup " | " }' >> $all_log +done +done +done \ No newline at end of file diff --git a/sample/tensorflow/scripts/profile_encoder_performance.sh b/sample/tensorflow/scripts/profile_encoder_performance.sh new file mode 100644 index 000000000..071265413 --- /dev/null +++ b/sample/tensorflow/scripts/profile_encoder_performance.sh @@ -0,0 +1,54 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +for precision in fp32 fp16; +do + +if [ "$precision" = "fp16" ]; then + echo "Using fp16." + precision_num=1 +else + echo "Using fp32" + precision_num=0 +fi + +logdir="bert-base-log-${precision}" +mkdir ${logdir} +all_log="${logdir}/all-log.log" +echo -e "| | TF (ms) | FT-OP (ms) | FT-OP Speedup | FT-CPP (ms) | FT-CPP Speedup | " > $all_log +echo -e "|:---------------------:|:-------:|:----------:|:-------------:|:-----------:|:--------------:| " >> $all_log + +for batch_size in 1 8 32 64 128 ; +do +for seq_len in 32 64 128 ; +do + tmp_log_cpp=${logdir}/batchsize-${batch_size}-seq-${seq_len}-${precision}-cpp-log.log + ./bin/encoder_gemm ${batch_size} ${seq_len} 12 64 ${precision_num} + ./bin/encoder_sample ${batch_size} 12 ${seq_len} 12 64 ${precision_num} 0 2>&1 | tee $tmp_log_cpp + + tmp_log_tf=${logdir}/batchsize-${batch_size}-seq-${seq_len}-${precision}-tf-log.log + python tensorflow/encoder_sample.py -batch ${batch_size} -s ${seq_len} -time 1 -d ${precision} 2>&1 | tee $tmp_log_tf + + ft_c_time=`tail -n 1 ${tmp_log_cpp} | awk '{print $9}'` + ft_o_time=`tail -n 1 ${tmp_log_tf} | awk '{print $9}'` + tf_time=`tail -n 2 ${tmp_log_tf} | head -n 1 | awk '{print $9}'` + ft_o_speedup=$(echo "scale=2; $tf_time / $ft_o_time" | bc) + ft_c_speedup=$(echo "scale=2; $tf_time / $ft_c_time" | bc) + tail -n 1 ${tmp_log_cpp} | awk -v tf_time=$tf_time -v ft_o_time=$ft_o_time \ + -v ft_c_time=$ft_c_time -v ft_o_speedup=$ft_o_speedup -v ft_c_speedup=$ft_c_speedup \ + '{print "| <" $3 ", " $5 "> | " tf_time " | " \ + ft_o_time " | " ft_o_speedup " | " ft_c_time " | " ft_c_speedup " | " }' >> $all_log +done +done +done \ No newline at end of file diff --git a/sample/tensorflow/translate_sample.py b/sample/tensorflow/translate_sample.py index 5f7ce332c..5a13cb37e 100644 --- a/sample/tensorflow/translate_sample.py +++ b/sample/tensorflow/translate_sample.py @@ -12,34 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. +''' +This is a sample code to demonstrate how to use the Fastertransformer op +to translate sentence from English to German. + +This sample code builds then encoder model by TensorFlow, which has the same +model structure to OpenNMT-tf encoder. Next, building the decoder model by +TensorFlow and FasterTransformer op, which has the same model structure to +OpenNMT-tf decoder. So, we can restore the checkpoint of OpenNMT-tf +transformer model directly. + +We compare the bleu scores and the times of translating all sentences in test +dataset of TensorFlow and FasterTransformer op. +''' + from __future__ import print_function +import copy +from datetime import datetime import tensorflow as tf import numpy as np import argparse -from utils.common import DecodingArgument -from utils.decoding import tf_decoding, op_decoding +import os +from utils.common import TransformerArgument +from utils.common import DecodingSamplingArgument +from utils.common import DecodingBeamsearchArgument +from utils.encoder import tf_encoder_opennmt +from utils.decoding import tf_beamsearch_decoding +from utils.decoding import tf_sampling_decoding +from utils.decoding import op_beamsearch_decoding +from utils.decoding import op_sampling_decoding +from utils.bleu_score import bleu_score from opennmt.utils import misc -from opennmt.encoders.self_attention_encoder import SelfAttentionEncoder -from opennmt.decoders.self_attention_decoder import SelfAttentionDecoder from opennmt.inputters import WordEmbedder from opennmt.inputters import ExampleInputter -def restore_model_by_pkl(sess, variables): - import pickle as pkl - with open("model.pkl", 'rb') as model_file: - model_dict = pkl.load(model_file) - - assign_op_list = [] - for var in variables: - print(var.name, end=' ') - if var.name in model_dict: - print("restore", end=' ') - assign_op_list.append(tf.assign(var, np.reshape(model_dict[var.name], var.shape))) - print("mean: {} , var: {} . ".format(np.mean(model_dict[var.name]), np.std(model_dict[var.name])), end=' ') - print() - assert(len(assign_op_list) == len(variables)) - sess.run(assign_op_list) - if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -50,7 +56,7 @@ def restore_model_by_pkl(sess, variables): parser.add_argument('-s', '--max_seq_len', type=int, default=200, metavar='NUMBER', help='max sequence length (default: 200)') parser.add_argument('-encoder_head', '--encoder_head_number', type=int, default=8, metavar='NUMBER', - help='encoder head number (default: 12)') + help='encoder head number (default: 8)') parser.add_argument('-encoder_size', '--encoder_size_per_head', type=int, default=64, metavar='NUMBER', help='encoder size per head (default: 64)') parser.add_argument('-decoder_head', '--decoder_head_number', type=int, default=8, metavar='NUMBER', @@ -62,17 +68,44 @@ def restore_model_by_pkl(sess, variables): parser.add_argument('-decoder_layer', '--decoder_num_layer', type=int, default=6, metavar='NUMBER', help='number of layers (default: 6)') parser.add_argument('-d', '--data_type', type=str, default="fp32", metavar='STRING', - help='data type (default: fp32)') + help='data type (default: fp32)', choices=['fp32', 'fp16']) + parser.add_argument('-time', '--test_time', type=str, default='', metavar='STRING', + help=''' + Test the time of which one (default: '' (not test anyone) ); + '': not test anyone + '0': test tf_decoding_beamsearch + '1': test op_decoder_beamsearch + '2': test op_decoding_beamsearch + '3': test tf_decoding_sampling + '4': test op_decoder_sampling + '5': test op_decoding_sampling + 'e.g., if you want to test op_decoder_beamsearch and op_decoding_sampling, + then you need to use -time '15' ''') + parser.add_argument('-diversity_rate', '--beam_search_diversity_rate', type=float, default=0.0, metavar='NUMBER', + help='deviersity rate of beam search. default is 0. When diversity rate = 0, it is equivalent to the naive beams earch.') + parser.add_argument('-topk', '--sampling_topk', type=int, default=1, metavar='NUMBER', + help='Candidate (k) value of top k sampling in decoding. Default is 1.') + parser.add_argument('-topp', '--sampling_topp', type=float, default=0.0, metavar='NUMBER', + help='Probability (p) value of top p sampling in decoding. Default is 0.0. ') + + parser.add_argument('--source_vocabulary', type=str, default="./tensorflow/utils/translation/wmtende.vocab", metavar='STRING', + help='Source vocabulary file path. Default is ./tensorflow/utils/translation/wmtende.vocab ') + parser.add_argument('--target_vocabulary', type=str, default="./tensorflow/utils/translation/wmtende.vocab", metavar='STRING', + help='Target vocabulary file path. Default is ./tensorflow/utils/translation/wmtende.vocab ') + parser.add_argument('--source', type=str, default="./tensorflow/utils/translation/test.en", metavar='STRING', + help='Source file path. Default is ./tensorflow/utils/translation/test.en ') + parser.add_argument('--target', type=str, default="./tensorflow/utils/translation/test.de", metavar='STRING', + help='Target file path. Default is ./tensorflow/utils/translation/test.de ') args = parser.parse_args() print("\n=============== Argument ===============") - print(args) + for key in vars(args): + print("{}: {}".format(key, vars(args)[key])) + print("========================================") start_of_sentence_id = 1 end_of_sentence_id = 2 - np.random.seed(1) - tf.set_random_seed(1) kernel_initializer_range = 0.02 bias_initializer_range = 0.02 @@ -89,164 +122,242 @@ def restore_model_by_pkl(sess, variables): decoder_hidden_dim = decoder_head_num * decoder_size_per_head tf_datatype = tf.float32 np_datatype = np.float32 - atol_threshold = 2e-5 if args.data_type == "fp16": tf_datatype = tf.float16 np_datatype = np.float16 - atol_threshold = 2e-2 + beam_search_diversity_rate = args.beam_search_diversity_rate + sampling_topk = args.sampling_topk + sampling_topp = args.sampling_topp - initializer_range = 0.02 - - source_inputter = WordEmbedder("source_vocabulary", embedding_size=512) - target_inputter = WordEmbedder("target_vocabulary", embedding_size=512) + source_inputter = WordEmbedder("source_vocabulary", embedding_size=encoder_hidden_dim, dtype=tf_datatype) + target_inputter = WordEmbedder("target_vocabulary", embedding_size=decoder_hidden_dim, dtype=tf_datatype) inputter = ExampleInputter(source_inputter, target_inputter) inputter.initialize({ - "source_vocabulary": "./utils/translation/wmtende.vocab", - "target_vocabulary": "./utils/translation/wmtende.vocab" + "source_vocabulary": args.source_vocabulary, + "target_vocabulary": args.target_vocabulary }) vocab_size = target_inputter.vocabulary_size - source_file = "./utils/translation/test.en" - - decoding_args = DecodingArgument(batch_size=batch_size, - beam_width=beam_width, - head_num=decoder_head_num, - size_per_head=decoder_size_per_head, - num_layer=decoder_num_layer, - max_seq_len=max_seq_len, - vocab_size=vocab_size, - start_id=start_of_sentence_id, - end_id=end_of_sentence_id, - encoder_hidden_dim=encoder_head_num * encoder_size_per_head, - dtype=tf_datatype) + source_file = args.source + + encoder_args = TransformerArgument(beam_width=1, + head_num=encoder_head_num, + size_per_head=encoder_size_per_head, + num_layer=encoder_num_layer, + dtype=tf_datatype, + kernel_init_range=kernel_initializer_range, + bias_init_range=bias_initializer_range) + + decoder_args = TransformerArgument(beam_width=beam_width, + head_num=decoder_head_num, + size_per_head=decoder_size_per_head, + num_layer=decoder_num_layer, + dtype=tf_datatype, + kernel_init_range=kernel_initializer_range, + bias_init_range=bias_initializer_range) + + decoder_args_2 = copy.deepcopy(decoder_args) # for beam search + decoder_args_2.__dict__ = copy.deepcopy(decoder_args.__dict__) + decoder_args_2.beam_width = 1 # for sampling + + decoding_beamsearch_args = DecodingBeamsearchArgument(vocab_size, + start_of_sentence_id, + end_of_sentence_id, + max_seq_len, + decoder_args, + beam_search_diversity_rate) + + decoding_sampling_args = DecodingSamplingArgument(vocab_size, + start_of_sentence_id, + end_of_sentence_id, + max_seq_len, + decoder_args_2, + sampling_topk, + sampling_topp) mode = tf.estimator.ModeKeys.PREDICT - with tf.variable_scope("transformer/encoder"): + with tf.variable_scope("transformer/encoder", reuse=tf.AUTO_REUSE): dataset = inputter.make_inference_dataset(source_file, batch_size) iterator = dataset.make_initializable_iterator() source = iterator.get_next() source_embedding = source_inputter.make_inputs(source) + source_embedding = tf.cast(source_embedding, tf_datatype) memory_sequence_length = source["length"] + + tf_encoder_result = tf_encoder_opennmt(source_embedding, encoder_args, sequence_length=memory_sequence_length) + tf_encoder_result = tf.cast(tf_encoder_result, tf_datatype) - encoder = SelfAttentionEncoder( - num_layers=encoder_num_layer, - num_units=512, - num_heads=8, - ffn_inner_dim=2048, - dropout=0.1, - attention_dropout=0.1, - relu_dropout=0.1) - memory, _, _ = encoder.encode(source_embedding, memory_sequence_length, mode=mode) - tf_encoder_result = memory - - tf_encoder_result = tf.reshape( - tf_encoder_result, [batch_size, -1, encoder_hidden_dim]) + tf_encoder_result = tf.reshape(tf_encoder_result, tf.shape(source_embedding)) with tf.variable_scope("transformer/decoder", reuse=tf.AUTO_REUSE): target_inputter.build() + target_vocab_rev = target_inputter.vocabulary_lookup_reverse() - with tf.variable_scope("transformer/decoder", reuse=tf.AUTO_REUSE): - decoder = SelfAttentionDecoder( - num_layers=6, - num_units=512, - num_heads=8, - ffn_inner_dim=2048, - dropout=0.0, - attention_dropout=0.0, - relu_dropout=0.0) + ### TF BeamSearch Decoding ### + tf_beamsearch_target_ids, tf_beamsearch_target_length, _, _, _ = tf_beamsearch_decoding(tf_encoder_result, + memory_sequence_length, + target_inputter.embedding, + decoding_beamsearch_args, + decoder_type=0) + + # tf_beamsearch_target_tokens: [batch_size, beam_width, seq_len] + tf_beamsearch_target_tokens = target_vocab_rev.lookup(tf.cast(tf_beamsearch_target_ids, tf.int64)) + tf_beamsearch_target_length = tf.minimum(tf_beamsearch_target_length + 1, tf.shape(tf_beamsearch_target_ids)[-1]) + ### end of TF BeamSearch Decoding ### + + ### TF Sampling Decoding ### + tf_sampling_target_ids, tf_sampling_target_length = tf_sampling_decoding(tf_encoder_result, + memory_sequence_length, + target_inputter.embedding, + decoding_sampling_args, + decoder_type=0) - start_tokens = tf.fill([batch_size], start_of_sentence_id) - end_token = end_of_sentence_id - - target_ids, _, target_length, _ = decoder.dynamic_decode_and_search( - target_inputter.embedding, - start_tokens, - end_token, - vocab_size=vocab_size, - beam_width=beam_width, - memory=memory, - memory_sequence_length=memory_sequence_length) - target_vocab_rev = target_inputter.vocabulary_lookup_reverse() - target_tokens = target_vocab_rev.lookup(tf.cast(target_ids, tf.int64)) - opennmt_target_length = target_length - opennmt_target_tokens = target_tokens - opennmt_target_ids = target_ids + # tf_sampling_target_tokens: [batch_size, seq_len] + tf_sampling_target_tokens = target_vocab_rev.lookup(tf.cast(tf_sampling_target_ids, tf.int64)) + tf_sampling_target_length = tf.minimum(tf_sampling_target_length + 1, tf.shape(tf_sampling_target_ids)[-1]) + ### end of TF BeamSearch Decoding ### + + ### OP BeamSearch Decoder ### + op_decoder_beamsearch_target_ids, op_decoder_beamsearch_target_length, _, _, _ = tf_beamsearch_decoding(tf_encoder_result, + memory_sequence_length, + target_inputter.embedding, + decoding_beamsearch_args, + decoder_type=1) - opennmt_variables = tf.global_variables() + # op_decoder_beamsearch_target_tokens: [batch_size, beam_width, seq_len] + op_decoder_beamsearch_target_tokens = target_vocab_rev.lookup(tf.cast(op_decoder_beamsearch_target_ids, tf.int64)) + op_decoder_beamsearch_target_length = tf.minimum(op_decoder_beamsearch_target_length + 1, tf.shape(op_decoder_beamsearch_target_ids)[-1]) + ### end of OP BeamSearch Decoder ### + + ### OP Sampling Decoder ### + op_decoder_sampling_target_ids, op_decoder_sampling_target_length = tf_sampling_decoding(tf_encoder_result, + memory_sequence_length, + target_inputter.embedding, + decoding_sampling_args, + decoder_type=1) - ## TF Decoding ### - finalized_tf_output_ids, finalized_tf_sequence_lengths, tf_output_ids, \ - tf_parent_ids, tf_sequence_lengths = tf_decoding(tf_encoder_result, - memory_sequence_length, - target_inputter.embedding, - decoding_args, - decoder_type=1, - kernel_initializer_range=kernel_initializer_range, - bias_initializer_range=bias_initializer_range) - - tf_target_ids = finalized_tf_output_ids - tf_target_length = finalized_tf_sequence_lengths - tf_target_tokens = target_vocab_rev.lookup(tf.cast(tf_target_ids, tf.int64)) - ## end of tf decoding ## - - ## op decoding ## + op_decoder_sampling_target_tokens = target_vocab_rev.lookup(tf.cast(op_decoder_sampling_target_ids, tf.int64)) + op_decoder_sampling_target_length = tf.minimum(op_decoder_sampling_target_length + 1, tf.shape(op_decoder_sampling_target_ids)[-1]) + ### end of OP BeamSearch Decoder ### + + ### Prepare Decoding variables for FasterTransformer ### all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) decoder_var_start_id = 0 - - while all_vars[decoder_var_start_id].name.find("transformer/decoding") == -1: + + while all_vars[decoder_var_start_id].name.find("transformer/decoder") == -1: decoder_var_start_id += 1 encoder_variables = all_vars[:decoder_var_start_id] - decoder_variables = all_vars[decoder_var_start_id:] + decoder_variables = all_vars[decoder_var_start_id + 1:] # decoder_var_start_id + 1 means skip the embedding table - finalized_op_output_ids, finalized_op_sequence_lengths, op_output_ids, \ - op_parent_ids, op_sequence_lengths = op_decoding(tf_encoder_result, - memory_sequence_length, - target_inputter.embedding, - decoder_variables, # first one is embedding table - decoding_args) + ### OP BeamSearch Decoding ### + op_beamsearch_target_ids, op_beamsearch_target_length, _, _, _ = op_beamsearch_decoding(tf_encoder_result, + memory_sequence_length, + target_inputter.embedding, + decoder_variables, + decoding_beamsearch_args) - op_target_ids = finalized_op_output_ids - op_target_length = finalized_op_sequence_lengths - op_target_tokens = target_vocab_rev.lookup(tf.cast(op_target_ids, tf.int64)) - - ## end of op decoding + op_beamsearch_target_tokens = target_vocab_rev.lookup(tf.cast(op_beamsearch_target_ids, tf.int64)) + op_beamsearch_target_length = tf.minimum(op_beamsearch_target_length + 1, tf.shape(op_beamsearch_target_ids)[-1]) + ### end of OP BeamSearch Decoding ### - opennmt_target_ids = tf.cast(opennmt_target_ids, tf.int32) - tf_target_ids = tf.cast(tf_target_ids, tf.int32) - op_target_ids = tf.cast(op_target_ids, tf.int32) + ### OP Sampling Decoding ### + op_sampling_target_ids, op_sampling_target_length = op_sampling_decoding(tf_encoder_result, + memory_sequence_length, + target_inputter.embedding, + decoder_variables, + decoding_sampling_args) - opennmt_target_length = tf.minimum(opennmt_target_length + 1, tf.shape(opennmt_target_ids)[2]) - tf_target_length = tf.minimum(tf_target_length + 1, tf.shape(tf_target_ids)[2]) - op_target_length = tf.minimum(op_target_length + 1, tf.shape(op_target_ids)[2]) + op_sampling_target_tokens = target_vocab_rev.lookup(tf.cast(op_sampling_target_ids, tf.int64)) + op_sampling_target_length = tf.minimum(op_sampling_target_length + 1, tf.shape(op_sampling_target_ids)[-1]) + ### end of OP Sampling Decoding ### config = tf.ConfigProto() config.gpu_options.allow_growth = True - with tf.Session(config=config) as sess: - saver = tf.train.Saver(opennmt_variables) - sess.run(tf.global_variables_initializer()) - saver.restore(sess, "translation/ckpt/model.ckpt-500000") - sess.run(tf.tables_initializer()) - sess.run(iterator.initializer) - restore_model_by_pkl(sess, decoder_variables) - - iteration = 0 - while iteration < 3: - try: - opennmt_batch_tokens, opennmt_batch_length, \ - tf_batch_tokens, tf_batch_length, \ - op_batch_tokens, op_batch_length, source_result = sess.run([opennmt_target_tokens, opennmt_target_length, - tf_target_tokens, tf_target_length, - op_target_tokens, op_target_length, source]) - print("[INFO] opennmt: ", end='') - for tokens, length in zip(opennmt_batch_tokens, opennmt_batch_length): - misc.print_bytes(b" ".join(tokens[0][:length[0] - 1])) - print("[INFO] tf : ", end='') - for tokens, length in zip(tf_batch_tokens, tf_batch_length): - misc.print_bytes(b" ".join(tokens[0][:length[0] - 1])) - print("[INFO] op : ", end='') - for tokens, length in zip(op_batch_tokens, op_batch_length): - misc.print_bytes(b" ".join(tokens[0][:length[0] - 1])) + + time_args = args.test_time + + class TranslationResult(object): + def __init__(self, token_op, length_op, name): + self.token_op = token_op + self.length_op = length_op + self.name = name + self.file_name = name + ".txt" + + self.token_list = [] + self.length_list = [] + self.batch_num = 0 + self.execution_time = 0.0 # seconds + self.sentence_num = 0 + self.bleu_score = None + + translation_result_list = [] + + if time_args.find("0") != -1: + translation_result_list.append(TranslationResult( + tf_beamsearch_target_tokens, tf_beamsearch_target_length, "tf-decoding-beamsearch")) + if time_args.find("1") != -1: + translation_result_list.append(TranslationResult( + op_decoder_beamsearch_target_tokens, op_decoder_beamsearch_target_length, "op-decoder-beamsearch")) + if time_args.find("2") != -1: + translation_result_list.append(TranslationResult( + op_beamsearch_target_tokens, op_beamsearch_target_length, "op-decoding-beamsearch")) + if time_args.find("3") != -1: + translation_result_list.append(TranslationResult( + tf_sampling_target_tokens, tf_sampling_target_length, "tf-decoding-sampling")) + if time_args.find("4") != -1: + translation_result_list.append(TranslationResult( + op_decoder_sampling_target_tokens, op_decoder_sampling_target_length, "op-decoder-sampling")) + if time_args.find("5") != -1: + translation_result_list.append(TranslationResult( + op_sampling_target_tokens, op_sampling_target_length, "op-decoding-sampling")) + + float_var_list = [] + half_var_list = [] + for var in tf.global_variables()[:-1]: + if var.dtype.base_dtype == tf.float32: + float_var_list.append(var) + elif var.dtype.base_dtype == tf.float16: + half_var_list.append(var) + + for i in range(len(translation_result_list)): + with tf.Session(config=config) as sess: + sess.run(tf.global_variables_initializer()) + sess.run(tf.tables_initializer()) + sess.run(iterator.initializer) + + if(len(float_var_list) > 0): + float_saver = tf.train.Saver(float_var_list) + float_saver.restore(sess, "translation/ckpt/model.ckpt-500000") + if(len(half_var_list) > 0): + half_saver = tf.train.Saver(half_var_list) + half_saver.restore(sess, "translation/ckpt/fp16_model.ckpt-500000") - iteration += 1 - except tf.errors.OutOfRangeError: - break + t1 = datetime.now() + while True: + try: + batch_tokens, batch_length = sess.run([translation_result_list[i].token_op, + translation_result_list[i].length_op]) + for tokens, length in zip(batch_tokens, batch_length): + if translation_result_list[i].name.find("beamsearch") != -1: + translation_result_list[i].token_list.append(b" ".join(tokens[0][:length[0] - 2]).decode("UTF-8")) + else: + translation_result_list[i].token_list.append(b" ".join(tokens[:length - 2]).decode("UTF-8")) + translation_result_list[i].batch_num += 1 + except tf.errors.OutOfRangeError: + break + t2 = datetime.now() + time_sum = (t2 - t1).total_seconds() + translation_result_list[i].execution_time = time_sum + with open(translation_result_list[i].file_name, "w") as file_b: + for s in translation_result_list[i].token_list: + file_b.write(s) + file_b.write("\n") + + ref_file_path = "./.ref_file.txt" + os.system("head -n %d %s > %s" % (len(translation_result_list[i].token_list), args.target, ref_file_path)) + translation_result_list[i].bleu_score = bleu_score(translation_result_list[i].file_name, ref_file_path) + os.system("rm {}".format(ref_file_path)) + + for t in translation_result_list: + print("[INFO] {} translates {} batches taking {:.2f} ms to translate {} tokens, BLEU score: {:.2f}, {:.0f} tokens/sec.".format( + t.name, t.batch_num, t.execution_time, t.bleu_score.sys_len, t.bleu_score.score, t.bleu_score.sys_len / t.execution_time)) diff --git a/sample/tensorflow/unit_test/beam_search_unit_test.py b/sample/tensorflow/unit_test/beam_search_unit_test.py new file mode 100644 index 000000000..a85df4a05 --- /dev/null +++ b/sample/tensorflow/unit_test/beam_search_unit_test.py @@ -0,0 +1,90 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import numpy as np +import os,sys +sys.path.append(os.getcwd()) +from utils.beam_search import beam_search, diverse_sibling_search + +sample_number = 100000 + +np.random.seed(1) +tf.set_random_seed(1) + +def beam_search_unit_test(): + batch_size = 1 + beam_width = 2 + vocab_size = 4 + + # case 1 + probs_1 = np.asarray([ [4.0, 3.0, 2.0, 1.0], + [4.1, 3.1, 2.1, 1.1] ]) + diversity_rate_1 = 0.0 + np_result_1 = np.asarray([4, 0]).astype(np.int32) + final_id_1 = diverse_sibling_search(probs_1, beam_width, vocab_size, diversity_rate=diversity_rate_1) + + # case 2 + probs_2 = np.asarray([ [4.0, 3.0, 2.0, 1.0], + [4.1, 3.1, 2.1, 1.1] ]) + diversity_rate_2 = -1.0 + np_result_2 = np.asarray([4, 0]).astype(np.int32) + final_id_2 = diverse_sibling_search(probs_2, beam_width, vocab_size, diversity_rate=diversity_rate_2) + + # case 3 + probs_3 = np.asarray([ [4.0, 3.0, 2.0, 1.0], + [2.1, 1.1, 0.1, 0.01] ]) + diversity_rate_3 = -1.0 + np_result_3 = np.asarray([0, 4]).astype(np.int32) + final_id_3 = diverse_sibling_search(probs_3, beam_width, vocab_size, diversity_rate=diversity_rate_3) + + # case 4 + probs_4 = np.asarray([ [4.0, 3.0, 2.0, 1.0], + [2.1, 1.1, 0.1, 0.01] ]) + diversity_rate_4 = 0.0 + np_result_4 = np.asarray([0, 1]).astype(np.int32) + final_id_4 = diverse_sibling_search(probs_4, beam_width, vocab_size, diversity_rate=diversity_rate_4) + + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + print("[INFO] start the diversing beam search unit test.") + with tf.Session(config=config) as sess: + tf_result_1 = sess.run(final_id_1) + tf_result_1 = np.asarray(tf_result_1).astype(np.int32) + for i, j in zip(tf_result_1, np_result_1): + assert(i == j) + print("[INFO] case_1 pass.") + + tf_result_2 = sess.run(final_id_2) + tf_result_2 = np.asarray(tf_result_2).astype(np.int32) + for i, j in zip(tf_result_2, np_result_2): + assert(i == j) + print("[INFO] case_2 pass.") + + tf_result_3 = sess.run(final_id_3) + tf_result_3 = np.asarray(tf_result_3).astype(np.int32) + for i, j in zip(tf_result_3, np_result_3): + assert(i == j) + print("[INFO] case_3 pass.") + + tf_result_4 = sess.run(final_id_4) + tf_result_4 = np.asarray(tf_result_4).astype(np.int32) + for i, j in zip(tf_result_4, np_result_4): + assert(i == j) + print("[INFO] case_4 pass.") + +if __name__ == "__main__": + beam_search_unit_test() + + diff --git a/sample/tensorflow/unit_test/sampling_unit_test.py b/sample/tensorflow/unit_test/sampling_unit_test.py new file mode 100644 index 000000000..9d3d497b5 --- /dev/null +++ b/sample/tensorflow/unit_test/sampling_unit_test.py @@ -0,0 +1,155 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import numpy as np +import os,sys +sys.path.append(os.getcwd()) +from utils.sampling import Sampling + +sample_number = 100000 + +np.random.seed(1) +tf.set_random_seed(1) + +def top_k_sampling_unit_test(): + top_k_sampling = Sampling("top_k") + + probs_1 = np.asarray([ [4.0, 3.0, 2.0, 1.0] ]) + np_result_1 = [0] + k_1 = 1 + tf_result_1 = top_k_sampling.sample(tf.convert_to_tensor(probs_1), k_1) + + probs_2 = np.asarray(np.log([ [0.6, 0.4, 0.3, 0.1] ])) + np_result_2 = [0.6, 0.4, 0.0, 0.0] + k_2 = 2 + tf_result_2 = top_k_sampling.sample(tf.convert_to_tensor(probs_2), k_2, sample_number) + + np_probs_3 = [0.3, 0.4, 0.25, 0.01, 0.05] + probs_3 = np.asarray(np.log([np_probs_3])) + np_result_3 = [0.3, 0.4, 0.25, 0.0, 0.05] + k_3 = 4 + tf_result_3 = top_k_sampling.sample(tf.convert_to_tensor(probs_3), k_3, sample_number) + + np_probs_4 = [0.3, 0.4, 0.25, 0.01, 0.05] + probs_4 = np.asarray(np.log([np_probs_4])) + np_result_4 = [0.3/0.7, 0.4/0.7, 0.0, 0.00, 0.00] + k_4 = 2 + tf_result_4 = top_k_sampling.sample(tf.convert_to_tensor(probs_4), k_4, sample_number) + + np_probs_5 = np.random.randn(1, 10000) * 1 + np_probs_5 = np.abs(np_probs_5) + np_probs_5[0][0] *= 5 + np_result_5 = np_probs_5 + k_5 = 10 + np_sorted_result_5 = np.sort(np_probs_5) + threshold = np_sorted_result_5[:,-k_5] + mask = np_probs_5 >= threshold + np_result_5 = np_probs_5 * mask + np_result_5[0] = np_result_5[0] / np.sum(np_result_5[0]) + tf_result_5 = top_k_sampling.sample(tf.convert_to_tensor(np.log(np_probs_5)), k_5, sample_number) + + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + print("[INFO] start the top k sampling unit test.") + with tf.Session(config=config) as sess: + tf_result_1 = sess.run(tf_result_1) + tf_result_1 = np.asarray(tf_result_1).astype(np.int32) + print("[INFO] case_1.") + for i, j in zip(tf_result_1, np_result_1): + assert(i == j) + + print("[INFO] case_2.") + tf_result_2 = sess.run(tf_result_2) + p0 = 0 + p1 = 0 + for i in tf_result_2: + if i == 0: + p0 += 1 + elif i == 1: + p1 += 1 + print(p0/sample_number, p1/sample_number) + print(np_result_2) + + print("[INFO] case_3.") + tf_result_3 = sess.run(tf_result_3) + p = np.zeros_like(np_result_3) + for i in tf_result_3: + p[i] += 1 + print(p * 1.0 / sample_number) + print(np_result_3) + + print("[INFO] case_4.") + tf_result_4 = sess.run(tf_result_4) + p = np.zeros_like(np_result_4) + for i in tf_result_4: + p[i] += 1 + print(p * 1.0 / sample_number) + print(np_result_4) + + print("[INFO] case_5.") + tf_result_5 = sess.run(tf_result_5) + p = np.zeros_like(np_result_5) + print(tf_result_5) + for i in tf_result_5: + p[0][i] += 1 + for i, j in zip(p[0]/sample_number, np_result_5[0]): + if i != 0 or j != 0: + print(i, j) + + + +def top_p_sampling_unit_test(): + top_p_sampling = Sampling("top_p") + + np_probs_1 = [0.3, 0.01, 0.4, 0.25, 0.05] + np_result_1 = [0, 0, 1, 0, 0] + p_1 = 0.3 + tf_result_1 = top_p_sampling.sample(tf.convert_to_tensor(np.log([np_probs_1])), p_1, sample_number) + + np_probs_2 = [0.3, 0.01, 0.4, 0.25, 0.05] + np_result_2 = [3./7, 0, 4./7, 0, 0] + p_2 = 0.5 + tf_result_2 = top_p_sampling.sample(tf.convert_to_tensor(np.log([np_probs_2])), p_2, sample_number) + + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + print("[INFO] start the top p sampling unit test.") + with tf.Session(config=config) as sess: + + print("[INFO] case_1.") + tf_result_1 = sess.run(tf_result_1) + p = np.zeros_like(np_result_1) + for i in tf_result_1: + for j in range(len(p)): + if i == j: + p[j] += 1 + print(p * 1.0 / sample_number) + print(np_result_1) + + print("[INFO] case_2.") + tf_result_2 = sess.run(tf_result_2) + p = np.zeros_like(np_result_2) + for i in tf_result_2: + for j in range(len(p)): + if i == j: + p[j] += 1 + print(p * 1.0 / sample_number) + print(np_result_2) + +if __name__ == "__main__": + top_k_sampling_unit_test() + top_p_sampling_unit_test() + + diff --git a/sample/tensorflow/utils/__init__.py b/sample/tensorflow/utils/__init__.py index e69de29bb..9e3250071 100644 --- a/sample/tensorflow/utils/__init__.py +++ b/sample/tensorflow/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sample/tensorflow/utils/beam_search.py b/sample/tensorflow/utils/beam_search.py new file mode 100644 index 000000000..36c6ae3a4 --- /dev/null +++ b/sample/tensorflow/utils/beam_search.py @@ -0,0 +1,150 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + +def search_word(beam_width, + vocab_size, + step, + logits, + cum_log_probs, + finished, + cache, + extra_vars, + op_self_cache=None, + search_method=None): + + # [batch_size * beam_width, vocab_size] + batchxbeam = tf.shape(logits)[0] + log_probs = tf.nn.log_softmax(logits) + + parent_ids = extra_vars[0] + sequence_lengths = extra_vars[1] + total_probs = log_probs + tf.expand_dims(cum_log_probs, 1) + # [batch_size * beam_width, vocab_size] + [batch_size * beam_width], has to broadcast + total_probs = tf.reshape(total_probs, [-1, beam_width * vocab_size]) + + if search_method == None: + search_method = BeamSearch() + sample_ids = search_method.process(total_probs, beam_width, vocab_size) + + word_ids = sample_ids % vocab_size # [batch_size * beam_width] + beam_ids = sample_ids // vocab_size # [batch_size * beam_width] + # [batch_size * beam_width] + beam_indices = (tf.range(batchxbeam) // beam_width) * beam_width + beam_ids + + sequence_lengths = tf.where( + finished, x=sequence_lengths, y=sequence_lengths + 1) + + # [batch_size * beam_width] + batch_pos = tf.range(batchxbeam) // beam_width + cum_log_probs = tf.gather_nd(total_probs, tf.stack( + [batch_pos, sample_ids], axis=-1)) # [batch_size * beam_width] + finished = tf.gather(finished, beam_indices) + sequence_lengths = tf.gather(sequence_lengths, beam_indices) + + cache = tf.contrib.framework.nest.map_structure( + lambda s: tf.gather(s, beam_indices), cache) + if op_self_cache != None: + op_self_cache = tf.contrib.framework.nest.map_structure( + lambda s: tf.gather(s, beam_indices, axis=3), op_self_cache) + + parent_ids = parent_ids.write(step, beam_ids) + extra_vars = [parent_ids, sequence_lengths] + + return word_ids, cum_log_probs, finished, cache, tuple(extra_vars), op_self_cache + +class Search(): + + def __init__(self): + pass + + def process(self, total_probs, beam_width, vocab_size): + pass + +class BeamSearch(Search): + + def __init__(self): + pass + + def process(self, total_probs, beam_width, vocab_size): + ''' + inputs: + total_probs: float tensor, [batch_size * beam_width, vocab_size] + beam_width: int scalar + + outputs: + sample_ids: int tensor, [batch_size * beam_width] + ''' + + # [batch_size, beam_width * vocab_size], can skip in cuda + total_probs = tf.reshape(total_probs, [-1, beam_width * vocab_size]) + + _, sample_ids = tf.nn.top_k(total_probs, beam_width) + # [batch_size * beam_width], can skip in cuda + sample_ids = tf.reshape(sample_ids, [-1]) + + return sample_ids + +class DiverseSiblingSearch(Search): + + def __init__(self, diversity_rate): + ''' + inputs: + diversity: int scalar, >= 0 + if diversity_rate == 0, then it is equivalent to beam_search + ''' + self.diversity_rate = diversity_rate + + def process(self, total_probs, beam_width, vocab_size): + ''' + inputs: + total_probs: float tensor, [batch_size * beam_width, vocab_size] + + outputs: + sample_ids: int tensor, [batch_size * beam_width] + beam_ids: int tensor, [batch_size * beam_width] + + 1. calculate hypothese for each beam + 2. Intra-sibling ordering + 3. rewrite scores + 4. choose top K hypothese + ''' + + total_probs = tf.reshape(total_probs, [-1, beam_width, vocab_size]) # [batch size, beam width, vocab size] + + sibling_score = tf.cast(tf.range(1, beam_width+1), total_probs.dtype) * self.diversity_rate # [beam_width] + + scores, ids = tf.nn.top_k(total_probs, beam_width) # [batch size, beam width, beam width] + scores = tf.add(scores, sibling_score) # [batch size, beam width, beam width] + + scores = tf.reshape(scores, [-1, beam_width * beam_width]) + ids = ids + tf.expand_dims(tf.expand_dims(tf.range(0, beam_width) * vocab_size, 0), -1) + ids = tf.reshape(ids, [-1, beam_width * beam_width]) + + _, final_ids = tf.nn.top_k(scores, beam_width) # [batch size, beam width] + + batch_size = tf.shape(final_ids)[0] + final_ids = tf.reshape(final_ids, [-1, 1]) + batch_index = tf.range(0, batch_size) + batch_index = tf.reshape(batch_index, [-1, 1]) + batch_index = tf.tile(batch_index, [1, beam_width]) + batch_index = tf.reshape(batch_index, [-1, 1]) + + index = tf.concat([batch_index, final_ids ], axis=1) + sample_ids = tf.gather_nd(ids, index) + + return sample_ids + + diff --git a/sample/tensorflow/utils/bleu_score.py b/sample/tensorflow/utils/bleu_score.py new file mode 100644 index 000000000..e60c50377 --- /dev/null +++ b/sample/tensorflow/utils/bleu_score.py @@ -0,0 +1,33 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import tensorflow as tf +from sacrebleu import corpus_bleu + +def bleu_score(pred_file, ref_file): + with tf.io.gfile.GFile(pred_file) as pred_stream, tf.io.gfile.GFile(ref_file) as ref_stream: + bleu = corpus_bleu(pred_stream, [ref_stream], force=True) + print(" bleu score: {:6.2f}".format(bleu.score)) + print(" bleu counts: {}".format(bleu.counts)) + print(" bleu totals: {}".format(bleu.totals)) + print(" bleu precisions: {}".format(bleu.precisions)) + print(" bleu sys_len: {}; ref_len: {}".format(bleu.sys_len, bleu.ref_len)) + return bleu + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("[ERROR] bleu_score.py needs a result file and a solution file. \n e.g. python bleu_score.py f1.txt f2.txt") + sys.exit(0) + bleu_score(sys.argv[1], sys.argv[2]) diff --git a/sample/tensorflow/utils/common.py b/sample/tensorflow/utils/common.py index 857143587..6ee9759c8 100644 --- a/sample/tensorflow/utils/common.py +++ b/sample/tensorflow/utils/common.py @@ -15,51 +15,154 @@ from datetime import datetime import tensorflow as tf import numpy as np +import ctypes +from utils.beam_search import BeamSearch +from utils.beam_search import DiverseSiblingSearch class TransformerArgument: def __init__( self, - batch_size, beam_width, head_num, size_per_head, num_layer, - max_seq_len, - dtype): + dtype=tf.float32, + kernel_init_range=0.02, + bias_init_range=0.02, + fuse_qkv=True, + remove_padding=False): + ''' + The arguments of Transformer layer (for both encoder and decoder). + + Args: + beam_width: The beam_width size for beam search. This argument is always one for encoder. + head_num: The head number of self attention in transformer layer. + size_per_head: The size of hidden dimension for each head of self attention in transformer layer. + num_layer: The number of transformer layer. For example, BERT-base uses 12 layers. + dtype: The data type of weights initializer and inputs. + kernel_init_range: The initializer range of kernel for all convolution layer and fully-connected layer. + kernel_init_range: The initializer range of bias for all convolution layer and fully-connected layer. + fuse_qkv: bool. Wether fuse the q, k, v gemm or not. + remove_padding: bool. Remove the padding of sentences of encoder. + ''' - self.batch_size = batch_size self.beam_width = beam_width self.head_num = head_num self.size_per_head = size_per_head self.num_layer = num_layer - self.max_seq_len = max_seq_len self.dtype = dtype self.hidden_dim = self.head_num * self.size_per_head + self.kernel_init_range = kernel_init_range + self.bias_init_range = bias_init_range + if self.dtype == tf.float32: + self.check_threshold = 2e-5 + elif self.dtype == tf.float16: + self.check_threshold = 2e-2 + self.fuse_qkv = fuse_qkv + self.remove_padding = remove_padding -class DecodingArgument: +class DecodingArgument(object): def __init__( self, - batch_size, - beam_width, - head_num, - size_per_head, - num_layer, - max_seq_len, vocab_size, start_id, end_id, - encoder_hidden_dim, - dtype): + max_seq_len, + decoder_args): + ''' + The arguments of Decoding. + Decoding is the function which contains the whole translation process. + For example, the embedding lookup, position encoding, decoder, and + beam search or sampling to choose the token. + + Args: + vocab_size: The size of vocabulary of Decoding. + start_id: The id of start token in vocabulary. + end_id: The id of end token in vocabulary. + max_seq_len: The maximum length of sentence in translation. + decoder_args: The arguments of decoder layer. + ''' - self.decoder_args = TransformerArgument(batch_size, - beam_width, - head_num, - size_per_head, - num_layer, - max_seq_len, - dtype) self.vocab_size = vocab_size self.start_id = start_id self.end_id = end_id - self.encoder_hidden_dim = encoder_hidden_dim + self.max_seq_len = max_seq_len + self.decoder_args = decoder_args + +class DecodingBeamsearchArgument(DecodingArgument): + def __init__( self, + vocab_size, + start_id, + end_id, + max_seq_len, + decoder_args, + beam_search_diversity_rate=-0.0): + ''' + The arguments of Decoding with beam search. + Most arguments are similar to DecodingArgument except the beam_search_diversity_rate. + + Args: + vocab_size: The size of vocabulary of Decoding. + start_id: The id of start token in vocabulary. + end_id: The id of end token in vocabulary. + max_seq_len: The maximum length of sentence in translation. + decoder_args: The arguments of decoder layer. + beam_search_diversity_rate: The diversity rate of beam search. When it is 0, + it is equivalent to naive beam search. + ''' + + super(DecodingBeamsearchArgument, self).__init__(vocab_size, + start_id, + end_id, + max_seq_len, + decoder_args) + + self.beam_search_diversity_rate = beam_search_diversity_rate + if abs(self.beam_search_diversity_rate) == 0.0: + self.search_method = BeamSearch() + else: + self.search_method = DiverseSiblingSearch(beam_search_diversity_rate) + +class DecodingSamplingArgument(DecodingArgument): + def __init__( self, + vocab_size, + start_id, + end_id, + max_seq_len, + decoder_args, + top_k=0, + top_p=0.0): + ''' + The arguments of Decoding with sampling. + Most arguments are similar to DecodingArgument except the top_k and top_p. + + Args: + vocab_size: The size of vocabulary of Decoding. + start_id: The id of start token in vocabulary. + end_id: The id of end token in vocabulary. + max_seq_len: The maximum length of sentence in translation. + decoder_args: The arguments of decoder layer. + top_k: A int value. The value of k for top k sampling. + top_p: A float value. The value of p for top p sampling. + + Note that top_k and top_p both are 0 in the same time is invalid. + Note that top_k and top_p both are non-zero in the same time is invalid. + If top_k is non-zero, the Decoding function will use the top k sampling. + If top_k is non-zero, the Decoding function will use the top p sampling. + ''' + + super(DecodingSamplingArgument, self).__init__(vocab_size, + start_id, + end_id, + max_seq_len, + decoder_args) + + self.top_k = top_k + self.top_p = top_p + if self.top_k == 0 and self.top_p == 0.0: + print("[ERROR] top_k and top_p cannot both be 0.") + exit(-1) + elif self.top_k != 0 and self.top_p != 0.0: + print("[ERROR] top_k and top_p cannot both be non-zero.") + exit(-1) def create_initializer(initializer_range=0.02, data_type=tf.float32): return tf.truncated_normal_initializer(stddev=initializer_range, dtype=data_type) @@ -108,4 +211,20 @@ def int_result_cross_check(name, tf_result, op_result, shape): print(" Cross-Check on step-{} {}".format(i, is_true)) if is_true == False: print("TF result: {}".format(tf_reshaped_result[i])) - print("OP result: {}".format(op_reshaped_result[i])) \ No newline at end of file + print("OP result: {}".format(op_reshaped_result[i])) + +class cudaProfiler: + + def __init__(self): + self.profiler = ctypes.CDLL("libcudart.so") + + def start(self): + ret = self.profiler.cudaProfilerStart() + if ret != 0: + raise Exception("cudaProfilerStart() return %d " %ret) + + def stop(self): + ret = self.profiler.cudaProfilerStop() + if ret != 0: + raise Exception("cudaProfilerStop() return %d " %ret) + \ No newline at end of file diff --git a/sample/tensorflow/utils/decoder.py b/sample/tensorflow/utils/decoder.py index ad9fbdd3f..34ebc3d96 100644 --- a/sample/tensorflow/utils/decoder.py +++ b/sample/tensorflow/utils/decoder.py @@ -14,7 +14,7 @@ import os import tensorflow as tf -from common import create_initializer +from utils.common import create_initializer def norm(inputs): """Layer normalizes :obj:`inputs`.""" @@ -69,9 +69,34 @@ def tf_decoder(decoder_args, memory, memory_sequence_length, step, - cache=None, - kernel_initializer_range=0.02, - bias_initializer_range=0): + cache=None): + ''' + Run the decoder transformer layer by TensorFlow. + + Args: + decoder_args: The arguments for decoder. The details are in the class "TransformerArgument" of common.py + inputs: A tf.Tensor with shape [batch_size * beam_width, 1, hidden_dimension]. + The inputs tensor of encoder. The rank must be 3. + memory: A tf.tensor with shape [batch_size * beam_width, max(memory_sequence_length), encoder_hidden_dimension]. + The results of encoder transformer layer. The rank must be 3. + Note that it must be extended by beam_width times + memory_sequence_length: A tf.Tensor with shape [batch_size * beam_width], type tf.int. + The lenght of each sentence of results of encoder. + Note that it must be extended by beam_width times + step: A tf.Tensor with tf.int type. The current step in the translation process. + cache: A dict. The cache space to store the keys and values of attention layers. + + Outputs: + outputs: A tf.Tensor with shape [batch_size * beam_width, 1, hidden_dimension]. + The results of decoder. + ''' + + k_init_range = decoder_args.kernel_init_range + b_init_range = decoder_args.bias_init_range + data_type = decoder_args.dtype + fuse_qkv = decoder_args.fuse_qkv + hidden_dim = decoder_args.hidden_dim + memory_mask = None # has something if memory is not None and not tf.contrib.framework.nest.is_sequence(memory): @@ -81,84 +106,62 @@ def tf_decoder(decoder_args, memory_sequence_length = (memory_sequence_length,) memory_mask = [ build_sequence_mask( - length, num_heads=decoder_args.head_num, maximum_length=tf.shape(m)[1], data_type=decoder_args.dtype) + length, num_heads=decoder_args.head_num, maximum_length=tf.shape(m)[1], data_type=data_type) for m, length in zip(memory, memory_sequence_length)] for l in range(decoder_args.num_layer): layer_name = "layer_{}".format(l) layer_cache = cache[layer_name] if cache is not None else None - + with tf.variable_scope(layer_name): with tf.variable_scope("masked_multi_head"): norm_inputs = norm(inputs) - queries = tf.layers.conv1d( - norm_inputs, - decoder_args.hidden_dim, - 1, - activation=None, - name="query", - use_bias=True, - bias_initializer=create_initializer( - bias_initializer_range, decoder_args.dtype), - kernel_initializer=create_initializer(kernel_initializer_range, decoder_args.dtype)) - - keys = tf.layers.conv1d( - norm_inputs, - decoder_args.hidden_dim, - 1, - activation=None, - name="key", - use_bias=True, - bias_initializer=create_initializer( - bias_initializer_range, decoder_args.dtype), - kernel_initializer=create_initializer(kernel_initializer_range, decoder_args.dtype)) - - values = tf.layers.conv1d( - norm_inputs, - decoder_args.hidden_dim, - 1, - activation=None, - name="value", - use_bias=True, - bias_initializer=create_initializer( - bias_initializer_range, decoder_args.dtype), - kernel_initializer=create_initializer(kernel_initializer_range, decoder_args.dtype)) - - keys = tf.reshape(keys, [decoder_args.batch_size * decoder_args.beam_width, - 1, decoder_args.head_num, decoder_args.size_per_head]) + if fuse_qkv == True: + queries, keys, values = tf.split( tf.layers.conv1d(norm_inputs, decoder_args.hidden_dim * 3, 1, + bias_initializer=create_initializer(b_init_range, data_type), + kernel_initializer=create_initializer(k_init_range, data_type)), 3, axis=2) + else: + ''' + This progress wants to prevent a addictional tf.concat to concat the q, k, v kernels for decoder op + becuase the concat bring large overhead for small batch size. + ''' + queries = tf.layers.conv1d(norm_inputs, decoder_args.hidden_dim, 1, + bias_initializer=create_initializer(b_init_range, data_type), + kernel_initializer=create_initializer(k_init_range, data_type)) + keys = tf.layers.conv1d(norm_inputs, decoder_args.hidden_dim, 1, + bias_initializer=create_initializer(b_init_range, data_type), + kernel_initializer=create_initializer(k_init_range, data_type), + name="key") + values = tf.layers.conv1d(norm_inputs, decoder_args.hidden_dim, 1, + bias_initializer=create_initializer(b_init_range, data_type), + kernel_initializer=create_initializer(k_init_range, data_type), + name="value") + + keys = tf.reshape(keys, [tf.shape(keys)[0], 1, decoder_args.head_num, decoder_args.size_per_head]) keys = tf.transpose(keys, [0, 2, 1, 3]) - values = tf.reshape(values, [ - decoder_args.batch_size * decoder_args.beam_width, 1, decoder_args.head_num, decoder_args.size_per_head]) + values = tf.reshape(values, [tf.shape(values)[0], 1, decoder_args.head_num, decoder_args.size_per_head]) values = tf.transpose(values, [0, 2, 1, 3]) - keys = tf.concat([layer_cache["self_keys"], keys], axis=2) - values = tf.concat( - [layer_cache["self_values"], values], axis=2) + values = tf.concat([layer_cache["self_values"], values], axis=2) layer_cache["self_keys"] = keys layer_cache["self_values"] = values - queries = tf.reshape(queries, [ - decoder_args.batch_size * decoder_args.beam_width, 1, decoder_args.head_num, decoder_args.size_per_head]) + queries = tf.reshape(queries, [tf.shape(queries)[0], 1, decoder_args.head_num, decoder_args.size_per_head]) queries = tf.transpose(queries, [0, 2, 1, 3]) queries *= (decoder_args.size_per_head)**-0.5 dot = tf.matmul(queries, keys, transpose_b=True) - attn = tf.cast(tf.nn.softmax( - tf.cast(dot, decoder_args.dtype)), dot.dtype) + attn = tf.cast(tf.nn.softmax(tf.cast(dot, data_type)), dot.dtype) context = tf.matmul(attn, values) context = tf.transpose(context, [0, 2, 1, 3]) - context = tf.reshape(context, [ - decoder_args.batch_size * decoder_args.beam_width, 1, decoder_args.head_num * decoder_args.size_per_head]) + context = tf.reshape(context, [tf.shape(context)[0], 1, decoder_args.head_num * decoder_args.size_per_head]) outputs = tf.layers.conv1d(context, - decoder_args.hidden_dim, - 1, - activation=None, - use_bias=True, - bias_initializer=create_initializer( - bias_initializer_range, decoder_args.dtype), - kernel_initializer=create_initializer(kernel_initializer_range, decoder_args.dtype)) + decoder_args.hidden_dim, + 1, + bias_initializer=create_initializer(b_init_range, data_type), + kernel_initializer=create_initializer(k_init_range, data_type)) # drop_and_add input_dim = inputs.get_shape().as_list()[-1] @@ -176,41 +179,29 @@ def tf_decoder(decoder_args, norm(last_context), decoder_args.hidden_dim, 1, - activation=None, - name="query", - use_bias=True, - bias_initializer=create_initializer( - bias_initializer_range, decoder_args.dtype), - kernel_initializer=create_initializer(kernel_initializer_range, decoder_args.dtype)) + bias_initializer=create_initializer(b_init_range, data_type), + kernel_initializer=create_initializer(k_init_range, data_type)) def _project_and_split(): - keys = tf.layers.conv1d( - mem, - decoder_args.hidden_dim, - 1, - activation=None, - name="key", - use_bias=True, - bias_initializer=create_initializer( - bias_initializer_range, decoder_args.dtype), - kernel_initializer=create_initializer(kernel_initializer_range, decoder_args.dtype)) - - values = tf.layers.conv1d( - mem, - decoder_args.hidden_dim, - 1, - activation=None, - name="value", - use_bias=True, - bias_initializer=create_initializer( - bias_initializer_range, decoder_args.dtype), - kernel_initializer=create_initializer(kernel_initializer_range, decoder_args.dtype)) - - keys = tf.reshape(keys, [decoder_args.batch_size * decoder_args.beam_width, tf.shape(keys)[1], - decoder_args.head_num, decoder_args.size_per_head]) + if fuse_qkv == True: + keys, values = tf.split( tf.layers.conv1d(mem, decoder_args.hidden_dim * 2, 1, + bias_initializer=create_initializer(b_init_range, data_type), + kernel_initializer=create_initializer(k_init_range, data_type)), 2, axis=2) + else: + keys = tf.layers.conv1d(mem, decoder_args.hidden_dim, 1, + bias_initializer=create_initializer(b_init_range, data_type), + kernel_initializer=create_initializer(k_init_range, data_type)) + values = tf.layers.conv1d(mem, decoder_args.hidden_dim, 1, + bias_initializer=create_initializer(b_init_range, data_type), + kernel_initializer=create_initializer(k_init_range, data_type), + name="value") + + + keys = tf.reshape(keys, [tf.shape(keys)[0], tf.shape(keys)[1], + decoder_args.head_num, decoder_args.size_per_head]) keys = tf.transpose(keys, [0, 2, 1, 3]) - values = tf.reshape(values, [decoder_args.batch_size * decoder_args.beam_width, tf.shape(values)[1], - decoder_args.head_num, decoder_args.size_per_head]) + values = tf.reshape(values, [tf.shape(values)[0], tf.shape(values)[1], + decoder_args.head_num, decoder_args.size_per_head]) values = tf.transpose(values, [0, 2, 1, 3]) return keys, values @@ -224,30 +215,25 @@ def _project_and_split(): memory_cache["memory_keys"] = keys memory_cache["memory_values"] = values - queries = tf.reshape(queries, [decoder_args.batch_size * decoder_args.beam_width, 1, - decoder_args.head_num, decoder_args.size_per_head]) + queries = tf.reshape(queries, [tf.shape(queries)[0], 1,decoder_args.head_num, decoder_args.size_per_head]) queries = tf.transpose(queries, [0, 2, 1, 3]) queries *= (decoder_args.size_per_head)**-0.5 - + dot = tf.matmul(queries, keys, transpose_b=True) - - dot = tf.cast(tf.cast(dot, decoder_args.dtype) * mask + - ((1.0 - mask) * decoder_args.dtype.min), dot.dtype) + dot = tf.cast(tf.cast(dot, data_type) * mask + + ((1.0 - mask) * data_type.min), dot.dtype) attn = tf.cast(tf.nn.softmax( - tf.cast(dot, decoder_args.dtype)), dot.dtype) + tf.cast(dot, data_type)), dot.dtype) context = tf.matmul(attn, values) context = tf.transpose(context, [0, 2, 1, 3]) - context = tf.reshape(context, [decoder_args.batch_size * decoder_args.beam_width, 1, + context = tf.reshape(context, [tf.shape(context)[0], 1, decoder_args.head_num * decoder_args.size_per_head]) context = tf.layers.conv1d(context, - decoder_args.hidden_dim, - 1, - activation=None, - use_bias=True, - bias_initializer=create_initializer( - bias_initializer_range, decoder_args.dtype), - kernel_initializer=create_initializer(kernel_initializer_range, decoder_args.dtype)) + decoder_args.hidden_dim, + 1, + bias_initializer=create_initializer(b_init_range, data_type), + kernel_initializer=create_initializer(k_init_range, data_type)) # drop_and_add input_dim = last_context.get_shape().as_list()[-1] @@ -260,27 +246,24 @@ def _project_and_split(): normed_last_context = norm(context) input_dim = normed_last_context.get_shape().as_list()[-1] inner = tf.layers.conv1d(normed_last_context, - decoder_args.hidden_dim * 4, - 1, - activation=tf.nn.relu, - use_bias=True, - bias_initializer=create_initializer( - bias_initializer_range, decoder_args.dtype), - kernel_initializer=create_initializer(kernel_initializer_range, decoder_args.dtype)) + decoder_args.hidden_dim * 4, + 1, + activation=tf.nn.relu, + use_bias=True, + bias_initializer=create_initializer(b_init_range, data_type), + kernel_initializer=create_initializer(k_init_range, data_type)) transformed = tf.layers.conv1d(inner, - input_dim, - 1, - use_bias=True, - bias_initializer=create_initializer( - bias_initializer_range, decoder_args.dtype), - kernel_initializer=create_initializer(kernel_initializer_range, decoder_args.dtype)) + input_dim, + 1, + use_bias=True, + bias_initializer=create_initializer(b_init_range, data_type), + kernel_initializer=create_initializer(k_init_range, data_type)) # drop_and_add input_dim = context.get_shape().as_list()[-1] output_dim = transformed.get_shape().as_list()[-1] if input_dim == output_dim: transformed += context - inputs = transformed outputs = inputs return outputs @@ -309,53 +292,160 @@ def init_tf_cache(batch_size, return cache -def init_op_cache(decoder_args): - self_cache = tf.zeros([decoder_args.num_layer, 2, 0, decoder_args.batch_size * decoder_args.beam_width, +def init_op_cache(decoder_args, batchxbeam, memory_max_seq_len): + self_cache = tf.zeros([decoder_args.num_layer, 2, 0, batchxbeam, decoder_args.hidden_dim], dtype=decoder_args.dtype, name="op_self_caches") - mem_cache = tf.zeros([decoder_args.num_layer, 2, decoder_args.batch_size * decoder_args.beam_width, - decoder_args.max_seq_len, decoder_args.hidden_dim], dtype=decoder_args.dtype, name="op_memory_caches") + mem_cache = tf.zeros([decoder_args.num_layer, 2, batchxbeam, + memory_max_seq_len, decoder_args.hidden_dim], dtype=decoder_args.dtype, name="op_memory_caches") return self_cache, mem_cache - def op_decoder(inputs, - step, memory_tensor, memory_sequence_length, op_self_cache, op_mem_cache, psuedo_input, - decoder_vars, - decoder_args, - memory_hidden_dim): + var_dict, + decoder_args): + ''' + Run the decoder transformer layer by FasterTransformer. + + Args: + inputs: A tf.Tensor with shape [batch_size * beam_width, 1, hidden_dimension]. + The inputs tensor of encoder. The rank must be 3. + memory_tensor: A tf.tensor with shape [batch_size * beam_width, max(memory_sequence_length), encoder_hidden_dimension]. + The results of encoder transformer layer. The rank must be 3. + Note that it must be extended by beam_width times + memory_sequence_length: A tf.Tensor with shape [batch_size * beam_width], type tf.int. + The lenght of each sentence of results of encoder. + Note that it must be extended by beam_width times + op_self_cache: A tf.Tensor with shape [num_layer, 2, None, batch_size * beam_width, hidden_dimension]. + The cache space to store the keys and values of first attention layer in each step. + op_mem_cache: A tf.Tensor with shape [num_layer, 2, batch_size * beam_width, max(memory_sequence_length) hidden_dimension]. + The cache space to store the keys and values of second attention layer. + Since they are same in each step, it is only need to compute them in first time. + psuedo_input: A tf.Tensor or null list. + Put the decoder results of TensorFlow when running the TensorFlow decoder and FasterTransformer + decoder in one model. This prevents the race condition. + It is useless when only run the FasterTransformer decoder. + decoder_args: The arguments for decoder. The details are in the class "TransformerArgument" of common.py + var_dict: A dict of tf.Tensor or numpy array. The variables for decoder. + They can be either some tensor or some numpy array. + + Outputs: + outputs: A tf.Tensor with shape [batch_size * beam_width, 1, hidden_dimension]. + The results of decoder. + ''' + + ''' + If fuse_qkv == Ture, this means that the computation of q, k, v in decoder are fused in one convolution. + + Therefore, we need to split them and then passing into the decoder op. The split will bring additional overhead, + especially when the batch size is small because the computation time is short. + + However, because most of the pretrained model on network fuse the qkv, so we fuse them as default. + ''' decoder_op_module = tf.load_op_library( os.path.join('./lib/libtf_decoder.so')) - + op_self_cache = tf.concat([op_self_cache, tf.zeros([decoder_args.num_layer, 2, 1, - decoder_args.batch_size * decoder_args.beam_width, + tf.shape(memory_tensor)[0], decoder_args.hidden_dim], dtype=decoder_args.dtype)], axis=2) - + fuse_qkv = decoder_args.fuse_qkv + for i in range(decoder_args.num_layer): + ''' + Handling the names of q, k, v kernel and bias because their names + are different for fusing the qkv or not. + ''' + + layer_prefix_name = "transformer/decoder/layer_%d/" % i + if fuse_qkv == True: + var_dict[layer_prefix_name + 'masked_multi_head/query/kernel:0'], \ + var_dict[layer_prefix_name + 'masked_multi_head/key/kernel:0'], \ + var_dict[layer_prefix_name + 'masked_multi_head/value/kernel:0'] = \ + tf.split(var_dict[layer_prefix_name + 'masked_multi_head/conv1d/kernel:0'], 3, axis=-1) + + var_dict[layer_prefix_name + 'masked_multi_head/query/bias:0'], \ + var_dict[layer_prefix_name + 'masked_multi_head/key/bias:0'], \ + var_dict[layer_prefix_name + 'masked_multi_head/value/bias:0'] = \ + tf.split(var_dict[layer_prefix_name + 'masked_multi_head/conv1d/bias:0'], 3, axis=-1) + + var_dict[layer_prefix_name + 'multi_head/query/kernel:0'] = \ + var_dict[layer_prefix_name + 'multi_head/conv1d/kernel:0'] + var_dict[layer_prefix_name + 'multi_head/query/bias:0'] = \ + var_dict[layer_prefix_name + 'multi_head/conv1d/bias:0'] + var_dict[layer_prefix_name + 'multi_head/key/kernel:0'], \ + var_dict[layer_prefix_name + 'multi_head/value/kernel:0'] = \ + tf.split(var_dict[layer_prefix_name + 'multi_head/conv1d_1/kernel:0'], 2, axis=-1) + var_dict[layer_prefix_name + 'multi_head/key/bias:0'], \ + var_dict[layer_prefix_name + 'multi_head/value/bias:0'] = \ + tf.split(var_dict[layer_prefix_name + 'multi_head/conv1d_1/bias:0'], 2, axis=-1) + else: + var_dict[layer_prefix_name + 'masked_multi_head/query/kernel:0'] = \ + var_dict[layer_prefix_name + 'masked_multi_head/conv1d/kernel:0'] + var_dict[layer_prefix_name + 'masked_multi_head/key/kernel:0'] = \ + var_dict[layer_prefix_name + 'masked_multi_head/key/kernel:0'] + var_dict[layer_prefix_name + 'masked_multi_head/value/kernel:0'] = \ + var_dict[layer_prefix_name + 'masked_multi_head/value/kernel:0'] + + var_dict[layer_prefix_name + 'masked_multi_head/query/bias:0'] = \ + var_dict[layer_prefix_name + 'masked_multi_head/conv1d/bias:0'] + var_dict[layer_prefix_name + 'masked_multi_head/key/bias:0'] = \ + var_dict[layer_prefix_name + 'masked_multi_head/key/bias:0'] + var_dict[layer_prefix_name + 'masked_multi_head/value/bias:0'] = \ + var_dict[layer_prefix_name + 'masked_multi_head/value/bias:0'] + + var_dict[layer_prefix_name + 'multi_head/query/kernel:0'] = \ + var_dict[layer_prefix_name + 'multi_head/conv1d/kernel:0'] + var_dict[layer_prefix_name + 'multi_head/query/bias:0'] = \ + var_dict[layer_prefix_name + 'multi_head/conv1d/bias:0'] + var_dict[layer_prefix_name + 'multi_head/key/kernel:0'] = \ + var_dict[layer_prefix_name + 'multi_head/conv1d_1/kernel:0'] + var_dict[layer_prefix_name + 'multi_head/key/bias:0'] = \ + var_dict[layer_prefix_name + 'multi_head/conv1d_1/bias:0'] + var_dict[layer_prefix_name + 'multi_head/value/kernel:0'] = \ + var_dict[layer_prefix_name + 'multi_head/value/kernel:0'] + var_dict[layer_prefix_name + 'multi_head/value/bias:0'] = \ + var_dict[layer_prefix_name + 'multi_head/value/bias:0'] + op_result, _, _ = decoder_op_module.decoder( - inputs, memory_tensor, memory_sequence_length, - decoder_vars[0 + 26 * i], decoder_vars[1 + 26 * i], - decoder_vars[2 + 26 * i], decoder_vars[3 + 26 * i], - decoder_vars[4 + 26 * i], decoder_vars[5 + 26 * i], - decoder_vars[6 + 26 * i], decoder_vars[7 + 26 * i], - decoder_vars[8 + 26 * i], decoder_vars[9 + 26 * i], - decoder_vars[10 + 26 * i], decoder_vars[11 + 26 * i], - decoder_vars[12 + 26 * i], decoder_vars[13 + 26 * i], - decoder_vars[14 + 26 * i], decoder_vars[15 + 26 * i], - decoder_vars[16 + 26 * i], decoder_vars[17 + 26 * i], - decoder_vars[18 + 26 * i], decoder_vars[19 + 26 * i], - decoder_vars[20 + 26 * i], decoder_vars[21 + 26 * i], - decoder_vars[22 + 26 * i], decoder_vars[23 + 26 * i], - decoder_vars[24 + 26 * i], decoder_vars[25 + 26 * i], - op_self_cache[i], op_mem_cache[i], - psuedo_input, # add tf_result as input to prevent the OP and TF from parallel execution and lead to error result + inputs, # 0 + memory_tensor, # 1 + memory_sequence_length, # 2 + var_dict[layer_prefix_name + 'masked_multi_head/LayerNorm/beta:0'], # 3 + var_dict[layer_prefix_name + 'masked_multi_head/LayerNorm/gamma:0'], # 4 + var_dict[layer_prefix_name + 'masked_multi_head/query/kernel:0'], # 5 + var_dict[layer_prefix_name + 'masked_multi_head/query/bias:0'], # 6 + var_dict[layer_prefix_name + 'masked_multi_head/key/kernel:0'], # 7 + var_dict[layer_prefix_name + 'masked_multi_head/key/bias:0'], # 8 + var_dict[layer_prefix_name + 'masked_multi_head/value/kernel:0'], # 9 + var_dict[layer_prefix_name + 'masked_multi_head/value/bias:0'], # 10 + var_dict[layer_prefix_name + 'masked_multi_head/conv1d_1/kernel:0'], # 11 + var_dict[layer_prefix_name + 'masked_multi_head/conv1d_1/bias:0'], # 12 + var_dict[layer_prefix_name + 'multi_head/LayerNorm/beta:0'], # 13 + var_dict[layer_prefix_name + 'multi_head/LayerNorm/gamma:0'], # 14 + var_dict[layer_prefix_name + 'multi_head/query/kernel:0'], # 15 + var_dict[layer_prefix_name + 'multi_head/query/bias:0'], # 16 + var_dict[layer_prefix_name + 'multi_head/key/kernel:0'], # 17 + var_dict[layer_prefix_name + 'multi_head/key/bias:0'], # 18 + var_dict[layer_prefix_name + 'multi_head/value/kernel:0'], # 19 + var_dict[layer_prefix_name + 'multi_head/value/bias:0'], # 20 + var_dict[layer_prefix_name + 'multi_head/conv1d_2/kernel:0'], # 21 + var_dict[layer_prefix_name + 'multi_head/conv1d_2/bias:0'], # 22 + var_dict[layer_prefix_name + 'ffn/LayerNorm/beta:0'], # 23 + var_dict[layer_prefix_name + 'ffn/LayerNorm/gamma:0'], # 24 + var_dict[layer_prefix_name + 'ffn/conv1d/kernel:0'], # 25 + var_dict[layer_prefix_name + 'ffn/conv1d/bias:0'], # 26 + var_dict[layer_prefix_name + 'ffn/conv1d_1/kernel:0'], # 27 + var_dict[layer_prefix_name + 'ffn/conv1d_1/bias:0'], # 28 + op_self_cache[i], # 29 + op_mem_cache[i], # 30 + psuedo_input, # 31, add tf_result as input to prevent the OP and TF from parallel execution and lead to error result head_num=decoder_args.head_num, size_per_head=decoder_args.size_per_head) inputs = op_result - + return op_result, op_self_cache, op_mem_cache diff --git a/sample/tensorflow/utils/decoding.py b/sample/tensorflow/utils/decoding.py index 4b2aa6892..1ae4cf1ae 100644 --- a/sample/tensorflow/utils/decoding.py +++ b/sample/tensorflow/utils/decoding.py @@ -15,37 +15,41 @@ import numpy as np import tensorflow as tf import os -from decoder import tf_decoder, op_decoder, init_op_cache, init_tf_cache -from common import create_initializer, _get_shape_invariants +import pickle +import sys +from utils.decoder import tf_decoder +from utils.decoder import op_decoder +from utils.decoder import init_op_cache +from utils.decoder import init_tf_cache +from utils.common import create_initializer +from utils.common import _get_shape_invariants from utils.position import SinusoidalPositionEncoder +from utils.beam_search import search_word +from utils.sampling import Sampling -def initialize_decoding_variables(decoding_args): +def initialize_decoding_variables(decoding_args, batchxbeam): - start_ids = tf.fill([decoding_args.decoder_args.batch_size * decoding_args.decoder_args.beam_width], - decoding_args.start_id) # [batch_size * beam_width] + start_ids = tf.fill([batchxbeam], decoding_args.start_id) # [batch_size * beam_width] step = tf.constant(0, dtype=tf.int32) # save the output ids for each step outputs = tf.TensorArray(tf.int32, size=0, dynamic_size=True) - cache = init_tf_cache(decoding_args.decoder_args.batch_size * decoding_args.decoder_args.beam_width, + cache = init_tf_cache(batchxbeam, decoding_args.decoder_args.head_num, decoding_args.decoder_args.size_per_head, decoding_args.decoder_args.num_layer, dtype=decoding_args.decoder_args.dtype, num_sources=1) - finished = tf.zeros([decoding_args.decoder_args.batch_size * decoding_args.decoder_args.beam_width], - dtype=tf.bool) # [batch_size * beam_width], record that a sentence is finished or not + finished = tf.zeros([batchxbeam], dtype=tf.bool) # [batch_size * beam_width], record that a sentence is finished or not initial_log_probs = tf.cast(tf.tile([0.] + [-float("inf")] * (decoding_args.decoder_args.beam_width - 1), - [decoding_args.decoder_args.batch_size]), dtype=tf.float32) # [batch_size * beam_width] + [batchxbeam / decoding_args.decoder_args.beam_width]), dtype=tf.float32) # [batch_size * beam_width] # [batch_size * beam_width], record the lengths of all sentences - sequence_lengths = tf.zeros( - [decoding_args.decoder_args.batch_size * decoding_args.decoder_args.beam_width], dtype=tf.int32) + sequence_lengths = tf.zeros([batchxbeam], dtype=tf.int32) # record the beam search indices, used for rebuild the whole sentence in the final parent_ids = tf.TensorArray(tf.int32, size=0, dynamic_size=True) extra_vars = tuple([parent_ids, sequence_lengths]) return start_ids, step, outputs, cache, finished, initial_log_probs, sequence_lengths, extra_vars - def generate_encoder_result(batch_size, max_seq_len, memory_hidden_dim, @@ -53,12 +57,14 @@ def generate_encoder_result(batch_size, memory_sequence_length = np.random.randint( 1, max_seq_len + 1, size=batch_size).astype(np.int32) + memory_sequence_length[np.random.randint(0, batch_size)] = max_seq_len outter_embbeding = np.random.randn(memory_hidden_dim) * 0.01 memory = [] + mem_max_seq_len = np.max(memory_sequence_length) for i in range(batch_size): - data = np.random.randn(max_seq_len, memory_hidden_dim) * 0.01 - for j in range(memory_sequence_length[i], max_seq_len): + data = np.random.randn(mem_max_seq_len, memory_hidden_dim) * 0.01 + for j in range(memory_sequence_length[i], mem_max_seq_len): data[j] = outter_embbeding memory.append(data) memory = np.asarray(memory) @@ -66,65 +72,15 @@ def generate_encoder_result(batch_size, return memory, memory_sequence_length - -def beam_search(beam_width, - vocab_size, - step, - log_probs, - cum_log_probs, - finished, - cache, - extra_vars, - op_self_cache=None): - - parent_ids = extra_vars[0] - sequence_lengths = extra_vars[1] - - # [batch_size * beam_width, vocab_size] + [batch_size * beam_width], has to broadcast - total_probs = log_probs + tf.expand_dims(cum_log_probs, 1) - # [batch_size, beam_width * vocab_size], can skip in cuda - total_probs = tf.reshape(total_probs, [-1, beam_width * vocab_size]) - - # both shapes are: [batch_size, beam_width] - _, sample_ids = tf.nn.top_k(total_probs, beam_width) - # [batch_size * beam_width], can skip in cuda - sample_ids = tf.reshape(sample_ids, [-1]) - word_ids = sample_ids % vocab_size # [batch_size * beam_width] - beam_ids = sample_ids // vocab_size # [batch_size * beam_width] - # [batch_size * beam_width] - beam_indices = ( - tf.range(sample_ids.shape[0]) // beam_width) * beam_width + beam_ids - - sequence_lengths = tf.where( - finished, x=sequence_lengths, y=sequence_lengths + 1) - - # [batch_size * beam_width] - batch_pos = tf.range(sample_ids.shape[0]) // beam_width - cum_log_probs = tf.gather_nd(total_probs, tf.stack( - [batch_pos, sample_ids], axis=-1)) # [batch_size * beam_width] - finished = tf.gather(finished, beam_indices) - sequence_lengths = tf.gather(sequence_lengths, beam_indices) - - cache = tf.contrib.framework.nest.map_structure( - lambda s: tf.gather(s, beam_indices), cache) - if op_self_cache != None: - op_self_cache = tf.contrib.framework.nest.map_structure( - lambda s: tf.gather(s, beam_indices, axis=3), op_self_cache) - - parent_ids = parent_ids.write(step, beam_ids) - extra_vars = [parent_ids, sequence_lengths] - - return word_ids, cum_log_probs, finished, cache, tuple(extra_vars), op_self_cache - - def finalize(beam_width, parent_ids, sequence_lengths, outputs, end_id, max_seq_len=None): maximum_lengths = tf.reduce_max(tf.reshape( sequence_lengths, [-1, beam_width]), axis=-1) + if max_seq_len != None: array_shape = [max_seq_len, -1, beam_width] else: - array_shape = [maximum_lengths[0], -1, beam_width] - + array_shape = [tf.reduce_max(maximum_lengths), -1, beam_width] + step_ids = tf.reshape(outputs, array_shape) parent_ids = tf.reshape(parent_ids, array_shape) @@ -137,210 +93,201 @@ def finalize(beam_width, parent_ids, sequence_lengths, outputs, end_id, max_seq_ lengths = tf.reduce_sum(lengths, axis=-1) return ids, lengths - -def op_decoding(memory_tensor, - memory_sequence_length, - embedding_table, - decoding_vars, - decoding_args): - - decoding_op_module = tf.load_op_library( - os.path.join('./lib/libtf_decoding.so')) - - val_off = 26 - decoding_vars_in_differ_layers = [] - for i in range(val_off): - par = [] - for j in range(decoding_args.decoder_args.num_layer): - par.append(decoding_vars[i + j * val_off]) - decoding_vars_in_differ_layers.append(par) - - extended_memory = tf.contrib.seq2seq.tile_batch( - memory_tensor, multiplier=decoding_args.decoder_args.beam_width) - extended_memory_sequence_length = tf.contrib.seq2seq.tile_batch( - memory_sequence_length, multiplier=decoding_args.decoder_args.beam_width) - - output_ids, parent_ids, sequence_lengths = decoding_op_module.decoding( - extended_memory, extended_memory_sequence_length, - decoding_vars_in_differ_layers[0], decoding_vars_in_differ_layers[1], - decoding_vars_in_differ_layers[2], decoding_vars_in_differ_layers[3], - decoding_vars_in_differ_layers[4], decoding_vars_in_differ_layers[5], - decoding_vars_in_differ_layers[6], decoding_vars_in_differ_layers[7], - decoding_vars_in_differ_layers[8], decoding_vars_in_differ_layers[9], - decoding_vars_in_differ_layers[10], decoding_vars_in_differ_layers[11], - decoding_vars_in_differ_layers[12], decoding_vars_in_differ_layers[13], - decoding_vars_in_differ_layers[14], decoding_vars_in_differ_layers[15], - decoding_vars_in_differ_layers[16], decoding_vars_in_differ_layers[17], - decoding_vars_in_differ_layers[18], decoding_vars_in_differ_layers[19], - decoding_vars_in_differ_layers[20], decoding_vars_in_differ_layers[21], - decoding_vars_in_differ_layers[22], decoding_vars_in_differ_layers[23], - decoding_vars_in_differ_layers[24], decoding_vars_in_differ_layers[25], - decoding_vars[-4], decoding_vars[-3], embedding_table, - decoding_vars[-2], tf.cast(decoding_vars[-1], dtype=tf.float32), - batch_size=decoding_args.decoder_args.batch_size, - beam_width=decoding_args.decoder_args.beam_width, - max_seq_len=decoding_args.decoder_args.max_seq_len, - head_num=decoding_args.decoder_args.head_num, - size_per_head=decoding_args.decoder_args.size_per_head, - num_layer=decoding_args.decoder_args.num_layer, - memory_hidden_dim=decoding_args.encoder_hidden_dim, - vocab_size=decoding_args.vocab_size, - start_id=decoding_args.start_id, end_id=decoding_args.end_id - ) - parent_ids = parent_ids % decoding_args.decoder_args.beam_width - - finalized_output_ids, finalized_sequence_lengths = finalize(decoding_args.decoder_args.beam_width, - parent_ids, - sequence_lengths, - output_ids, - decoding_args.end_id, - decoding_args.decoder_args.max_seq_len) - - finalized_sequence_lengths = tf.minimum( - finalized_sequence_lengths + 1, tf.shape(finalized_output_ids)[2]) +def decoding_body(word_ids, + step, + memory, + memory_sequence_length, + my_cache, + op_self_cache, + op_mem_cache, + embedding_table, + decoding_args, + decoder_type): - return finalized_output_ids, finalized_sequence_lengths, output_ids, parent_ids, sequence_lengths - - -def tf_decoding(memory_tensor, - memory_sequence_length, - embedding_table, - decoding_args, - decoder_type, - kernel_initializer_range, - bias_initializer_range, - atol_threshold=1e-6): - - with tf.variable_scope("transformer/decoding", reuse=tf.AUTO_REUSE): + decoder_args = decoding_args.decoder_args + hidden_dim = decoder_args.hidden_dim + k_init_range = decoder_args.kernel_init_range + data_type = decoder_args.dtype + + batchxbeam = tf.shape(word_ids)[0] + # [batch_size * beam_width, hidden_dim] + inputs = tf.nn.embedding_lookup(embedding_table, word_ids) + # [batch_size * beam_width, 1, hidden_dim] + inputs = tf.expand_dims(inputs, 1) + + inputs *= hidden_dim**0.5 + position_encoder = SinusoidalPositionEncoder() + if position_encoder is not None: + position_encoding_table = position_encoder._create_position_encoding_table(decoding_args.max_seq_len, hidden_dim, data_type) + position_encoding_val = position_encoding_table[step] + position_encoding_val = tf.reshape(position_encoding_val, [1, 1, -1]) + position_encoding_val = tf.tile(position_encoding_val, [batchxbeam, 1, 1]) + inputs = inputs + position_encoding_val + + with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE): + tf_result = tf_decoder(decoder_args=decoder_args, + inputs=inputs, + memory=memory, + memory_sequence_length=memory_sequence_length, + step=step, + cache=my_cache) + + if decoder_type != 0: + decoder_vars = tf.global_variables() + decoder_vars_start_id = 0 + while decoder_vars_start_id < len(decoder_vars): + if decoder_vars[decoder_vars_start_id].name.find("transformer/decoder/layer") != -1: + break + decoder_vars_start_id += 1 + decoder_vars = decoder_vars[decoder_vars_start_id:] + decoder_var_dict = {} + for v in decoder_vars: + decoder_var_dict[v.name] = v + + psuedo_input = [] + if decoder_type == 2: + psuedo_input = tf_result + + op_result, op_self_cache, op_mem_cache = op_decoder(inputs, + memory, + memory_sequence_length, + op_self_cache, + op_mem_cache, + psuedo_input, + decoder_var_dict, + decoder_args) + + result = None + if decoder_type == 0: + result = tf_result + elif decoder_type == 1: + result = op_result + elif decoder_type == 2: + result = tf_result + result_2 = op_result + + flatten_result = tf.reshape(result, [-1]) + flatten_result_2 = tf.reshape(result_2, [-1]) + abs_diff = tf.math.abs(flatten_result - flatten_result_2) + abs_argmax = tf.math.argmax(abs_diff) + result = tf.Print(result, ["[INFO][PYTHON] step:", step, + tf.cond(abs_diff[abs_argmax] / (tf.math.abs(flatten_result[abs_argmax]) + 1e-6) < decoder_args.check_threshold, + lambda: "True", lambda: "False"), + "max abs diff: ", abs_diff[abs_argmax], + " op val: ", flatten_result_2[abs_argmax], + " tf val: ", flatten_result[abs_argmax] ]) + else: + print("[TF][ERROR] decoder type is only 0 or 1 or 2.") + exit(-1) + + result = tf.contrib.layers.layer_norm(result, begin_norm_axis=-1) + + # [batch_size * beam_width, hidden_dim] + result = tf.squeeze(result, axis=1) + logits = tf.layers.dense(result, + decoding_args.vocab_size, + use_bias=True, + bias_initializer=create_initializer(0.0, data_type), + kernel_initializer=create_initializer(k_init_range, data_type), + activation=None) + + return logits, my_cache, op_self_cache, op_mem_cache + +def tf_beamsearch_decoding(memory_tensor, + memory_sequence_length, + embedding_table, + decoding_args, + decoder_type): + ''' + Run the decoding with beam search by TensorFlow. + + Args: + memory_tensor: A tf.tensor with shape [batch_size * beam_width, max(memory_sequence_length), encoder_hidden_dimension]. + The results of encoder transformer layer. The rank must be 3. + Note that it must be extended by beam_width times. + memory_sequence_length: A tf.Tensor with shape [batch_size * beam_width], type tf.int. + The lenght of each sentence of results of encoder. + Note that it must be extended by beam_width times. + embedding_table: A tf.Tensor with shape [vocab_size, hidden_dimension]. + The embedding table of embedding lookup for each step. + decoder_args: The arguments for decoding. The details are in the class "DecodingBeamsearchArgument" of common.py + decoder_type: A int value. Choose to using TensorFlow decoder, FasterTransformer decoder, or both. + If it is 0, then using the TensorFlow decoder only. + If it is 1, then using the FasterTransformer decoder only. + If it is 2, then using both decoder and compare their result. + Outputs: + finalized_tf_output_ids: A tf.Tensor with shape [batch_size, beam_width, max(tf_sequence_lengths)], with tf.int type. + Finalized tf_output_ids by beam search algorithm and tf_parent_ids. + finalized_tf_sequence_lengths: A tf.Tensor with shape [batch_size * beam_width], with int type. + Finalized tf_sequence_lengths by beam search algorithm and tf_parent_ids. + tf_output_ids: A tf.Tensor with shape [batch_size, beam_width, max(tf_sequence_lengths)], with tf.int type. + The results of decoding. It contains the id of token of vocabulary. + tf_parent_ids: A tf.Tensor with shape [batch_size, beam_width, max(tf_sequence_lengths)], with tf.int type. + The beam index of output ids for each step. + tf_sequence_lengths: A tf.Tensor with shape [batch_size * beam_width], with int type. + ''' + + decoder_args = decoding_args.decoder_args + beam_width = decoder_args.beam_width + search_method = decoding_args.search_method + with tf.variable_scope("transformer", reuse=tf.AUTO_REUSE): # copy memory and memory_sequence_length by beam_width times # if memory is [a, b, c], beam_width = 3, then the result is: [a a a b b b c c c ] - extended_memory = tf.contrib.seq2seq.tile_batch( - memory_tensor, multiplier=decoding_args.decoder_args.beam_width) + extended_memory = tf.contrib.seq2seq.tile_batch(memory_tensor, multiplier=beam_width) extended_memory_sequence_length = tf.contrib.seq2seq.tile_batch( - memory_sequence_length, multiplier=decoding_args.decoder_args.beam_width) + memory_sequence_length, multiplier=beam_width) def _cond(word_ids, cum_log_probs, finished, step, outputs, my_cache, extra_vars, op_self_cache, op_mem_cache): return tf.reduce_any(tf.logical_not(finished)) def _body(word_ids, cum_log_probs, finished, step, outputs, my_cache, extra_vars, op_self_cache, op_mem_cache): - # [batch_size * beam_width, hidden_dim] - inputs = tf.nn.embedding_lookup(embedding_table, word_ids) - # [batch_size * beam_width, 1, hidden_dim] - inputs = tf.expand_dims(inputs, 1) - - inputs *= decoding_args.decoder_args.hidden_dim**0.5 - position_encoder = SinusoidalPositionEncoder() - if position_encoder is not None: - inputs = position_encoder( - inputs, position=step + 1 if step is not None else None) - - with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE): - tf_result = tf_decoder(decoder_args=decoding_args.decoder_args, - inputs=inputs, - memory=extended_memory, - memory_sequence_length=extended_memory_sequence_length, - step=step, - cache=my_cache, - kernel_initializer_range=kernel_initializer_range, - bias_initializer_range=bias_initializer_range) - - - if decoder_type != 0: - decoder_vars = tf.global_variables() - decoder_vars_start_id = 0 - while decoder_vars_start_id < len(decoder_vars): - if decoder_vars[decoder_vars_start_id].name.find("transformer/decoding/decoder") != -1: - break - decoder_vars_start_id += 1 - decoder_vars = decoder_vars[decoder_vars_start_id:] - - psuedo_input = [] - if decoder_type == 2: - psuedo_input = tf_result - - op_result, op_self_cache, op_mem_cache = op_decoder(inputs, - step, - extended_memory, - extended_memory_sequence_length, - op_self_cache, - op_mem_cache, - psuedo_input, - decoder_vars, - decoding_args.decoder_args, - decoding_args.encoder_hidden_dim) - - result = None - if decoder_type == 0: - result = tf_result - elif decoder_type == 1: - result = op_result - elif decoder_type == 2: - result = tf_result - result_2 = op_result - - flatten_result = tf.reshape(result, [-1]) - flatten_result_2 = tf.reshape(result_2, [-1]) - abs_diff = tf.math.abs(flatten_result - flatten_result_2) - argmax = tf.math.argmax(abs_diff) - result = tf.Print(result, ["[INFO][PYTHON] step:", step, "max diff: ", abs_diff[argmax], - " op val: ", flatten_result_2[argmax], - " tf val: ", flatten_result[argmax], - tf.cond(abs_diff[argmax] < atol_threshold, lambda: "True", lambda: "False")]) - else: - print("[TF][ERROR] decoder type is only 0 or 1 or 2.") - exit(-1) - - result = tf.contrib.layers.layer_norm(result, begin_norm_axis=-1) - # [batch_size * beam_width, hidden_dim] - result = tf.squeeze(result, axis=1) - logits = tf.layers.dense(result, - decoding_args.vocab_size, - use_bias=True, - bias_initializer=create_initializer( - bias_initializer_range, decoding_args.decoder_args.dtype), - kernel_initializer=create_initializer( - kernel_initializer_range, decoding_args.decoder_args.dtype), - activation=None) - - end_ids = tf.fill([decoding_args.decoder_args.batch_size * decoding_args.decoder_args.beam_width], - decoding_args.end_id) # [batch_size * beam_width] + logits, my_cache, op_self_cache, op_mem_cache = decoding_body(word_ids, + step, + extended_memory, + extended_memory_sequence_length, + my_cache, + op_self_cache, + op_mem_cache, + embedding_table, + decoding_args, + decoder_type) + + end_ids = tf.fill([tf.shape(logits)[0]], decoding_args.end_id) # [batch_size * beam_width] eos_max_prob = tf.one_hot(end_ids, decoding_args.vocab_size, - on_value=decoding_args.decoder_args.dtype.max, - off_value=decoding_args.decoder_args.dtype.min) # [batch_size * beam_width, vocab_size] + on_value=decoder_args.dtype.max, + off_value=decoder_args.dtype.min) # [batch_size * beam_width, vocab_size] + # [batch_size * beam_width, vocab_size] logits = tf.where(finished, x=eos_max_prob, y=logits) logits = tf.cast(logits, tf.float32) - # [batch_size * beam_width, vocab_size] - log_probs = tf.nn.log_softmax(logits) - + output_id, next_cum_log_probs, finished, my_cache, \ - extra_vars, op_self_cache = beam_search(decoding_args.decoder_args.beam_width, + extra_vars, op_self_cache = search_word(beam_width, decoding_args.vocab_size, step, - log_probs, + logits, cum_log_probs, finished, my_cache, extra_vars, - op_self_cache) - + op_self_cache, + search_method=search_method) + cum_log_probs = tf.where(finished, x=cum_log_probs, y=next_cum_log_probs) + outputs = outputs.write(step, output_id) - cum_log_probs = tf.where( - finished, x=cum_log_probs, y=next_cum_log_probs) - finished = tf.logical_or(finished, tf.equal( - output_id, decoding_args.end_id)) + finished = tf.logical_or(finished, tf.equal(output_id, decoding_args.end_id)) return output_id, cum_log_probs, finished, step + 1, outputs, my_cache, extra_vars, op_self_cache, op_mem_cache # initialization + batchxbeam = tf.shape(extended_memory)[0] start_ids, step, outputs, tf_decoder_cache, finished, initial_log_probs, \ - tf_sequence_lengths, extra_vars = initialize_decoding_variables( - decoding_args) + tf_sequence_lengths, extra_vars = initialize_decoding_variables(decoding_args, batchxbeam) word_ids = tf.identity(start_ids, name="word_ids") cum_log_probs = tf.identity(initial_log_probs, name="cum_log_probs") # if use_op == False, these two caches are useless - op_self_cache, op_mem_cache = init_op_cache(decoding_args.decoder_args) + op_self_cache, op_mem_cache = init_op_cache(decoder_args, batchxbeam, tf.reduce_max(memory_sequence_length)) _, _, _, _, outputs, _, extra_vars, _, _ = tf.while_loop( _cond, @@ -357,35 +304,532 @@ def _body(word_ids, cum_log_probs, finished, step, outputs, my_cache, extra_vars op_mem_cache ), back_prop=False, - maximum_iterations=decoding_args.decoder_args.max_seq_len, + maximum_iterations=decoding_args.max_seq_len, shape_invariants=( start_ids.shape, initial_log_probs.shape, finished.shape, step.shape, tf.TensorShape(None), - tf.contrib.framework.nest.map_structure( - _get_shape_invariants, tf_decoder_cache), - tf.contrib.framework.nest.map_structure( - _get_shape_invariants, extra_vars), - tf.contrib.framework.nest.map_structure( - _get_shape_invariants, op_self_cache), + tf.contrib.framework.nest.map_structure(_get_shape_invariants, tf_decoder_cache), + tf.contrib.framework.nest.map_structure(_get_shape_invariants, extra_vars), + tf.contrib.framework.nest.map_structure(_get_shape_invariants, op_self_cache), tf.contrib.framework.nest.map_structure(_get_shape_invariants, op_mem_cache)) ) tf_parent_ids = extra_vars[0].stack() tf_sequence_lengths = extra_vars[1] tf_output_ids = outputs.stack() - - finalized_tf_output_ids, finalized_tf_sequence_lengths = finalize(decoding_args.decoder_args.beam_width, + + finalized_tf_output_ids, finalized_tf_sequence_lengths = finalize(beam_width, tf_parent_ids, tf_sequence_lengths, tf_output_ids, decoding_args.end_id) - finalized_tf_output_ids = tf.cast( - finalized_tf_output_ids, start_ids.dtype) + finalized_tf_output_ids = tf.cast(finalized_tf_output_ids, start_ids.dtype) finalized_tf_sequence_lengths = tf.minimum( finalized_tf_sequence_lengths + 1, tf.shape(finalized_tf_output_ids)[2]) return finalized_tf_output_ids, finalized_tf_sequence_lengths, tf_output_ids, tf_parent_ids, tf_sequence_lengths + +def tf_sampling_decoding(memory_tensor, + memory_sequence_length, + embedding_table, + decoding_args, + decoder_type): + ''' + Run the decoding with sampling by TensorFlow. + + Args: + memory_tensor: A tf.tensor with shape [batch_size, max(memory_sequence_length), encoder_hidden_dimension]. + The results of encoder transformer layer. The rank must be 3. + memory_sequence_length: A tf.Tensor with shape [batch_size], type tf.int. + The lenght of each sentence of results of encoder. + embedding_table: A tf.Tensor with shape [vocab_size, hidden_dimension]. + The embedding table of embedding lookup for each step. + decoder_args: The arguments for decoding. The details are in the class "DecodingSamplingArgument" of common.py + decoder_type: A int value. Choose to using TensorFlow decoder, FasterTransformer decoder, or both. + If it is 0, then using the TensorFlow decoder only. + If it is 1, then using the FasterTransformer decoder only. + If it is 2, then using both decoder and compare their result. + Outputs: + tf_output_ids: A tf.Tensor with shape [batch_size, max(sequence_lengths)], with int type. + The results of decoding. It contains the id of token of vocabulary. + sequence_lengths: A tf.Tensor with shape [batch_size], with int type. + ''' + + decoder_args = decoding_args.decoder_args + + with tf.variable_scope("transformer", reuse=tf.AUTO_REUSE): + batch_size = tf.shape(memory_tensor)[0] + + def _cond(word_ids, finished, step, outputs, my_cache, sequence_lengths, op_self_cache, op_mem_cache): + return tf.reduce_any(tf.logical_not(finished)) + + def _body(word_ids, finished, step, outputs, my_cache, sequence_lengths, op_self_cache, op_mem_cache): + logits, my_cache, op_self_cache, op_mem_cache = decoding_body(word_ids, + step, + memory_tensor, + memory_sequence_length, + my_cache, + op_self_cache, + op_mem_cache, + embedding_table, + decoding_args, + decoder_type) + + end_ids = tf.fill([batch_size],decoding_args.end_id) # [batch_size * beam_width] + eos_max_prob = tf.one_hot(end_ids, decoding_args.vocab_size, + on_value=decoder_args.dtype.max, + off_value=decoder_args.dtype.min) # [batch_size * beam_width, vocab_size] + # [batch_size, vocab_size] + logits = tf.where(finished, x=eos_max_prob, y=logits) + logits = tf.cast(logits, tf.float32) + + # sampling + if decoding_args.top_k != 0: + sampling_method = Sampling("top_k") + output_id = sampling_method.sample(logits, threshold=decoding_args.top_k) + elif decoding_args.top_p != 0.0: + sampling_method = Sampling("top_p") + output_id = sampling_method.sample(logits, threshold=decoding_args.top_p) + sequence_lengths = tf.where(finished, x=sequence_lengths, y=sequence_lengths + 1) + + outputs = outputs.write(step, output_id) + finished = tf.logical_or(finished, tf.equal(output_id, decoding_args.end_id)) + + # return output_id, cum_log_probs, finished, step + 1, outputs, my_cache, extra_vars, op_self_cache, op_mem_cache + return output_id, finished, step + 1, outputs, my_cache, sequence_lengths, op_self_cache, op_mem_cache + + # initialization + start_ids, step, outputs, tf_decoder_cache, finished, _, \ + _, extra_vars = initialize_decoding_variables(decoding_args, batch_size) + + sequence_lengths = extra_vars[1] + word_ids = tf.identity(start_ids, name="word_ids") + # if use_op == False, these two caches are useless + op_self_cache, op_mem_cache = init_op_cache(decoder_args, batch_size, tf.reduce_max(memory_sequence_length)) + + _, _, _, outputs, _, sequence_lengths, _, _ = tf.while_loop( + _cond, + _body, + loop_vars=( + word_ids, + finished, + step, + outputs, + tf_decoder_cache, + sequence_lengths, + op_self_cache, + op_mem_cache + ), + back_prop=False, + maximum_iterations=decoding_args.max_seq_len, + shape_invariants=( + start_ids.shape, + finished.shape, + step.shape, + tf.TensorShape(None), + tf.contrib.framework.nest.map_structure( + _get_shape_invariants, tf_decoder_cache), + tf.contrib.framework.nest.map_structure( + _get_shape_invariants, sequence_lengths), + tf.contrib.framework.nest.map_structure( + _get_shape_invariants, op_self_cache), + tf.contrib.framework.nest.map_structure(_get_shape_invariants, op_mem_cache)) + ) + + tf_output_ids = outputs.stack() + tf_sequence_lengths = sequence_lengths + tf_output_ids = tf.reshape(tf_output_ids, [-1, batch_size]) + tf_output_ids = tf.transpose(tf_output_ids, [1, 0]) + tf_output_ids = tf.cast(tf_output_ids, start_ids.dtype) + + return tf_output_ids, sequence_lengths + +def preprocess_decoder_var(decoding_vars, + num_layer, + using_model_var, + checkpoint_filename, + data_type, + fuse_qkv=True): + ''' + Args: + decoding_vars: A list of tf.Tensor. The variables of decoding. + num_layer: A int value. The number of transformer layer of decoder in decoding + using_model_var: A bool value. Using the model variables of TensorFlow or not. + If True, then putting the model variables of TensorFlow decoding model into decoding op directly. + The data type is tensor of TensorFlow in this case. + + If False, then restoring the values of variables from the checkpoint_filename, and putting + the values into decoding op. + The data type is numpy is this case. + checkpoint_file: A string. The checkpoint file name of storing the values of model. The checkpoint should be stored in + pickle, and the name of checkpoint should be xxx.pkl. + The model is saved by dict. + The key of the dict is the name of variables + The value of the dict is the values of variables + For example, decoding_vars[0]=, + then the key is 'transformer/decoder/layer_0/masked_multi_head/LayerNorm/beta:0'; the value is sess.run(decoding_vars[0]) + data_type: tf.float32 or tf.float16. + Only used when using_model_var is False. Convert the numpy data to the data type of model. + + Outputs: + vars_in_diff_layers_dict: A dict to store the variables by their name. + + For decoder variables, the key is like 'transformer/decoder/layer/masked_multi_head/LayerNorm/beta:0', + which is similar to the name of variables, except we use 'layer' but not 'layer_x'. The value is a list, + which contains 'transformer/decoder/layer_%d/masked_multi_head/LayerNorm/beta:0' % i for i in range(num_layer) + + For other variables, the key is the name of variable, and the value is the correspoding weight. + + Note that we return the concated weights. The concat operation would bring other overhead, and this should be optimized in + the real application. The recommended method is pre-processing the weights as numpy format. Because TensorFlow do the operations + for each inference if using the TensorFlow to pre-process the weights. + ''' + + var_dict = {} + if using_model_var == False: + # restore the model from the checkpoint file + if(checkpoint_filename == None): + print("[ERROR] checkpoint_filename cannot be None when using_model_var is False.") + exit(-1) + + with open(checkpoint_filename, 'rb') as f: + ckpt = pickle.load(f) + + for var in decoding_vars: + var_dict[var.name] = ckpt[var.name] + else: + for var in decoding_vars: + var_dict[var.name] = var + + vars_in_diff_layers_dict = {} + vars_in_diff_layers_dict["transformer/decoder/LayerNorm/beta:0"] = var_dict["transformer/decoder/LayerNorm/beta:0"] + vars_in_diff_layers_dict["transformer/decoder/LayerNorm/gamma:0"] = var_dict["transformer/decoder/LayerNorm/gamma:0"] + vars_in_diff_layers_dict["transformer/decoder/dense/kernel:0"] = var_dict["transformer/decoder/dense/kernel:0"] + vars_in_diff_layers_dict["transformer/decoder/dense/bias:0"] = tf.cast(var_dict["transformer/decoder/dense/bias:0"], dtype=tf.float32) + + for i in range(num_layer): + ''' + Handling the names of q, k, v kernel and bias because their names + are different for fusing the qkv or not. + ''' + + layer_prefix_name = "transformer/decoder/layer_%d/" % i + if fuse_qkv == True: + var_dict[layer_prefix_name + 'masked_multi_head/query/kernel:0'], \ + var_dict[layer_prefix_name + 'masked_multi_head/key/kernel:0'], \ + var_dict[layer_prefix_name + 'masked_multi_head/value/kernel:0'] = \ + tf.split(var_dict[layer_prefix_name + 'masked_multi_head/conv1d/kernel:0'], 3, axis=-1) + + var_dict[layer_prefix_name + 'masked_multi_head/query/bias:0'], \ + var_dict[layer_prefix_name + 'masked_multi_head/key/bias:0'], \ + var_dict[layer_prefix_name + 'masked_multi_head/value/bias:0'] = \ + tf.split(var_dict[layer_prefix_name + 'masked_multi_head/conv1d/bias:0'], 3, axis=-1) + + var_dict[layer_prefix_name + 'multi_head/query/kernel:0'] = \ + var_dict[layer_prefix_name + 'multi_head/conv1d/kernel:0'] + var_dict[layer_prefix_name + 'multi_head/query/bias:0'] = \ + var_dict[layer_prefix_name + 'multi_head/conv1d/bias:0'] + var_dict[layer_prefix_name + 'multi_head/key/kernel:0'], \ + var_dict[layer_prefix_name + 'multi_head/value/kernel:0'] = \ + tf.split(var_dict[layer_prefix_name + 'multi_head/conv1d_1/kernel:0'], 2, axis=-1) + var_dict[layer_prefix_name + 'multi_head/key/bias:0'], \ + var_dict[layer_prefix_name + 'multi_head/value/bias:0'] = \ + tf.split(var_dict[layer_prefix_name + 'multi_head/conv1d_1/bias:0'], 2, axis=-1) + else: + var_dict[layer_prefix_name + 'masked_multi_head/query/kernel:0'] = \ + var_dict[layer_prefix_name + 'masked_multi_head/conv1d/kernel:0'] + var_dict[layer_prefix_name + 'masked_multi_head/key/kernel:0'] = \ + var_dict[layer_prefix_name + 'masked_multi_head/key/kernel:0'] + var_dict[layer_prefix_name + 'masked_multi_head/value/kernel:0'] = \ + var_dict[layer_prefix_name + 'masked_multi_head/value/kernel:0'] + + var_dict[layer_prefix_name + 'masked_multi_head/query/bias:0'] = \ + var_dict[layer_prefix_name + 'masked_multi_head/conv1d/bias:0'] + var_dict[layer_prefix_name + 'masked_multi_head/key/bias:0'] = \ + var_dict[layer_prefix_name + 'masked_multi_head/key/bias:0'] + var_dict[layer_prefix_name + 'masked_multi_head/value/bias:0'] = \ + var_dict[layer_prefix_name + 'masked_multi_head/value/bias:0'] + + var_dict[layer_prefix_name + 'multi_head/query/kernel:0'] = \ + var_dict[layer_prefix_name + 'multi_head/conv1d/kernel:0'] + var_dict[layer_prefix_name + 'multi_head/query/bias:0'] = \ + var_dict[layer_prefix_name + 'multi_head/conv1d/bias:0'] + var_dict[layer_prefix_name + 'multi_head/key/kernel:0'] = \ + var_dict[layer_prefix_name + 'multi_head/conv1d_1/kernel:0'] + var_dict[layer_prefix_name + 'multi_head/key/bias:0'] = \ + var_dict[layer_prefix_name + 'multi_head/conv1d_1/bias:0'] + var_dict[layer_prefix_name + 'multi_head/value/kernel:0'] = \ + var_dict[layer_prefix_name + 'multi_head/value/kernel:0'] + var_dict[layer_prefix_name + 'multi_head/value/bias:0'] = \ + var_dict[layer_prefix_name + 'multi_head/value/bias:0'] + + layer_prefix_name = 'transformer/decoder/layer' + vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/LayerNorm/beta:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/LayerNorm/beta:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/LayerNorm/gamma:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/LayerNorm/gamma:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/query/kernel:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/query/kernel:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/query/bias:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/query/bias:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/key/kernel:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/key/kernel:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/key/bias:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/key/bias:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/value/kernel:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/value/kernel:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/value/bias:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/value/bias:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/conv1d_1/kernel:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/conv1d_1/kernel:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/conv1d_1/bias:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/conv1d_1/bias:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + + vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/LayerNorm/beta:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/LayerNorm/beta:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/LayerNorm/gamma:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/LayerNorm/gamma:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/query/kernel:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/query/kernel:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/query/bias:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/query/bias:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/key/kernel:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/key/kernel:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/key/bias:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/key/bias:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/value/kernel:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/value/kernel:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/value/bias:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/value/bias:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/conv1d_2/kernel:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/conv1d_2/kernel:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/conv1d_2/bias:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/conv1d_2/bias:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type) + + vars_in_diff_layers_dict[layer_prefix_name + '/ffn/LayerNorm/beta:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/ffn/LayerNorm/beta:0' % i] for i in range(num_layer)], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/ffn/LayerNorm/gamma:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/ffn/LayerNorm/gamma:0' % i] for i in range(num_layer)], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/ffn/conv1d/kernel:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/ffn/conv1d/kernel:0' % i] for i in range(num_layer)], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/ffn/conv1d/bias:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/ffn/conv1d/bias:0' % i] for i in range(num_layer)], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/ffn/conv1d_1/kernel:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/ffn/conv1d_1/kernel:0' % i] for i in range(num_layer)], axis=0), dtype=data_type) + vars_in_diff_layers_dict[layer_prefix_name + '/ffn/conv1d_1/bias:0'] = \ + tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/ffn/conv1d_1/bias:0' % i] for i in range(num_layer)], axis=0), dtype=data_type) + + return vars_in_diff_layers_dict + +def op_beamsearch_decoding(memory_tensor, + memory_sequence_length, + embedding_table, + decoding_vars, + decoding_args, + using_model_var=True, + checkpoint_filename=None): + ''' + Run the decoding with beam search by TensorFlow. + + Args: + memory_tensor: A tf.tensor with shape [batch_size * beam_width, max(memory_sequence_length), encoder_hidden_dimension]. + The results of encoder transformer layer. The rank must be 3. + Note that it must be extended by beam_width times. + memory_sequence_length: A tf.Tensor with shape [batch_size * beam_width], type tf.int. + The lenght of each sentence of results of encoder. + Note that it must be extended by beam_width times. + embedding_table: A tf.Tensor with shape [vocab_size, hidden_dimension]. + The embedding table of embedding lookup for each step. + decoder_vars: A list of tf.Tensor. The variables for decoding. A list of model variables of TensorFlow model. + decoder_args: The arguments for decoding. The details are in the class "DecodingBeamsearchArgument" of common.py + using_model_var: A bool value. Using the model variables of TensorFlow or not. + The details are described in 'preprocess_decoder_var' function in the following. + checkpoint_filename: A string. The checkpoint file name of storing the values of model. + The details are described in 'preprocess_decoder_var' function in the following. + Outputs: + finalized_output_ids: A tf.Tensor with shape [batch_size, beam_width, max(sequence_lengths)], with tf.int type. + Finalized output_ids by beam search algorithm and parent_ids. + finalized_sequence_lengths: A tf.Tensor with shape [batch_size * beam_width], with int type. + Finalized sequence_lengths by beam search algorithm and parent_ids. + output_ids: A tf.Tensor with shape [batch_size, beam_width, max(sequence_lengths)], with tf.int type. + The results of decoding. It contains the id of token of vocabulary. + parent_ids: A tf.Tensor with shape [batch_size, beam_width, max(sequence_lengths)], with tf.int type. + The beam index of output ids for each step. + sequence_lengths: A tf.Tensor with shape [batch_size * beam_width], with int type. + ''' + + decoder_args = decoding_args.decoder_args + decoding_op_module = tf.load_op_library(os.path.join('./lib/libtf_decoding_beamsearch.so')) + + vars_dict_in_differ_layers = preprocess_decoder_var(decoding_vars, + decoder_args.num_layer, + using_model_var, + checkpoint_filename, + decoder_args.dtype, + decoder_args.fuse_qkv) + + extended_memory = tf.contrib.seq2seq.tile_batch( + memory_tensor, multiplier=decoder_args.beam_width) + extended_memory_sequence_length = tf.contrib.seq2seq.tile_batch( + memory_sequence_length, multiplier=decoder_args.beam_width) + + position_encoder = SinusoidalPositionEncoder() + position_encoding_table = position_encoder._create_position_encoding_table( + decoding_args.max_seq_len, decoder_args.head_num * decoder_args.size_per_head, decoder_args.dtype) + # shape of position_encoding_table: [max_seq_len, hidden_dim] + + output_ids, parent_ids, sequence_lengths = decoding_op_module.decoding( + extended_memory, # 0 + extended_memory_sequence_length, # 1 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/LayerNorm/beta:0'], # 2 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/LayerNorm/gamma:0'], # 3 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/query/kernel:0'], # 4 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/query/bias:0'], # 5 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/key/kernel:0'], # 6 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/key/bias:0'], # 7 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/value/kernel:0'], # 8 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/value/bias:0'], # 9 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/conv1d_1/kernel:0'], # 10 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/conv1d_1/bias:0'], # 11 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/LayerNorm/beta:0'], # 12 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/LayerNorm/gamma:0'], # 13 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/query/kernel:0'], # 14 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/query/bias:0'], # 15 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/key/kernel:0'], # 16 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/key/bias:0'], # 17 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/value/kernel:0'], # 18 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/value/bias:0'], # 19 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/conv1d_2/kernel:0'], # 20 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/conv1d_2/bias:0'], # 21 + vars_dict_in_differ_layers['transformer/decoder/layer/ffn/LayerNorm/beta:0'], # 22 + vars_dict_in_differ_layers['transformer/decoder/layer/ffn/LayerNorm/gamma:0'], # 23 + vars_dict_in_differ_layers['transformer/decoder/layer/ffn/conv1d/kernel:0'], # 24 + vars_dict_in_differ_layers['transformer/decoder/layer/ffn/conv1d/bias:0'], # 25 + vars_dict_in_differ_layers['transformer/decoder/layer/ffn/conv1d_1/kernel:0'], # 26 + vars_dict_in_differ_layers['transformer/decoder/layer/ffn/conv1d_1/bias:0'], # 27 + vars_dict_in_differ_layers['transformer/decoder/LayerNorm/beta:0'], # 28 + vars_dict_in_differ_layers['transformer/decoder/LayerNorm/gamma:0'], # 29 + embedding_table, # 30 + vars_dict_in_differ_layers['transformer/decoder/dense/kernel:0'], # 31 + vars_dict_in_differ_layers['transformer/decoder/dense/bias:0'], # 32 + position_encoding_table, # 33 + beam_width=decoder_args.beam_width, + max_seq_len=decoding_args.max_seq_len, + head_num=decoder_args.head_num, + size_per_head=decoder_args.size_per_head, + num_layer=decoder_args.num_layer, + start_id=decoding_args.start_id, + end_id=decoding_args.end_id, + beam_search_diversity_rate=decoding_args.beam_search_diversity_rate + ) + parent_ids = parent_ids % decoder_args.beam_width + + finalized_output_ids, finalized_sequence_lengths = finalize(decoder_args.beam_width, + parent_ids, + sequence_lengths, + output_ids, + decoding_args.end_id, + decoding_args.max_seq_len) + + finalized_sequence_lengths = tf.minimum( + finalized_sequence_lengths + 1, tf.shape(finalized_output_ids)[2]) + + return finalized_output_ids, finalized_sequence_lengths, output_ids, parent_ids, sequence_lengths + +def op_sampling_decoding(memory_tensor, + memory_sequence_length, + embedding_table, + decoding_vars, + decoding_args, + using_model_var=True, + checkpoint_filename=None): + ''' + Run the decoding with sampling by FasterTransformer. + + Args: + memory_tensor: A tf.tensor with shape [batch_size, max(memory_sequence_length), encoder_hidden_dimension]. + The results of encoder transformer layer. The rank must be 3. + memory_sequence_length: A tf.Tensor with shape [batch_size], type tf.int. + The lenght of each sentence of results of encoder. + embedding_table: A tf.Tensor with shape [vocab_size, hidden_dimension]. + The embedding table of embedding lookup for each step. + decoder_vars: A list of tf.Tensor. The variables for decoding. A list of model variables of TensorFlow model. + decoder_args: The arguments for decoding. The details are in the class "DecodingSamplingArgument" of common.py + using_model_var: A bool value. Using the model variables of TensorFlow or not. + The details are described in 'preprocess_decoder_var' function in the following. + checkpoint_filename: A string. The checkpoint file name of storing the values of model. + The details are described in 'preprocess_decoder_var' function in the following. + Outputs: + output_ids: A tf.Tensor with shape [batch_size, max(sequence_lengths)], with int type. + The results of decoding. It contains the id of token of vocabulary. + sequence_lengths: A tf.Tensor with shape [batch_size], with int type. + ''' + + decoder_args = decoding_args.decoder_args + decoding_op_module = tf.load_op_library(os.path.join('./lib/libtf_decoding_sampling.so')) + + vars_dict_in_differ_layers = preprocess_decoder_var(decoding_vars, + decoding_args.decoder_args.num_layer, + using_model_var, + checkpoint_filename, + decoder_args.dtype, + decoder_args.fuse_qkv) + + position_encoder = SinusoidalPositionEncoder() + position_encoding_table = position_encoder._create_position_encoding_table( + decoding_args.max_seq_len, decoder_args.head_num * decoder_args.size_per_head, decoder_args.dtype) + # shape of position_encoding_table: [max_seq_len, hidden_dim] + + output_ids, sequence_lengths = decoding_op_module.decoding_sampling( + memory_tensor, # 0 + memory_sequence_length, # 1 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/LayerNorm/beta:0'], # 2 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/LayerNorm/gamma:0'], # 3 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/query/kernel:0'], # 4 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/query/bias:0'], # 5 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/key/kernel:0'], # 6 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/key/bias:0'], # 7 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/value/kernel:0'], # 8 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/value/bias:0'], # 9 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/conv1d_1/kernel:0'], # 10 + vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/conv1d_1/bias:0'], # 11 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/LayerNorm/beta:0'], # 12 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/LayerNorm/gamma:0'], # 13 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/query/kernel:0'], # 14 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/query/bias:0'], # 15 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/key/kernel:0'], # 16 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/key/bias:0'], # 17 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/value/kernel:0'], # 18 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/value/bias:0'], # 19 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/conv1d_2/kernel:0'], # 20 + vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/conv1d_2/bias:0'], # 21 + vars_dict_in_differ_layers['transformer/decoder/layer/ffn/LayerNorm/beta:0'], # 22 + vars_dict_in_differ_layers['transformer/decoder/layer/ffn/LayerNorm/gamma:0'], # 23 + vars_dict_in_differ_layers['transformer/decoder/layer/ffn/conv1d/kernel:0'], # 24 + vars_dict_in_differ_layers['transformer/decoder/layer/ffn/conv1d/bias:0'], # 25 + vars_dict_in_differ_layers['transformer/decoder/layer/ffn/conv1d_1/kernel:0'], # 26 + vars_dict_in_differ_layers['transformer/decoder/layer/ffn/conv1d_1/bias:0'], # 27 + vars_dict_in_differ_layers['transformer/decoder/LayerNorm/beta:0'], # 28 + vars_dict_in_differ_layers['transformer/decoder/LayerNorm/gamma:0'], # 29 + embedding_table, # 30 + vars_dict_in_differ_layers['transformer/decoder/dense/kernel:0'], # 31 + vars_dict_in_differ_layers['transformer/decoder/dense/bias:0'], # 32 + position_encoding_table, # 33 + max_seq_len=decoding_args.max_seq_len, + candidate_num=decoding_args.top_k, + probability_threshold=decoding_args.top_p, + head_num=decoder_args.head_num, + size_per_head=decoder_args.size_per_head, + num_layer=decoder_args.num_layer, + start_id=decoding_args.start_id, + end_id=decoding_args.end_id + ) + batch_size = tf.shape(memory_tensor)[0] + output_ids = tf.reshape(output_ids, [-1, batch_size]) + output_ids = tf.transpose(output_ids, [1, 0]) + + return output_ids, sequence_lengths \ No newline at end of file diff --git a/sample/tensorflow/utils/dump_model.py b/sample/tensorflow/utils/dump_model.py index 16d3ff691..e2a046744 100644 --- a/sample/tensorflow/utils/dump_model.py +++ b/sample/tensorflow/utils/dump_model.py @@ -22,135 +22,30 @@ print("[ERROR] dump_pruned_model.py needs a ckpt file as input. \n e.g. python dump_pruned_model.py model.ckpt") sys.exit(0) -ckpt_name = sys.argv[1] +# Get the values of all variables in the checkpoint file, and then save the values of all variables in a pickle file by dict +# The key of the dict is the name of variables +# The value of the dict is the values of variables +# For example, all_variables[0]=, +# then the key is 'transformer/decoder/layer_0/masked_multi_head/LayerNorm/beta:0'; the value is sess.run(all_variables[0]) + +# If you need to dump the model which has same structure but different variable name, you can convert the name of your model into opennmt's name one by one. +# For example, the name of beta variable of first layer normalization in first layer of decoder is 'transformer/decoder/layer_0/masked_multi_head/LayerNorm/beta:0', +# and in your model, you use other name like 'body/decoder/layer_0/self_attention/LayerNorm/beta:0' +# then the key is: 'transformer/decoder/layer_0/masked_multi_head/LayerNorm/beta:0' (the model name of opennmt) +# and the value is sess.run() (your variable value) +ckpt_name = sys.argv[1] + with tf.Session() as sess: saver = tf.train.import_meta_graph(ckpt_name + ".meta") saver.restore(sess, (ckpt_name)) - - def dumpModel_new(): - all_variables = tf.trainable_variables() - ckpt = {} - - for i, var in enumerate(all_variables): - print("[INFO] %d/%d" %(i, len(all_variables)), end='\r') - sys.stdout.flush() - if var in tf.trainable_variables(): - val = sess.run(var) - name = None - if var.name.find("Adam") != -1: - continue - elif var.name.find("encoder") != -1: - # transformer/encoder/layer_x/multi_head/conv1d/kernel:0 -> transformer/encoder/layer_x/attention/self/query, key, value/kernel:0 - # transformer/encoder/layer_x/multi_head/conv1d_1/kernel:0 -> transformer/encoder/layer_x/attention/output/kernel:0 - # transformer/encoder/layer_x/multi_head/LayerNorm/gamma:0 -> transformer/encoder/layer_x/attention/output/LayerNorm/gamma:0 - if var.name.find("multi_head/conv1d/") != -1: - dim = val.shape[-1] / 3 - Q, K, V = np.split(val, [dim, dim * 2], axis=-1) - ckpt[var.name.replace("multi_head/conv1d/", "attention/self/query/")] = Q - ckpt[var.name.replace("multi_head/conv1d/", "attention/self/key/")] = K - ckpt[var.name.replace("multi_head/conv1d/", "attention/self/value/")] = V - - elif var.name.find("multi_head/conv1d_1/") != -1: - name = var.name.replace("multi_head/conv1d_1/", "attention/output/") - ckpt[name] = val - - elif var.name.find("multi_head/LayerNorm/") != -1: - name = var.name.replace("multi_head/LayerNorm/", "attention/output/LayerNorm/") - ckpt[name] = val - - # transformer/encoder/layer_x/ffn/conv1d/kernel:0 -> transformer/encoder/layer_x/intermediate/dense/kernel:0 - # transformer/encoder/layer_x/ffn/LayerNorm/beta:0 -> transformer/encoder/layer_x/output/LayerNorm/beta:0 - # transformer/encoder/layer_x/ffn/conv1d_1/kernel:0 -> transformer/encoder/layer_x/output/dense/kernel:0 - elif var.name.find("ffn/conv1d/") != -1: - name = var.name.replace("ffn/conv1d/", "intermediate/dense/") - ckpt[name] = val - - elif var.name.find("ffn/LayerNorm/") != -1: - name = var.name.replace("ffn/LayerNorm/", "output/LayerNorm/") - ckpt[name] = val - - elif var.name.find("ffn/conv1d_1/") != -1: - name = var.name.replace("ffn/conv1d_1/", "output/dense/") - ckpt[name] = val - - elif var.name.find("transformer/encoder/w_embs") != -1: - name = var.name - ckpt[name] = val - - elif var.name.find("decoder") != -1: - pre_name = var.name.replace("decoder", "decoding/decoder") - - # transformer/decoder/layer_x/masked_multi_head/conv1d/kernel:0 -> transformer/decoder/layer_x/masked_multi_head/query, key, value/kernel:0 - # transformer/decoder/layer_x/masked_multi_head/conv1d_1/kernel:0 -> transformer/decoder/layer_x/masked_multi_head/conv1d/kernel:0 - # transformer/decoder/layer_x/masked_multi_head/LayerNorm/gamma:0 -> transformer/decoder/layer_x/masked_multi_head/LayerNorm/gamma:0 - if var.name.find("masked_multi_head/conv1d/") != -1: - dim = val.shape[-1] / 3 - Q, K, V = np.split(val, [dim, dim * 2], axis=-1) - ckpt[pre_name.replace("masked_multi_head/conv1d/", "masked_multi_head/query/")] = Q - ckpt[pre_name.replace("masked_multi_head/conv1d/", "masked_multi_head/key/")] = K - ckpt[pre_name.replace("masked_multi_head/conv1d/", "masked_multi_head/value/")] = V - elif var.name.find("masked_multi_head/conv1d_1/") != -1: - name = pre_name.replace("masked_multi_head/conv1d_1/", "masked_multi_head/conv1d/") - ckpt[name] = val - elif var.name.find("masked_multi_head/LayerNorm/") != -1: - name = pre_name - ckpt[name] = val - - # transformer/decoder/layer_x/multi_head/conv1d/kernel:0 -> transformer/decoder/layer_x/multi_head/query/kernel:0 - # transformer/decoder/layer_x/multi_head/conv1d_1/kernel:0 -> transformer/decoder/layer_x/multi_head/key, value/kernel:0 - # transformer/decoder/layer_x/multi_head/conv1d_2/kernel:0 -> transformer/decoder/layer_x/multi_head/conv1d/kernel - # transformer/decoder/layer_x/multi_head/LayerNorm/gamma:0 -> transformer/decoder/layer_x/multi_head/LayerNorm/gamma:0 - elif var.name.find("multi_head/conv1d/") != -1: - name = pre_name.replace("multi_head/conv1d/", "multi_head/query/") - ckpt[name] = val - elif var.name.find("multi_head/conv1d_1/") != -1: - dim = val.shape[-1] / 2 - K, V = np.split(val, [dim], axis=-1) - ckpt[pre_name.replace("multi_head/conv1d_1/", "multi_head/key/")] = K - ckpt[pre_name.replace("multi_head/conv1d_1/", "multi_head/value/")] = V - elif var.name.find("multi_head/conv1d_2/") != -1: - name = pre_name.replace("multi_head/conv1d_2/", "multi_head/conv1d/") - ckpt[name] = val - elif var.name.find("multi_head/LayerNorm/") != -1: - name = pre_name - ckpt[name] = val - - # transformer/decoder/layer_x/ffn/conv1d/kernel:0 -> transformer/decoder/layer_x/intermediate/dense/kernel:0 - # transformer/decoder/layer_x/ffn/LayerNorm/beta:0 -> transformer/decoder/layer_x/output/LayerNorm/beta:0 - # transformer/decoder/layer_x/ffn/conv1d_1/kernel:0 -> transformer/decoder/layer_x/output/dense/kernel:0 - elif var.name.find("ffn/conv1d/") != -1: - # name = var.name.replace("ffn/conv1d/", "intermediate/dense/") - name = pre_name - ckpt[name] = val - elif var.name.find("ffn/LayerNorm/") != -1: - # name = var.name.replace("ffn/LayerNorm/", "output/LayerNorm/") - name = pre_name - ckpt[name] = val - elif var.name.find("ffn/conv1d_1/") != -1: - # name = var.name.replace("ffn/conv1d_1/", "output/dense/") - name = pre_name - ckpt[name] = val - - elif var.name.find("transformer/decoder/w_embs") != -1: - name = var.name.replace("decoder", "decoding") - ckpt[name] = val - - elif var.name.find("transformer/decoder/dense/") != -1: - name = var.name.replace("decoder", "decoding") - ckpt[name] = val - - elif var.name.find("transformer/decoder/LayerNorm/") != -1: - name = var.name.replace("decoder", "decoding") - ckpt[name] = val - - if name != None: - print("[INFO] {} -> {} ".format(var.name, name)) - - for key in ckpt: - print(key) - with open('model.pkl', 'wb') as f: - pickle.dump(ckpt, f, 0) - - dumpModel_new() - + all_variables = tf.trainable_variables() + ckpt = {} + + all_val = sess.run(all_variables) + for var, val in zip(all_variables, all_val): + if var.name.find("Adam") == -1: + ckpt[var.name] = val + + with open('model.pkl', 'wb') as f: + pickle.dump(ckpt, f, pickle.HIGHEST_PROTOCOL) diff --git a/sample/tensorflow/utils/encoder.py b/sample/tensorflow/utils/encoder.py index 2448d0253..64c345ba3 100644 --- a/sample/tensorflow/utils/encoder.py +++ b/sample/tensorflow/utils/encoder.py @@ -17,7 +17,8 @@ import math import six import os -from common import create_initializer +from utils.common import create_initializer +from utils.position import SinusoidalPositionEncoder def gelu(x): cdf = 0.5 * (1.0 + tf.tanh( @@ -45,7 +46,7 @@ def attention_layer(from_tensor, from_seq_length=None, to_seq_length=None, tf_datatype=tf.float32): - + def transpose_for_scores(input_tensor, batch_size, num_attention_heads, seq_length, width): output_tensor = tf.reshape( @@ -120,8 +121,9 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, if attention_mask is not None: # `attention_mask` = [B, 1, F, T] - attention_mask = tf.expand_dims(attention_mask, axis=[1]) - + if tf.rank(attention_mask) == 3: + attention_mask = tf.expand_dims(attention_mask, axis=[1]) + adder = (1.0 - tf.cast(attention_mask, tf_datatype)) * -10000.0 attention_scores += adder @@ -155,7 +157,24 @@ def tf_encoder(input_tensor, attention_mask=None, intermediate_act_fn=gelu, initializer_range=0.02): + ''' + Run the bert transformer layer by TensorFlow. + Args: + inputs: A tf.Tensor with shape [batch_size, seq_len, hidden_dimension]. + The inputs tensor of encoder. The rank must be 3. + encoder_args: The arguments for encoder. The details are in the class + "TransformerArgument" of common.py + attention_mask: A tf.Tensor. The attention mask for self attention. + intermediate_act_fn: A callable function. + The activation function in the FFN. It is gelu in BERT. + initializer_range: A float value. + The range of initializer for all weights. + + Outputs: + outputs: A tf.Tensor with shape [batch_size, seq_len, hidden_dimension]. + The results of encoder. + ''' intermediate_size = encoder_args.hidden_dim * 4 if encoder_args.hidden_dim % encoder_args.head_num != 0: @@ -183,9 +202,9 @@ def tf_encoder(input_tensor, size_per_head=encoder_args.size_per_head, initializer_range=initializer_range, do_return_2d_tensor=True, - batch_size=encoder_args.batch_size, - from_seq_length=encoder_args.max_seq_len, - to_seq_length=encoder_args.max_seq_len, + batch_size=batch_size, + from_seq_length=seq_length, + to_seq_length=seq_length, tf_datatype=encoder_args.dtype) attention_output = attention_head @@ -223,8 +242,118 @@ def tf_encoder(input_tensor, layer_output = layer_norm(layer_output + attention_output) prev_output = layer_output + prev_output = tf.reshape(prev_output, shape=tf.shape(input_tensor)) return prev_output +def build_sequence_mask(sequence_length, + num_heads=None, + maximum_length=None, + dtype=tf.float32): + """Builds the dot product mask. + Args: + sequence_length: The sequence length. + num_heads: The number of heads. + maximum_length: Optional size of the returned time dimension. Otherwise + it is the maximum of :obj:`sequence_length`. + dtype: The type of the mask tensor. + Returns: + A broadcastable ``tf.Tensor`` of type :obj:`dtype` and shape + ``[batch_size, 1, max_length, max_length]``. + """ + mask = tf.sequence_mask(sequence_length, maxlen=maximum_length, dtype=dtype) # [batch_size, maximum_length] + mask = tf.reshape(mask, [-1, 1, 1, maximum_length]) + m_2 = tf.transpose(mask, [0, 1, 3, 2]) + mask = mask * m_2 + + return mask + +def tf_encoder_opennmt(input_tensor, + encoder_args, + initializer_range=0.02, + sequence_length=None): + ''' + Run the bert transformer layer by TensorFlow. + + Args: + input_tensor: A tf.Tensor with shape [batch_size, seq_len, hidden_dimension]. + The inputs tensor of encoder. The rank must be 3. + encoder_args: The arguments for encoder. The details are in the class + "TransformerArgument" of common.py + initializer_range: A float value. + The range of initializer for all weights. + sequence_length: A tf.Tensor with shape [batch_size], with tf.int type. + The sequence length of each sentence in input_tensor. + + Outputs: + output: A tf.Tensor with shape [batch_size, max(sequence_length), hidden_dimension]. + The results of encoder. + ''' + + data_type = encoder_args.dtype + input_shape = get_shape_list(input_tensor, expected_rank=3) + batch_size = input_shape[0] + seq_length = input_shape[1] + + input_tensor *= encoder_args.hidden_dim**0.5 + position_encoder = SinusoidalPositionEncoder() + input_tensor = position_encoder(input_tensor, position=tf.range(seq_length)) + + mask = build_sequence_mask( + sequence_length, + encoder_args.head_num, + maximum_length=tf.shape(input_tensor)[1], + dtype=data_type) + + intermediate_size = encoder_args.hidden_dim * 4 + if encoder_args.hidden_dim % encoder_args.head_num != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (encoder_args.hidden_dim, encoder_args.head_num)) + + layer_input = input_tensor + for layer_idx in range(encoder_args.num_layer): + with tf.variable_scope("layer_%d" % layer_idx, reuse=tf.AUTO_REUSE): + with tf.variable_scope("multi_head"): + normed_input = tf.cast(layer_norm(tf.cast(layer_input, tf.float32)), data_type) + + queries, keys, values = tf.split(tf.layers.conv1d(normed_input, encoder_args.hidden_dim * 3, 1), 3, axis=2) + + # split head + queries = tf.reshape(queries, [batch_size, seq_length, encoder_args.head_num, encoder_args.size_per_head]) + queries = tf.transpose(queries, [0, 2, 1, 3]) + + keys = tf.reshape(keys, [batch_size, seq_length, encoder_args.head_num, encoder_args.size_per_head]) + keys = tf.transpose(keys, [0, 2, 1, 3]) + + values = tf.reshape(values, [batch_size, seq_length, encoder_args.head_num, encoder_args.size_per_head]) + values = tf.transpose(values, [0, 2, 1, 3]) + + queries *= (encoder_args.size_per_head)**-0.5 + + dot = tf.matmul(queries, keys, transpose_b=True) + + if mask is not None: + dot = tf.cast(tf.cast(dot, data_type) * mask + ((1.0 - mask) * data_type.min), dot.dtype) + + attn = tf.cast(tf.nn.softmax(tf.cast(dot, data_type)), dot.dtype) + + context_1 = tf.matmul(attn, values) + context_1 = tf.transpose(context_1, [0, 2, 1, 3]) + context_1 = tf.reshape(context_1, [batch_size, seq_length, encoder_args.hidden_dim]) + attention_output = tf.layers.conv1d(context_1, encoder_args.hidden_dim, 1) + context_2 = attention_output + layer_input + + with tf.variable_scope("ffn"): + normed_context_2 = tf.cast(layer_norm(tf.cast(context_2, tf.float32)), data_type) + intermediate_output = tf.layers.conv1d(normed_context_2, intermediate_size, 1, activation=tf.nn.relu) + layer_output_1 = tf.layers.conv1d(intermediate_output, encoder_args.hidden_dim, 1) + layer_output_2 = layer_output_1 + context_2 + layer_input = layer_output_2 + + layer_input = tf.cast(layer_input, tf.float32) + output = layer_norm(layer_input, name="LayerNorm") + return output + def get_shape_list(tensor, expected_rank=None, name=None): if name is None: @@ -297,28 +426,60 @@ def assert_rank(tensor, expected_rank, name=None): def op_encoder(inputs, encoder_args, - encoder_vars, - attention_mask): - transformer_op_module = tf.load_op_library( - os.path.join('./lib/libtf_fastertransformer.so')) + attention_mask, + encoder_vars_dict, + sequence_length): + ''' + Run the bert transformer layer by FasterTransformer. + + Args: + inputs: A tf.Tensor with shape [batch_size, seq_len, hidden_dimension]. + The inputs tensor of encoder. The rank must be 3. + encoder_args: The arguments for encoder. The details are in the class "TransformerArgument" of common.py + attention_mask: A tf.Tensor. The attention mask for self attention. + encoder_vars_dict: A dict of tf.Tensor or numpy array. + The variables for encoder. They can be either some tensor or some numpy array. + The key is the name of the tensor, like 'layer_0/attention/self/query/kernel:0'. + Teh value is the corresponding tensor or numpy array + sequence_length: A tf.Tensor or numpy array with shape [batch_size]. + The sequence length of the sentences + Outputs: + outputs: A tensor with shape [batch_size, seq_len, hidden_dimension]. + The results of encoder. + ''' + remove_padding = encoder_args.remove_padding + transformer_op_module = tf.load_op_library(os.path.join('./lib/libtf_fastertransformer.so')) + if remove_padding == True: + inputs, sequence_id_offset = transformer_op_module.build_mask_remove_padding(inputs, sequence_length) + else: + sequence_id_offset = [] for layer_idx in range(encoder_args.num_layer): - val_off = layer_idx * 16 outputs = transformer_op_module.bert_transformer( inputs, inputs, - encoder_vars[val_off + 0], encoder_vars[val_off + - 2], encoder_vars[val_off + 4], - encoder_vars[val_off + 1], encoder_vars[val_off + - 3], encoder_vars[val_off + 5], + encoder_vars_dict['layer_%d/attention/self/query/kernel:0' % layer_idx], + encoder_vars_dict['layer_%d/attention/self/query/bias:0' % layer_idx], + encoder_vars_dict['layer_%d/attention/self/key/kernel:0' % layer_idx], + encoder_vars_dict['layer_%d/attention/self/key/bias:0' % layer_idx], + encoder_vars_dict['layer_%d/attention/self/value/kernel:0' % layer_idx], + encoder_vars_dict['layer_%d/attention/self/value/bias:0' % layer_idx], attention_mask, - encoder_vars[val_off + 6], encoder_vars[val_off + - 7], encoder_vars[val_off + 8], - encoder_vars[val_off + 9], encoder_vars[val_off + - 10], encoder_vars[val_off + 11], - encoder_vars[val_off + 12], encoder_vars[val_off + - 13], encoder_vars[val_off + 14], - encoder_vars[val_off + 15], - from_seq_len=encoder_args.max_seq_len, to_seq_len=encoder_args.max_seq_len, - head_num=encoder_args.head_num, size_per_head=encoder_args.size_per_head) + encoder_vars_dict['layer_%d/attention/output/dense/kernel:0' % layer_idx], + encoder_vars_dict['layer_%d/attention/output/dense/bias:0' % layer_idx], + encoder_vars_dict['layer_%d/attention/output/LayerNorm/beta:0' % layer_idx], + encoder_vars_dict['layer_%d/attention/output/LayerNorm/gamma:0' % layer_idx], + encoder_vars_dict['layer_%d/intermediate/dense/kernel:0' % layer_idx], + encoder_vars_dict['layer_%d/intermediate/dense/bias:0' % layer_idx], + encoder_vars_dict['layer_%d/output/dense/kernel:0' % layer_idx], + encoder_vars_dict['layer_%d/output/dense/bias:0' % layer_idx], + encoder_vars_dict['layer_%d/output/LayerNorm/beta:0' % layer_idx], + encoder_vars_dict['layer_%d/output/LayerNorm/gamma:0' % layer_idx], + sequence_id_offset, + head_num=encoder_args.head_num, size_per_head=encoder_args.size_per_head, + remove_padding=remove_padding) inputs = outputs + + if remove_padding == True: + outputs = transformer_op_module.rebuild_padding(outputs, sequence_id_offset, attention_mask) + return outputs diff --git a/sample/tensorflow/utils/position.py b/sample/tensorflow/utils/position.py index 89d9394ce..3a4004407 100644 --- a/sample/tensorflow/utils/position.py +++ b/sample/tensorflow/utils/position.py @@ -15,7 +15,7 @@ import math import abc import tensorflow as tf -from reducer import SumReducer +from utils.reducer import SumReducer class PositionEncoder(tf.keras.layers.Layer): """Base class for position encoders.""" @@ -43,7 +43,7 @@ def call(self, inputs, position=None): # pylint: disable=arguments-differ batch_size = tf.shape(inputs)[0] timesteps = tf.shape(inputs)[1] input_dim = inputs.get_shape().as_list()[-1] # return int - positions = tf.range(timesteps) + 1 if position is None else [position] + positions = tf.range(timesteps) + 1 if position is None else position position_encoding = self._encode([positions], input_dim, dtype=inputs.dtype) position_encoding = tf.tile(position_encoding, [batch_size, 1, 1]) return self.reducer([inputs, position_encoding]) @@ -58,6 +58,12 @@ def _encode(self, positions, depth, dtype): A ``tf.Tensor`` of shape :math:`[B, ..., D]`. """ raise NotImplementedError() + + def _create_position_encoding_table(self, max_seq_len, input_dim, dtype): + positions = tf.range(max_seq_len) + 1 + self.position_encoding_table = self._encode([positions], input_dim, dtype=dtype) + self.position_encoding_table = tf.squeeze(self.position_encoding_table) + return self.position_encoding_table class SinusoidalPositionEncoder(PositionEncoder): diff --git a/sample/tensorflow/utils/sampling.py b/sample/tensorflow/utils/sampling.py new file mode 100644 index 000000000..889ba900c --- /dev/null +++ b/sample/tensorflow/utils/sampling.py @@ -0,0 +1,74 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + +class Sampling(): + + def __init__(self, sample_method): + if sample_method == "top_k": + self.sample_method = self.top_k_logits + elif sample_method == "top_p": + self.sample_method = self.top_p_logits + else: + print("[ERROR] the sample method should be one of top_k and top_p") + exit(-1) + + pass + + def sample(self, logits, threshold, num_samples=1): + ''' + inputs: + logits: [batch_size, vocab_size], the values of log logits + threshold: int when using top_k, and a probability (0~1) when using top_p + + outputs: + samples: [batch_size] + ''' + + logits = self.sample_method(logits, threshold) + samples = tf.multinomial(logits, num_samples=num_samples, output_dtype=tf.int32) + samples = tf.reshape(samples, [-1]) + return samples + + def top_k_logits(self, logits, k): + if k == 0: + return logits + else: + values, _ = tf.nn.top_k(logits, k=k) # [batch size, k] + min_values = values[:, -1, tf.newaxis] #[batch size, 1] + return tf.where( + logits < min_values, + tf.ones_like(logits, dtype=logits.dtype) * logits.dtype.min, + logits + ) + + def top_p_logits(self, logits, p): + sorted_logits = tf.sort(logits, direction='DESCENDING') + sorted_probs = tf.nn.softmax(sorted_logits) + probs_sums = tf.cumsum(sorted_probs, axis=1, exclusive=True) + logits_masked = tf.where( + probs_sums < p, + sorted_logits, + tf.ones_like(sorted_logits) * 1000 + ) # [batchsize, vocab] + min_logits = tf.reduce_min(logits_masked, axis=1, keepdims=True) # [batch size, 1] + return tf.where( + logits < min_logits, + tf.ones_like(logits, dtype=logits.dtype) * logits.dtype.min, + logits + ) + + + \ No newline at end of file diff --git a/sample/tensorflow/utils/translation/download_model_data.sh b/sample/tensorflow/utils/translation/download_model_data.sh index c736f2e4c..7453f672c 100644 --- a/sample/tensorflow/utils/translation/download_model_data.sh +++ b/sample/tensorflow/utils/translation/download_model_data.sh @@ -1,3 +1,17 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Install the OpenNMT-tf v1 pip install opennmt-tf==1.25.1 @@ -9,13 +23,8 @@ wget https://s3.amazonaws.com/opennmt-models/averaged-ende-ckpt500k.tar.gz mkdir translation mkdir translation/ckpt -# mkdir translation/data -# tar xf wmt_ende_sp.tar.gz -C translation/data tar xf averaged-ende-ckpt500k.tar.gz -C translation/ckpt -# rm wmt_ende_sp.tar.gz rm averaged-ende-ckpt500k.tar.gz -# head -n 5 translation/data/test.en > test.en -# head -n 5 translation/data/test.de > test.de # convert the pretrained model to fit our model structure -python utils/dump_model.py translation/ckpt/model.ckpt-500000 +python tensorflow/utils/dump_model.py translation/ckpt/model.ckpt-500000 diff --git a/sample/tensorflow_bert/fast_infer_util.py b/sample/tensorflow_bert/fast_infer_util.py index 684ab3b52..975685f17 100644 --- a/sample/tensorflow_bert/fast_infer_util.py +++ b/sample/tensorflow_bert/fast_infer_util.py @@ -23,7 +23,7 @@ import sys from my_modeling import * -build_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../build/lib') +build_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../lib') transformer_op_module = tf.load_op_library( os.path.join(build_path, 'libtf_fastertransformer.so')) @@ -122,6 +122,48 @@ def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, return (loss, per_example_loss, logits, probabilities) +def create_model_squad(bert_config, is_training, input_ids, input_mask, segment_ids, + use_one_hot_embeddings): + """Creates a classification model.""" + model = BertModel( + config=bert_config, + is_training=is_training, + input_ids=input_ids, + input_mask=input_mask, + token_type_ids=segment_ids, + use_one_hot_embeddings=use_one_hot_embeddings) + + final_hidden = model.get_sequence_output() + + final_hidden_shape = get_shape_list(final_hidden, expected_rank=3) + batch_size = final_hidden_shape[0] + seq_length = final_hidden_shape[1] + hidden_size = final_hidden_shape[2] + + output_weights = tf.get_variable( + "cls/squad/output_weights", [2, hidden_size], + dtype=tf.flags.FLAGS.floatx, + initializer=tf.truncated_normal_initializer(stddev=0.02)) + + output_bias = tf.get_variable( + "cls/squad/output_bias", [2], + dtype=tf.flags.FLAGS.floatx, + initializer=tf.zeros_initializer()) + + final_hidden_matrix = tf.reshape(final_hidden, + [batch_size * seq_length, hidden_size]) + logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True) + logits = tf.nn.bias_add(logits, output_bias) + + logits = tf.reshape(logits, [batch_size, seq_length, 2]) + logits = tf.transpose(logits, [2, 0, 1]) + + unstacked_logits = tf.unstack(logits, axis=0) + + (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1]) + + return (start_logits, end_logits) + def get_available_gpus(): local_device_protos = device_lib.list_local_devices() @@ -138,7 +180,8 @@ def fast_transformer_model_trans(input_tensor, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, initializer_range=0.02, - do_return_all_layers=False): + do_return_all_layers=False, + sequence_length=None): """ Re-implementation of transformer_model function from modeling.py from Google's BERT repository https://github.com/google-research/bert using FasterTransformer Tensorflow op. @@ -260,21 +303,50 @@ def fast_transformer_model_trans(input_tensor, layer_output = dropout(layer_output, hidden_dropout_prob) layer_output = layer_norm(layer_output + attention_output) - # FASTINFER: fast transformer encoder inference - trainable_vars = tf.get_collection( - tf.GraphKeys.TRAINABLE_VARIABLES, scope=tf.get_variable_scope().name) - layer_output = transformer_op_module.bert_transformer( - layer_input, - layer_input, - trainable_vars[0], trainable_vars[2], trainable_vars[4], trainable_vars[1], trainable_vars[3], trainable_vars[5], - attention_mask, - trainable_vars[6], trainable_vars[7], trainable_vars[8], trainable_vars[9], trainable_vars[10], trainable_vars[11], - trainable_vars[12], trainable_vars[13], trainable_vars[14], trainable_vars[15], - from_seq_len=seq_length, to_seq_len=seq_length, head_num=num_attention_heads, size_per_head=attention_head_size) - - prev_output = layer_output + + # FASTINFER: fast transformer encoder inference + inputs = input_tensor + remove_padding = tf.flags.FLAGS.remove_padding + if remove_padding == True: + inputs, sequence_id_offset = transformer_op_module.build_mask_remove_padding(inputs, sequence_length) + else: + sequence_id_offset = [] + graph = tf.get_default_graph() + for layer_idx in range(num_hidden_layers): + layer_output = transformer_op_module.bert_transformer( + inputs, + inputs, + graph.get_tensor_by_name('bert/encoder/layer_%d/attention/self/query/kernel:0' % layer_idx), + graph.get_tensor_by_name('bert/encoder/layer_%d/attention/self/query/bias:0' % layer_idx), + graph.get_tensor_by_name('bert/encoder/layer_%d/attention/self/key/kernel:0' % layer_idx), + graph.get_tensor_by_name('bert/encoder/layer_%d/attention/self/key/bias:0' % layer_idx), + graph.get_tensor_by_name('bert/encoder/layer_%d/attention/self/value/kernel:0' % layer_idx), + graph.get_tensor_by_name('bert/encoder/layer_%d/attention/self/value/bias:0' % layer_idx), + tf.expand_dims(attention_mask, 1), + graph.get_tensor_by_name('bert/encoder/layer_%d/attention/output/dense/kernel:0' % layer_idx), + graph.get_tensor_by_name('bert/encoder/layer_%d/attention/output/dense/bias:0' % layer_idx), + graph.get_tensor_by_name('bert/encoder/layer_%d/attention/output/LayerNorm/beta:0' % layer_idx), + graph.get_tensor_by_name('bert/encoder/layer_%d/attention/output/LayerNorm/gamma:0' % layer_idx), + graph.get_tensor_by_name('bert/encoder/layer_%d/intermediate/dense/kernel:0' % layer_idx), + graph.get_tensor_by_name('bert/encoder/layer_%d/intermediate/dense/bias:0' % layer_idx), + graph.get_tensor_by_name('bert/encoder/layer_%d/output/dense/kernel:0' % layer_idx), + graph.get_tensor_by_name('bert/encoder/layer_%d/output/dense/bias:0' % layer_idx), + graph.get_tensor_by_name('bert/encoder/layer_%d/output/LayerNorm/beta:0' % layer_idx), + graph.get_tensor_by_name('bert/encoder/layer_%d/output/LayerNorm/gamma:0' % layer_idx), + sequence_id_offset, + head_num=num_attention_heads, size_per_head=attention_head_size, + remove_padding=remove_padding) + + if remove_padding == True: + all_layer_outputs.append(transformer_op_module.rebuild_padding(layer_output, sequence_id_offset, tf.expand_dims(attention_mask, 1))) + else: all_layer_outputs.append(layer_output) - + inputs = layer_output + + if remove_padding == True: + layer_output = transformer_op_module.rebuild_padding(layer_output, sequence_id_offset, tf.expand_dims(attention_mask, 1)) + + if do_return_all_layers: final_outputs = [] for layer_output in all_layer_outputs: diff --git a/sample/tensorflow_bert/my_modeling.py b/sample/tensorflow_bert/my_modeling.py index 37289b70e..0eacde253 100644 --- a/sample/tensorflow_bert/my_modeling.py +++ b/sample/tensorflow_bert/my_modeling.py @@ -201,6 +201,8 @@ def __init__(self, # for the attention scores. attention_mask = create_attention_mask_from_input_mask( input_ids, input_mask) + + sequence_length = tf.reduce_sum(input_mask, axis=1) # Run the stacked transformer. # `sequence_output` shape = [batch_size, seq_length, hidden_size]. @@ -215,7 +217,8 @@ def __init__(self, hidden_dropout_prob=config.hidden_dropout_prob, attention_probs_dropout_prob=config.attention_probs_dropout_prob, initializer_range=config.initializer_range, - do_return_all_layers=True) + do_return_all_layers=True, + sequence_length=sequence_length) self.sequence_output = self.all_encoder_layers[-1] # The "pooler" converts the encoded sequence tensor of shape @@ -766,7 +769,8 @@ def transformer_model(input_tensor, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, initializer_range=0.02, - do_return_all_layers=False): + do_return_all_layers=False, + sequence_length=None): """Multi-headed, multi-layer Transformer from "Attention is All You Need". This is almost an exact implementation of the original Transformer encoder. diff --git a/sample/tensorflow_bert/profile_bert_inference.py b/sample/tensorflow_bert/profile_bert_inference.py index c61b6b431..4ba2ecf14 100644 --- a/sample/tensorflow_bert/profile_bert_inference.py +++ b/sample/tensorflow_bert/profile_bert_inference.py @@ -161,8 +161,8 @@ def graph_fn(): t2, r2 = profile_util.run_profile( graph_fn, jit_xla, num_iter, check_result=False, init_checkpoint=FLAGS.init_checkpoint) - print('average time (seconds) elasped original tensorflow:', t1) - print('average time (seconds) elasped fast transformer:', t2) + print('average time (seconds) elapsed original tensorflow:', t1) + print('average time (seconds) elapsed fast transformer:', t2) if len(r1) + len(r2) > 0: check_res = np.asarray([np.allclose( r1[i], r2[i], atol=1e-4, rtol=0) for i in range(num_iter)]) diff --git a/sample/tensorflow_bert/profile_transformer_inference.py b/sample/tensorflow_bert/profile_transformer_inference.py index 8104288db..0c9a0f19e 100644 --- a/sample/tensorflow_bert/profile_transformer_inference.py +++ b/sample/tensorflow_bert/profile_transformer_inference.py @@ -48,7 +48,8 @@ def __init__(self, input_tensor, attention_mask, transformer_model_fn, - scope=None): + scope=None, + sequence_length=None): config = my_modeling.copy.deepcopy(config) if not is_training: config.hidden_dropout_prob = 0.0 @@ -74,7 +75,8 @@ def __init__(self, hidden_dropout_prob=config.hidden_dropout_prob, attention_probs_dropout_prob=config.attention_probs_dropout_prob, initializer_range=config.initializer_range, - do_return_all_layers=True) + do_return_all_layers=True, + sequence_length=sequence_length) self.sequence_output = self.all_encoder_layers[-1] with tf.variable_scope("pooler"): @@ -94,13 +96,14 @@ def get_sequence_output(self): def model_fn_builder(bert_config, transformer_model_fn): - def model_fn(input_tensor, attention_mask): # pylint: disable=unused-argument + def model_fn(input_tensor, attention_mask, sequence_length=None): # pylint: disable=unused-argument model = TransformerModel( config=bert_config, is_training=False, input_tensor=input_tensor, attention_mask=attention_mask, - transformer_model_fn=transformer_model_fn) + transformer_model_fn=transformer_model_fn, + sequence_length=sequence_length) seq_output = model.get_sequence_output() return seq_output @@ -111,11 +114,13 @@ def profile_model(config, jit_xla, num_iter): # initialize data input_data = np.random.randn( FLAGS.predict_batch_size, FLAGS.max_seq_length, config.hidden_size) - attention_mask = np.random.randint(2, size=( - FLAGS.predict_batch_size, FLAGS.max_seq_length)) + sequence_length = np.random.randint(0, FLAGS.max_seq_length + 1, size=FLAGS.predict_batch_size).astype(np.int32) + attention_mask = np.zeros((FLAGS.predict_batch_size, FLAGS.max_seq_length)) + for i in range(len(sequence_length)): + attention_mask[i, 0:sequence_length[i]] = 1 attention_mask = np.repeat( attention_mask[:, np.newaxis, :], FLAGS.max_seq_length, axis=1) - + model_fn_tf = model_fn_builder(config, my_modeling.transformer_model) model_fn_ft = model_fn_builder(config, fiu.fast_transformer_model_trans) @@ -124,7 +129,7 @@ def graph_fn(): input_tensor = tf.constant(input_data, dtype=FLAGS.floatx) mask_tensor = tf.constant(attention_mask, dtype=FLAGS.floatx) - output_var = model_fn(input_tensor, mask_tensor) + output_var = model_fn(input_tensor, mask_tensor, sequence_length) # for saving memcopy time return tf.reduce_mean(output_var) return graph_fn @@ -152,8 +157,8 @@ def graph_fn(): model_fn_ft), jit_xla, num_iter, check_result=False, init_checkpoint=FLAGS.init_checkpoint) # check errors - print('average time (seconds) elasped original tensorflow:', t1) - print('average time (seconds) elasped fast transformer:', t2) + print('average time (seconds) elapsed original tensorflow:', t1) + print('average time (seconds) elapsed fast transformer:', t2) if len(r1) + len(r2) > 0: check_res = np.asarray([np.allclose( @@ -220,4 +225,5 @@ def main(_): flags.mark_flag_as_required("xla") flags.DEFINE_bool("tf_profile", False, "whether to use tensorflow profiling") + flags.DEFINE_bool("remove_padding", False, "Whether remove the padding of sentences") tf.app.run() diff --git a/sample/tensorflow_bert/run_classifier_wrap.py b/sample/tensorflow_bert/run_classifier_wrap.py index a234c240b..7bcfb0ddb 100644 --- a/sample/tensorflow_bert/run_classifier_wrap.py +++ b/sample/tensorflow_bert/run_classifier_wrap.py @@ -73,4 +73,5 @@ flags.mark_flag_as_required("output_dir") flags.DEFINE_string("floatx", None, "float32 or float16") flags.mark_flag_as_required("floatx") + flags.DEFINE_bool("remove_padding", False, "Whether remove the padding of sentences") tf.app.run() diff --git a/sample/tensorflow_bert/run_squad_wrap.py b/sample/tensorflow_bert/run_squad_wrap.py new file mode 100644 index 000000000..052b1da7f --- /dev/null +++ b/sample/tensorflow_bert/run_squad_wrap.py @@ -0,0 +1,78 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# usage example +# export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 +# export GLUE_DIR=/path/to/glue +# python run_classifier_wrap.py --floatx=float16 --task_name=MRPC --do_eval=true --data_dir=$GLUE_DIR/MRPC --vocab_file=$BERT_BASE_DIR/vocab.txt --bert_config_file=$BERT_BASE_DIR/bert_config.json --init_checkpoint=mrpc_output/fp16_model.ckpt --max_seq_length=128 --eval_batch_size=8 --output_dir=mrpc_output + +# FP32 Tensorflow Transformer MRPC result +# INFO:tensorflow: eval_accuracy = 0.877451 +# INFO:tensorflow: eval_loss = 0.44744828 +# INFO:tensorflow: global_step = 0 +# INFO:tensorflow: loss = 0.44744828 + +# FP32 Faster Transformer MRPC result +# INFO:tensorflow: eval_accuracy = 0.877451 +# INFO:tensorflow: eval_loss = 0.4474482 +# INFO:tensorflow: global_step = 0 +# INFO:tensorflow: loss = 0.4474482 + +# FP16 Tensorflow Transformer MRPC result +# INFO:tensorflow: eval_accuracy = 0.875 +# INFO:tensorflow: eval_loss = 0.44760832 +# INFO:tensorflow: global_step = 0 +# INFO:tensorflow: loss = 0.44760215 + +# FP16 Faster Transformer MRPC result +# INFO:tensorflow: eval_accuracy = 0.875 +# INFO:tensorflow: eval_loss = 0.44731623 +# INFO:tensorflow: global_step = 0 +# INFO:tensorflow: loss = 0.44728807 + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +import os +bert_submodule = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bert') +sys.path.insert(0, bert_submodule) +import tensorflow as tf +# import run_classifier as rc +import run_squad as rs +import fast_infer_util as fiu +import my_modeling + +flags = tf.flags +FLAGS = flags.FLAGS + +# replace transformer implementation +my_modeling.transformer_model = fiu.fast_transformer_model_trans +# replace the model to support fp16 data type +rs.create_model = fiu.create_model_squad +# replace the input function to drop remainder +rs.file_based_input_fn_builder = fiu.file_based_input_fn_builder_drop +main = rs.main + +if __name__ == "__main__": + # flags.mark_flag_as_required("data_dir") + # flags.mark_flag_as_required("task_name") + flags.mark_flag_as_required("vocab_file") + flags.mark_flag_as_required("bert_config_file") + flags.mark_flag_as_required("output_dir") + flags.DEFINE_string("floatx", None, "float32 or float16") + flags.mark_flag_as_required("floatx") + flags.DEFINE_bool("remove_padding", False, "Remove padding or Not") + tf.app.run() diff --git a/sample/tensorflow_bert/squad_evaluate-v1.1.py b/sample/tensorflow_bert/squad_evaluate-v1.1.py new file mode 100644 index 000000000..342238b36 --- /dev/null +++ b/sample/tensorflow_bert/squad_evaluate-v1.1.py @@ -0,0 +1,108 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Official evaluation script for v1.1 of the SQuAD dataset. """ +from __future__ import print_function +from collections import Counter +import string +import re +import argparse +import json +import sys + + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + def remove_articles(text): + return re.sub(r'\b(a|an|the)\b', ' ', text) + + def white_space_fix(text): + return ' '.join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return ''.join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(prediction, ground_truth): + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def exact_match_score(prediction, ground_truth): + return (normalize_answer(prediction) == normalize_answer(ground_truth)) + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def evaluate(dataset, predictions): + f1 = exact_match = total = 0 + for article in dataset: + for paragraph in article['paragraphs']: + for qa in paragraph['qas']: + total += 1 + if qa['id'] not in predictions: + message = 'Unanswered question ' + qa['id'] + \ + ' will receive score 0.' + print(message, file=sys.stderr) + continue + ground_truths = list(map(lambda x: x['text'], qa['answers'])) + prediction = predictions[qa['id']] + exact_match += metric_max_over_ground_truths( + exact_match_score, prediction, ground_truths) + f1 += metric_max_over_ground_truths( + f1_score, prediction, ground_truths) + + exact_match = 100.0 * exact_match / total + f1 = 100.0 * f1 / total + + return {'exact_match': exact_match, 'f1': f1} + + +if __name__ == '__main__': + expected_version = '1.1' + parser = argparse.ArgumentParser( + description='Evaluation for SQuAD ' + expected_version) + parser.add_argument('dataset_file', help='Dataset file') + parser.add_argument('prediction_file', help='Prediction File') + args = parser.parse_args() + with open(args.dataset_file) as dataset_file: + dataset_json = json.load(dataset_file) + if (dataset_json['version'] != expected_version): + print('Evaluation expects v-' + expected_version + + ', but got dataset with v-' + dataset_json['version'], + file=sys.stderr) + dataset = dataset_json['data'] + with open(args.prediction_file) as prediction_file: + predictions = json.load(prediction_file) + print(json.dumps(evaluate(dataset, predictions))) diff --git a/tools/gemm_test/decoding_gemm.h b/tools/gemm_test/decoding_gemm.h index db249840a..dcda98958 100644 --- a/tools/gemm_test/decoding_gemm.h +++ b/tools/gemm_test/decoding_gemm.h @@ -41,7 +41,7 @@ void generate_decoding_gemm_config(int batch_size, } const int hidden_units = head_number * size_per_head; - const int gemm_num = 5; + const int gemm_num = 6; int M[gemm_num]; int N[gemm_num]; int K[gemm_num]; @@ -75,6 +75,11 @@ void generate_decoding_gemm_config(int batch_size, N[4] = hidden_units; strcpy(mess[4], "ffn gemm2"); + M[5] = batch_size * beam_width; + K[5] = hidden_units; + N[5] = hidden_units; + strcpy(mess[5], "from_tensor * QKV (batchstridedgemm) in masked attention"); + cublasHandle_t cublas_handle; check_cuda_error(cublasCreate(&cublas_handle)); @@ -114,9 +119,19 @@ void generate_decoding_gemm_config(int batch_size, T* d_A; T* d_B; T* d_C; - check_cuda_error(cudaMalloc((void**)&d_A, sizeof(T) * m * k)); - check_cuda_error(cudaMalloc((void**)&d_B, sizeof(T) * k * n)); - check_cuda_error(cudaMalloc((void**)&d_C, sizeof(T) * m * n)); + + if(i == 5) + { + check_cuda_error(cudaMalloc((void**)&d_A, sizeof(T) * m * k)); + check_cuda_error(cudaMalloc((void**)&d_B, sizeof(T) * k * n * 3)); + check_cuda_error(cudaMalloc((void**)&d_C, sizeof(T) * m * n * 3)); + } + else + { + check_cuda_error(cudaMalloc((void**)&d_A, sizeof(T) * m * k)); + check_cuda_error(cudaMalloc((void**)&d_B, sizeof(T) * k * n)); + check_cuda_error(cudaMalloc((void**)&d_C, sizeof(T) * m * n)); + } float exec_time = 99999.0f; int fast_algo = 0; @@ -127,16 +142,33 @@ void generate_decoding_gemm_config(int batch_size, gettimeofday(&start, NULL); for(int ite = 0; ite < ites; ++ite) { - status = cublasGemmEx(cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - d_B, BType, n, - d_A, AType, k, - &beta, - d_C, CType, n, - computeType, - static_cast(algo)); + if(i == 5) + { + status = cublasGemmStridedBatchedEx(cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + d_B, BType, n, k*n, + d_A, AType, k, 0, + &beta, + d_C, CType, n, m*n, + 3, + computeType, + static_cast(algo)); + } + else + { + status = cublasGemmEx(cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + d_B, BType, n, + d_A, AType, k, + &beta, + d_C, CType, n, + computeType, + static_cast(algo)); + } } cudaDeviceSynchronize(); gettimeofday(&end, NULL); @@ -151,7 +183,7 @@ void generate_decoding_gemm_config(int batch_size, } } printf("fast_algo %d costs %.3f ms\n", fast_algo, exec_time); - fprintf(fd, "%d\n", fast_algo); + fprintf(fd, "%d %f\n", fast_algo, exec_time); cudaFree(d_A); cudaFree(d_B); cudaFree(d_C); diff --git a/tools/gemm_test/encoder_gemm.cc b/tools/gemm_test/encoder_gemm.cc index d3dc532cd..8c351e55c 100644 --- a/tools/gemm_test/encoder_gemm.cc +++ b/tools/gemm_test/encoder_gemm.cc @@ -30,10 +30,6 @@ int main(int argc, char* argv[]) const int head_num = atoi(argv[3]); const int size_per_head = atoi(argv[4]); - struct cudaDeviceProp prop; - check_cuda_error(cudaGetDeviceProperties(&prop, 0)); - printf("Device %s\n", prop.name); - if(atoi(argv[5]) == 0) generate_encoder_gemm_config(batch_size, seq_len, head_num, size_per_head); else if(atoi(argv[5]) == 1) diff --git a/tools/gemm_test/encoder_gemm.h b/tools/gemm_test/encoder_gemm.h index c5a2312f5..3015e3ab2 100644 --- a/tools/gemm_test/encoder_gemm.h +++ b/tools/gemm_test/encoder_gemm.h @@ -25,6 +25,12 @@ using namespace std; +template +void device_malloc(T **ptr, int size) +{ + check_cuda_error(cudaMalloc((void **)ptr, sizeof(T) * size)); +} + template void generate_encoder_gemm_config(int batch_size, int seq_len, @@ -42,11 +48,11 @@ void generate_encoder_gemm_config(int batch_size, check_cuda_error(cudaGetDeviceProperties(&prop, 0)); printf("Device %s\n", prop.name); - const int gemm_num = 5; + const int gemm_num = 6; int M[gemm_num]; int N[gemm_num]; int K[gemm_num]; - int batchCount[gemm_num] = {1,1,1,1,1}; + int batchCount[gemm_num] = {1,1,1,1,1,1}; char mess[gemm_num][256]; //gemm1 @@ -79,6 +85,12 @@ void generate_encoder_gemm_config(int batch_size, batchCount[4] = batch_size * head_num; strcpy(mess[4], "attention batched Gemm2"); + M[5] = batch_size * seq_len; + N[5] = head_num * size_per_head; + K[5] = N[5]; + batchCount[5] = 3; + strcpy(mess[5], "from_tensor * weight_QKV in BatchGemm"); + cublasHandle_t cublas_handle; check_cuda_error(cublasCreate(&cublas_handle)); @@ -110,18 +122,35 @@ void generate_encoder_gemm_config(int batch_size, T alpha = (T)1.0f; T beta = (T)0.0f; - printf("***FP32 Gemm Testing***\n"); + printf("***Encoder Gemm Testing***\n"); for(int i = 0; i < gemm_num; ++i) { + // if(i != 0 && i != 5) continue; + int m = M[i], n = N[i], k = K[i]; printf("\n-----------------------------\n"); printf("GEMM test %d: [M: %d, K: %d, N: %d] %s\n", i, m, k, n, mess[i]); T* d_A; T* d_B; T* d_C; - check_cuda_error(cudaMalloc((void**)&d_A, sizeof(T) * m * k * batchCount[i])); - check_cuda_error(cudaMalloc((void**)&d_B, sizeof(T) * k * n * batchCount[i])); - check_cuda_error(cudaMalloc((void**)&d_C, sizeof(T) * m * n * batchCount[i])); + device_malloc(&d_A, sizeof(T) * m * k * batchCount[i]); + device_malloc(&d_B, sizeof(T) * k * n * batchCount[i]); + device_malloc(&d_C, sizeof(T) * m * n * batchCount[i]); + + // array of pointer for batchedGemm + T* harray[9]; + for(int i = 0; i < 9; i++) + { + if( i >= 0 && i < 3) device_malloc(&harray[i], sizeof(T) * m * k); + else if(i >= 3 && i < 6) device_malloc(&harray[i], sizeof(T) * k * n); + else if(i >= 6 && i < 9) device_malloc(&harray[i], sizeof(T) * m * n); + } + T** darray = 0; + check_cuda_error(cudaMalloc((void**)&darray, sizeof(T*) * 9)); + cudaMemcpy((void*)darray, (void*)harray, sizeof(T*) * 9, cudaMemcpyHostToDevice); + T** dAarray = darray; + T** dBarray = darray + 3; + T** dCarray = darray + 6; float exec_time = 99999.0f; int fast_algo = 0; @@ -159,7 +188,7 @@ void generate_encoder_gemm_config(int batch_size, computeType, static_cast(algo)); } - else + else if(i == 4) { status = cublasGemmStridedBatchedEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, @@ -173,6 +202,21 @@ void generate_encoder_gemm_config(int batch_size, computeType, static_cast(algo)); } + else if(i == 5) + { + status = cublasGemmBatchedEx(cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + (const void* const*) dBarray, BType, n, + (const void* const*) dAarray, AType, k, + &beta, + (void* const*)dCarray, CType, n, + 3, + computeType, + static_cast(algo)); + } + if(status != CUBLAS_STATUS_SUCCESS) break; } cudaDeviceSynchronize(); gettimeofday(&end, NULL); @@ -187,7 +231,13 @@ void generate_encoder_gemm_config(int batch_size, } } printf("fast_algo %d costs %.3f ms\n", fast_algo, exec_time); - fprintf(fd, "%d\n", fast_algo); + fprintf(fd, "%d %f\n", fast_algo, exec_time); + + cudaFree(d_A); + cudaFree(d_B); + cudaFree(d_C); + for(int i = 0; i < 9; i++) cudaFree(harray[i]); + cudaFree(darray); } return; }