Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,12 @@ option(USE_LITE_PROTO "Use lite protobuf instead of full." OFF)
option(USE_LMDB "Use LMDB" OFF)
option(USE_METAL "Use Metal for iOS build" ON)
option(USE_NATIVE_ARCH "Use -march=native" OFF)
option(USE_NCCL "Use NCCL" ON)
cmake_dependent_option(
USE_NCCL "Use NCCL" ON
"USE_CUDA" OFF)
cmake_dependent_option(
USE_RCCL "Use RCCL" ON
"USE_ROCM" OFF)
option(USE_SYSTEM_NCCL "Use system-wide NCCL" OFF)
option(USE_NNAPI "Use NNAPI" OFF)
option(USE_NNPACK "Use NNPACK" ON)
Expand Down Expand Up @@ -170,6 +175,7 @@ if (BUILD_ATEN_ONLY)
set(USE_GFLAGS OFF)
set(USE_GLOG OFF)
set(USE_NCCL OFF)
set(USE_RCCL OFF)
set(USE_NNPACK OFF)
set(USE_NUMPY OFF)
set(USE_OPENCV OFF)
Expand Down
6 changes: 6 additions & 0 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -810,12 +810,18 @@ if(USE_ROCM)

set(Caffe2_HIP_INCLUDE
${hip_INCLUDE_DIRS} ${hcc_INCLUDE_DIRS} ${hsa_INCLUDE_DIRS} ${rocrand_INCLUDE_DIRS} ${hiprand_INCLUDE_DIRS} ${rocblas_INCLUDE_DIRS} ${miopen_INCLUDE_DIRS} ${thrust_INCLUDE_DIRS} $<INSTALL_INTERFACE:include> ${Caffe2_HIP_INCLUDE})
if(USE_RCCL)
list(APPEND Caffe2_HIP_INCLUDE ${rccl_INCLUDE_DIRS})
endif(USE_RCCL)

# This is needed for library added by hip_add_library (same for hip_add_executable)
hip_include_directories(${Caffe2_HIP_INCLUDE})

set(Caffe2_HIP_DEPENDENCY_LIBS
${rocrand_LIBRARIES} ${hiprand_LIBRARIES} ${hipsparse_LIBRARIES} ${PYTORCH_HIP_HCC_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES})
if(USE_RCCL)
list(APPEND Caffe2_HIP_DEPENDENCY_LIBS ${PYTORCH_RCCL_LIBRARIES})
endif(USE_RCCL)

# Note [rocblas & rocfft cmake bug]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
1 change: 1 addition & 0 deletions cmake/Summary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ function (caffe2_print_configuration_summary)
if(${USE_NCCL})
message(STATUS " USE_SYSTEM_NCCL : ${USE_SYSTEM_NCCL}")
endif()
message(STATUS " USE_RCCL : ${USE_RCCL}")
message(STATUS " USE_NNPACK : ${USE_NNPACK}")
message(STATUS " USE_NUMPY : ${USE_NUMPY}")
message(STATUS " USE_OBSERVERS : ${USE_OBSERVERS}")
Expand Down
10 changes: 10 additions & 0 deletions cmake/public/LoadHIP.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ ELSE()
SET(HCC_AMDGPU_TARGET $ENV{HCC_AMDGPU_TARGET})
ENDIF()

# RCCL PATH
IF(NOT DEFINED ENV{RCCL_PATH})
SET(RCCL_PATH ${ROCM_PATH}/rccl)
ELSE()
SET(RCCL_PATH $ENV{RCCL_PATH})
ENDIF()

# Add HIP to the CMAKE Module Path
set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH})

Expand Down Expand Up @@ -140,6 +147,7 @@ IF(HIP_FOUND)
set(rocfft_DIR ${ROCFFT_PATH}/lib/cmake/rocfft)
set(hipsparse_DIR ${HIPSPARSE_PATH}/lib/cmake/hipsparse)
set(rocsparse_DIR ${ROCSPARSE_PATH}/lib/cmake/rocsparse)
set(rccl_DIR ${RCCL_PATH}/lib/cmake/rccl)

find_package_and_print_version(rocrand REQUIRED)
find_package_and_print_version(hiprand REQUIRED)
Expand All @@ -149,6 +157,7 @@ IF(HIP_FOUND)
find_package_and_print_version(rocfft REQUIRED)
#find_package_and_print_version(hipsparse REQUIRED)
find_package_and_print_version(rocsparse REQUIRED)
find_package_and_print_version(rccl)

