Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
*~
*.o
build/
*.pyc
build*/
*.pyc
.vscode/
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
path = sample/tensorflow_bert/bert
url = https://github.com/google-research/bert.git

[submodule "OpenNMT-tf"]
path = OpenNMT-tf
url = https://github.com/OpenNMT/OpenNMT-tf
79 changes: 65 additions & 14 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cmake_minimum_required(VERSION 3.8 FATAL_ERROR)
cmake_minimum_required(VERSION 3.8 FATAL_ERROR) # for PyTorch extensions, version should be greater than 3.13
project(FasterTransformer LANGUAGES CXX CUDA)

find_package(CUDA 10.0 REQUIRED)

option(BUILD_TRT "Build in TensorRT mode" OFF)
option(BUILD_TF "Build in TensorFlow mode" OFF)
option(BUILD_THE "Build in PyTorch eager mode" OFF)
option(BUILD_THS "Build in TorchScript class mode" OFF)
option(BUILD_THSOP "Build in TorchScript OP mode" OFF)

set(CXX_STD "11" CACHE STRING "C++ standard")

set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})

Expand Down Expand Up @@ -53,6 +58,11 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_${SM},code=\\\"s
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA")
endif()
if(BUILD_THE OR BUILD_THS OR BUILD_THSOP)
string(SUBSTRING ${SM} 0 1 SM_MAJOR)
string(SUBSTRING ${SM} 1 1 SM_MINOR)
set(ENV{TORCH_CUDA_ARCH_LIST} "${SM_MAJOR}.${SM_MINOR}")
endif()
message("-- Assign GPU architecture (sm=${SM})")

else()
Expand All @@ -65,22 +75,21 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} \
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA")
if(BUILD_THE OR BUILD_THS OR BUILD_THSOP)
set(ENV{TORCH_CUDA_ARCH_LIST} "6.0;6.1;7.0;7.5")
endif()
message("-- Assign GPU architecture (sm=60,61,70,75)")
endif()

set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -O0")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall")

set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall --ptxas-options=-v --resource-usage")

set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD "${CXX_STD}")
set(CMAKE_CXX_STANDARD_REQUIRED ON)

if(CMAKE_CXX_STANDARD STREQUAL "11")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++11")
endif()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD}")

set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3")
Expand Down Expand Up @@ -108,6 +117,41 @@ if(BUILD_TRT)
list(APPEND COMMON_LIB_DIRS ${TRT_PATH}/lib)
endif()

set(PYTHON_PATH "python" CACHE STRING "Python path")
if(BUILD_THE OR BUILD_THS OR BUILD_THSOP)
execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; import os; import torch;
print(os.path.dirname(torch.__file__),end='');"
RESULT_VARIABLE _PYTHON_SUCCESS
OUTPUT_VARIABLE TORCH_DIR)
if (NOT _PYTHON_SUCCESS MATCHES 0)
message(FATAL_ERROR "Torch config Error.")
endif()
list(APPEND CMAKE_PREFIX_PATH ${TORCH_DIR})
find_package(Torch REQUIRED)

execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; from distutils import sysconfig;
print(sysconfig.get_python_inc());
print(sysconfig.get_config_var('SO'));"
RESULT_VARIABLE _PYTHON_SUCCESS
OUTPUT_VARIABLE _PYTHON_VALUES)
if (NOT _PYTHON_SUCCESS MATCHES 0)
message(FATAL_ERROR "Python config Error.")
endif()
string(REGEX REPLACE ";" "\\\\;" _PYTHON_VALUES ${_PYTHON_VALUES})
string(REGEX REPLACE "\n" ";" _PYTHON_VALUES ${_PYTHON_VALUES})
list(GET _PYTHON_VALUES 0 PY_INCLUDE_DIR)
list(GET _PYTHON_VALUES 1 PY_SUFFIX)
list(APPEND COMMON_HEADER_DIRS ${PY_INCLUDE_DIR})

execute_process(COMMAND ${PYTHON_PATH} "-c" "from torch.utils import cpp_extension; print(' '.join(cpp_extension._prepare_ldflags([],True,False)),end='');"
RESULT_VARIABLE _PYTHON_SUCCESS
OUTPUT_VARIABLE TORCH_LINK)
if (NOT _PYTHON_SUCCESS MATCHES 0)
message(FATAL_ERROR "PyTorch link config Error.")
endif()
endif()


include_directories(
${COMMON_HEADER_DIRS}
)
Expand All @@ -124,10 +168,17 @@ if(BUILD_TF)
add_custom_target(copy ALL COMMENT "Copying tensorflow test scripts")
add_custom_command(TARGET copy
POST_BUILD
COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow/*.py ${PROJECT_SOURCE_DIR}/build/
COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow/utils ${PROJECT_SOURCE_DIR}/build/ -r
COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow/scripts ${PROJECT_SOURCE_DIR}/build/ -r
COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow_bert ${PROJECT_SOURCE_DIR}/build/ -r
COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow/ ${PROJECT_BINARY_DIR} -r
COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow_bert ${PROJECT_BINARY_DIR}/tensorflow -r
)
endif()

if(BUILD_THE OR BUILD_THS OR BUILD_THSOP)
add_custom_target(copy ALL COMMENT "Copying pytorch test scripts")
add_custom_command(TARGET copy
POST_BUILD
COMMAND cp ${PROJECT_SOURCE_DIR}/sample/pytorch/ ${PROJECT_BINARY_DIR} -r
COMMAND mkdir -p ${PROJECT_BINARY_DIR}/pytorch/translation/data/
COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow/utils/translation/test.* ${PROJECT_BINARY_DIR}/pytorch/translation/data/
)
endif()
Loading