# TODO: hip_hcc has an interface include flag "-hc" which is only
# recognizable by hcc, but not gcc and clang. Right now in our
Expand All @@ -158,6 +167,7 @@ IF(HIP_FOUND)
# TODO: miopen_LIBRARIES should return fullpath to the library file,
# however currently it's just the lib name
FIND_LIBRARY(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES} HINTS ${MIOPEN_PATH}/lib)
FIND_LIBRARY(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES} HINTS ${RCCL_PATH}/lib)
FIND_LIBRARY(hiprand_LIBRARIES hiprand HINTS ${HIPRAND_PATH}/lib)
FIND_LIBRARY(rocsparse_LIBRARIES rocsparse HINTS ${ROCSPARSE_PATH}/lib)
FIND_LIBRARY(hipsparse_LIBRARIES hipsparse HINTS ${HIPSPARSE_PATH}/lib)
Expand Down
11 changes: 11 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@
# NCCL_INCLUDE_DIR
# specify where nccl is installed
#
# RCCL_ROOT_DIR
# RCCL_LIB_DIR
# RCCL_INCLUDE_DIR
# specify where rccl is installed
#
# NVTOOLSEXT_PATH (Windows only)
# specify where nvtoolsext is installed
#
Expand Down Expand Up @@ -169,6 +174,7 @@
from tools.setup_helpers.rocm import USE_ROCM
from tools.setup_helpers.miopen import USE_MIOPEN, MIOPEN_LIBRARY, MIOPEN_INCLUDE_DIR
from tools.setup_helpers.nccl import USE_NCCL, USE_SYSTEM_NCCL, NCCL_SYSTEM_LIB, NCCL_INCLUDE_DIR
from tools.setup_helpers.rccl import USE_RCCL, RCCL_LIB_DIR, RCCL_INCLUDE_DIR, RCCL_ROOT_DIR, RCCL_SYSTEM_LIB
from tools.setup_helpers.dist_check import USE_DISTRIBUTED
################################################################################
# Parameters parsed from environment
Expand Down Expand Up @@ -343,6 +349,11 @@ def run(self):
report('-- Building NCCL library')
else:
report('-- Not using NCCL')
if USE_RCCL:
print('-- Detected RCCL library at ' +
RCCL_SYSTEM_LIB + ', ' + RCCL_INCLUDE_DIR)
else:
print('-- Not using RCCL')
if USE_DISTRIBUTED:
report('-- Building with THD distributed package ')
if IS_LINUX:
Expand Down
2 changes: 2 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,7 @@ def test_broadcast_coalesced(self):
self._test_broadcast_coalesced(self, tensors, num_bytes * 5 // 2)

@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
@skipIfRocm
def test_broadcast_coalesced_dense_only(self):
numel = 5
num_bytes = numel * 8
Expand Down Expand Up @@ -1146,6 +1147,7 @@ def test_reduce_add_coalesced(self):
self._test_reduce_add_coalesced(self, tensors, num_bytes * 5 // 2)

@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
@skipIfRocm
def test_reduce_add_coalesced_dense_only(self):
numel = 5
num_bytes = numel * 8
Expand Down
9 changes: 9 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3237,6 +3237,7 @@ def test_broadcast_no_grad(self):
self.assertFalse(output.requires_grad)

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@skipIfRocm
def test_replicate(self):
module = nn.Linear(10, 5).float().cuda()
input = Variable(torch.randn(2, 10).float().cuda())
Expand Down Expand Up @@ -3345,6 +3346,7 @@ def local_test(out):
local_test(out)

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@skipIfRocm
def test_data_parallel_small_back(self):
l = nn.Linear(10, 5).float().cuda()
i = Variable(torch.randn(20, 10).float().cuda())
Expand Down Expand Up @@ -3391,6 +3393,7 @@ def forward(self, x):
self.assertRaises(AssertionError, lambda: dp.data_parallel(l, i, (0, 1)))

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@skipIfRocm
def test_data_parallel(self):
l = nn.Linear(10, 5).float().cuda()
i = Variable(torch.randn(20, 10).float().cuda(1))
Expand Down Expand Up @@ -3494,6 +3497,7 @@ def forward(self, *input):

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@repeat_test_for_types(ALL_TENSORTYPES)
@skipIfRocm
def test_data_parallel_module(self, dtype=torch.float):
l = nn.Linear(10, 5).to("cuda", dtype)
i = torch.randn(20, 10, device="cuda", dtype=dtype)
Expand All @@ -3505,6 +3509,7 @@ def test_data_parallel_module(self, dtype=torch.float):

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@repeat_test_for_types(ALL_TENSORTYPES)
@skipIfRocm
def test_data_parallel_module_kwargs_only(self, dtype=torch.float):
class Net(nn.Module):
def __init__(self):
Expand All @@ -3524,6 +3529,7 @@ def forward(self, input):

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@repeat_test_for_types(ALL_TENSORTYPES)
@skipIfRocm
def test_data_parallel_module_kwargs_only_empty_list(self, dtype=torch.float):
class Net(nn.Module):
def __init__(self):
Expand All @@ -3543,6 +3549,7 @@ def forward(self, input):

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@repeat_test_for_types(ALL_TENSORTYPES)
@skipIfRocm
def test_data_parallel_module_kwargs_only_empty_dict(self, dtype=torch.float):
class Net(nn.Module):
def __init__(self):
Expand All @@ -3562,6 +3569,7 @@ def forward(self, input):

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@repeat_test_for_types(ALL_TENSORTYPES)
@skipIfRocm
def test_data_parallel_module_kwargs_only_empty_tuple(self, dtype=torch.float):
class Net(nn.Module):
def __init__(self):
Expand All @@ -3580,6 +3588,7 @@ def forward(self, input):
self.assertEqual(out.data, expected_out, dtype2prec[dtype])

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@skipIfRocm
def test_data_parallel_device_args(self):
cuda0 = torch.device('cuda:0')
cuda1 = torch.device('cuda:1')
Expand Down
4 changes: 3 additions & 1 deletion tools/amd_build/build_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,14 @@
"csrc/autograd/profiler.h",
"csrc/autograd/profiler.cpp",
"csrc/cuda/cuda_check.h",
"torch/lib/c10d/ProcessGroupGloo.hpp",
"torch/lib/c10d/ProcessGroupGloo.cpp",
# These files are compatible with both cuda and hip
"csrc/autograd/engine.cpp"
]
for root, _directories, files in os.walk(os.path.join(proj_dir, "torch")):
for filename in files:
if filename.endswith(".cpp") or filename.endswith(".h"):
if filename.endswith(".cpp") or filename.endswith(".h") or filename.endswith(".hpp"):
source = os.path.join(root, filename)
# Disabled files
if reduce(lambda result, exclude: source.endswith(exclude) or result, ignore_files, False):
Expand Down
6 changes: 6 additions & 0 deletions tools/amd_build/disabled_features.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
{
"disable_unsupported_hip_calls": [
{
"path": "torch/lib/c10d/NCCLUtils.hpp",
"s_constants": {
"<nccl.h>": "\"c10d/rccl1_compat.h\""
}
}
],
"disabled_modules": [
],
Expand Down
5 changes: 3 additions & 2 deletions tools/amd_build/pyHIPIFY/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@
API_BLAS = 39
API_SPARSE = 40
API_RAND = 41
API_LAST = 42
API_FFT = 43
API_FFT = 42
API_RCCL = 43
API_LAST = 44

HIP_UNSUPPORTED = 43
API_PYTORCH = 1337
Expand Down
39 changes: 39 additions & 0 deletions tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,10 @@
("curandStateXORWOW_t", ("hiprandStateXORWOW_t", CONV_TYPE, API_RAND)),
("curandState_t", ("hiprandState_t", CONV_TYPE, API_RAND)),
("curandState", ("hiprandState_t", CONV_TYPE, API_RAND)),
("ncclResult_t", ("rcclResult_t", CONV_TYPE, API_RCCL)),
("ncclComm_t", ("rcclComm_t", CONV_TYPE, API_RCCL)),
("ncclDataType_t", ("rcclDataType_t", CONV_TYPE, API_RCCL)),
("ncclRedOp_t", ("rcclRedOp_t", CONV_TYPE, API_RCCL)),
])

CUDA_INCLUDE_MAP = collections.OrderedDict([
Expand Down Expand Up @@ -276,6 +280,7 @@
("cusparse.h", ("hipsparse.h", CONV_INCLUDE, API_RAND)),
("cufft.h", ("hipfft.h", CONV_INCLUDE, API_BLAS)),
("cufftXt.h", ("hipfft.h", CONV_INCLUDE, API_BLAS)),
("<nccl.h>", ("<rccl.h>", CONV_INCLUDE, API_RCCL)), #PyTorch also has a source file named "nccl.h", so we need to use "<"">" to differentiate
])

CUDA_IDENTIFIER_MAP = collections.OrderedDict([
Expand Down Expand Up @@ -2172,6 +2177,40 @@
("cufftDestroy", ("hipfftDestroy", CONV_MATH_FUNC, API_FFT)),
("cufftGetVersion", ("hipfftGetVersion", CONV_MATH_FUNC, API_FFT)),
("cufftGetProperty", ("hipfftGetProperty", CONV_MATH_FUNC, API_FFT, HIP_UNSUPPORTED)),
("ncclGetErrorString", ("rcclGetErrorString", CONV_ERROR, API_RCCL)),
("ncclCommInitAll", ("rcclCommInitAll", CONV_SPECIAL_FUNC, API_RCCL)),
("ncclCommInitRank", ("rcclCommInitRank", CONV_SPECIAL_FUNC, API_RCCL)),
("ncclCommDestroy", ("rcclCommDestroy", CONV_SPECIAL_FUNC, API_RCCL)),
("ncclBcast", ("rcclBcast", CONV_SPECIAL_FUNC, API_RCCL)),
("ncclReduce", ("rcclReduce", CONV_SPECIAL_FUNC, API_RCCL)),
("ncclAllReduce", ("rcclAllReduce", CONV_SPECIAL_FUNC, API_RCCL)),
("ncclAllGather", ("rcclAllGather", CONV_SPECIAL_FUNC, API_RCCL)),
("ncclReduceScatter", ("rcclReduceScatter", CONV_SPECIAL_FUNC, API_RCCL, HIP_UNSUPPORTED)),
("ncclSuccess", ("rcclSuccess", CONV_TYPE, API_RCCL)),
("ncclChar", ("rcclChar", CONV_TYPE, API_RCCL)),
("ncclInt8", ("rcclChar", CONV_TYPE, API_RCCL)),
("ncclUint8", ("rcclChar", CONV_TYPE, API_RCCL)), #FIXME: This should be mapped to an unsigned int8 type

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my education: why is it not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rccl doesn't have one.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool - can we get that adjusted?

("ncclInt", ("rcclInt", CONV_TYPE, API_RCCL)),
("ncclInt32", ("rcclInt", CONV_TYPE, API_RCCL)),
("ncclUint32", ("rcclInt", CONV_TYPE, API_RCCL)), #FIXME: This should be mapped to an unsigned int32 type
("ncclInt64", ("rcclInt64", CONV_TYPE, API_RCCL)),
("ncclUint64", ("rcclUint64", CONV_TYPE, API_RCCL)),
("ncclHalf", ("rcclHalf", CONV_TYPE, API_RCCL)),
("ncclFloat16", ("rcclHalf", CONV_TYPE, API_RCCL)),
("ncclFloat", ("rcclFloat", CONV_TYPE, API_RCCL)),
("ncclFloat32", ("rcclFloat", CONV_TYPE, API_RCCL)),
("ncclDouble", ("rcclDouble", CONV_TYPE, API_RCCL)),
("ncclFloat64", ("rcclDouble", CONV_TYPE, API_RCCL)),
("ncclSum", ("rcclSum", CONV_TYPE, API_RCCL)),
("ncclProd", ("rcclProd", CONV_TYPE, API_RCCL)),
("ncclMin", ("rcclMin", CONV_TYPE, API_RCCL)),
("ncclMax", ("rcclMax", CONV_TYPE, API_RCCL)),
("ncclUniqueId", ("rcclUniqueId", CONV_TYPE, API_RCCL)),
("ncclGetUniqueId", ("rcclGetUniqueId", CONV_TYPE, API_RCCL)),
("ncclGroupStart", ("rcclGroupStart", CONV_TYPE, API_RCCL)),
("ncclGroupEnd", ("rcclGroupEnd", CONV_TYPE, API_RCCL)),
("NCCL_UNIQUE_ID_BYTES", ("RCCL_UNIQUE_ID_BYTES", CONV_TYPE, API_RCCL)),
("USE_NCCL", ("USE_RCCL", CONV_DEF, API_RCCL)),
])

CUDA_SPARSE_MAP = collections.OrderedDict([
Expand Down
18 changes: 16 additions & 2 deletions tools/amd_build/pyHIPIFY/hipify_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,8 @@ def pattern(self):
CAFFE2_TRIE.add(src)
CAFFE2_MAP[src] = dst
RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.pattern())
RE_PYTORCH_PREPROCESSOR = re.compile(r'\b{0}\b'.format(PYTORCH_TRIE.pattern()))
# Use \W instead of \b so that even if the pattern contains non-word characters, the replacement still succeeds
RE_PYTORCH_PREPROCESSOR = re.compile(r'(\W)({0})(\W)'.format(PYTORCH_TRIE.pattern()))

RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"')
RE_ANGLE_HEADER = re.compile(r'#include <([^>]+)>')
Expand All @@ -890,7 +891,7 @@ def preprocessor(output_directory, filepath, stats):
# unsupported_calls statistics reporting is broken atm
if is_pytorch_file(filepath):
def pt_repl(m):
return PYTORCH_MAP[m.group(0)]
return m.group(1) + PYTORCH_MAP[m.group(2)] + m.group(3)
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
else:
def c2_repl(m):
Expand Down Expand Up @@ -1201,3 +1202,16 @@ def hipify(
all_files,
show_detailed=show_detailed,
show_progress=show_progress)

# copy rccl compat file to c10d
rccl_compat_file = "rccl1_compat.h"
rccl_compat_src_filepath = os.path.join(os.path.dirname(__file__), rccl_compat_file)
if not os.path.exists(rccl_compat_src_filepath):
print("ERROR: File does not exist: " + rccl_compat_src_filepath)
sys.exit(1)
rccl_compat_dst_dir = os.path.join(output_directory, "torch", "lib", "c10d")
if not os.path.exists(rccl_compat_dst_dir):
print("ERROR: Directory does not exist: " + rccl_compat_dst_dir)
sys.exit(1)
rccl_compat_dst_filepath = os.path.join(rccl_compat_dst_dir, rccl_compat_file)
shutil.copy(rccl_compat_src_filepath, rccl_compat_dst_filepath)
48 changes: 48 additions & 0 deletions tools/amd_build/pyHIPIFY/rccl1_compat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#ifndef RCCL1_COMPAT_H
#define RCCL1_COMPAT_H

#include <rccl.h>

#ifndef RCCL_MAJOR // RCCL 1.x
#define RCCL_MAJOR 1
#define RCCL_MINOR 0

#define rcclNumOps rccl_NUM_OPS
#define rcclNumTypes rccl_NUM_TYPES

static rcclResult_t rcclGroupStart() { return rcclSuccess; }
static rcclResult_t rcclGroupEnd() { return rcclSuccess; }

#define CHECKCOUNT(count) if (count > INT_MAX) return rcclInvalidArgument;

/*
static rcclResult_t rcclReduce(const void* sendbuff, void* recvbuff, size_t count, rcclDataType_t datatype,
rcclRedOp_t op, int root, rcclComm_t comm, hipStream_t stream) {
CHECKCOUNT(count);
return rcclReduce(sendbuff, recvbuff, (int)count, datatype, op, root, comm, stream);
}
static rcclResult_t rcclAllReduce(const void* sendbuff, void* recvbuff, size_t count,
rcclDataType_t datatype, rcclRedOp_t op, rcclComm_t comm, hipStream_t stream) {
CHECKCOUNT(count);
return rcclAllReduce(sendbuff, recvbuff, (int)count, datatype, op, comm, stream);
}
static rcclResult_t rcclBcast(void* buff, size_t count, rcclDataType_t datatype, int root,
rcclComm_t comm, hipStream_t stream) {
CHECKCOUNT(count);
return rcclBcast(buff, (int)count, datatype, root, comm, stream);
}
static rcclResult_t rcclReduceScatter(const void* sendbuff, void* recvbuff,
size_t recvcount, rcclDataType_t datatype, rcclRedOp_t op, rcclComm_t comm,
hipStream_t stream) {
CHECKCOUNT(recvcount);
return rcclReduceScatter(sendbuff, recvbuff, (int)recvcount, datatype, op, comm, stream);
}
*/
static rcclResult_t rcclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount,
rcclDataType_t datatype, rcclComm_t comm, hipStream_t stream) {
CHECKCOUNT(sendcount);
return rcclAllGather(sendbuff, (int)sendcount, datatype, recvbuff, comm, stream);
}
#endif

#endif
Loading