Skip to content

Commit

Permalink
libtorch: work on some cuda refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ConnorBaker committed Aug 22, 2023
1 parent e24fd3e commit 225a72a
Showing 1 changed file with 48 additions and 25 deletions.
73 changes: 48 additions & 25 deletions pkgs/development/libraries/science/math/libtorch/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
fetchFromGitHub,
fetchpatch,
pkgs,
symlinkJoin,
# nativeBuildInputs
asmjit,
blas,
Expand All @@ -23,6 +22,7 @@
mpi,
ninja,
numactl,
onnx,
protobuf,
psimd,
pthreadpool,
Expand All @@ -43,7 +43,7 @@
useXnnpack ? true,
useZstd ? true,
}: let
inherit (lib) lists;
inherit (lib) lists strings;
setBool = bool:
if bool
then "ON"
Expand Down Expand Up @@ -82,28 +82,6 @@
};
});

cuda-redist = symlinkJoin {
name = "cuda-redist";
paths = with cudaPackages;
[
autoAddOpenGLRunpathHook
cuda_cccl # <thrust> and CUB
cuda_cudart
cuda_cupti # Needed by Kineto for GPU profiling
cuda_nvcc
cuda_nvml_dev
cuda_nvrtc
cuda_nvtx
libcublas
libcufft
libcurand
libcusolver
libcusparse
nccl.dev
]
++ lists.optionals useCudnn [cudnn];
};

mkDerivation =
if useCuda
then cudaPackages.backendStdenv.mkDerivation
Expand Down Expand Up @@ -171,6 +149,7 @@ in
rm -rf FXdiv*
rm -rf gloo*
rm -rf ideep/mkl-dnn*
rm -rf onnx*
rm -rf protobuf*
rm -rf psimd*
rm -rf pthreadpool*
Expand Down Expand Up @@ -235,6 +214,7 @@ in
fxdiv
gflags
glog
onnx
protobuf
psimd
pthreadpool
Expand All @@ -248,7 +228,14 @@ in
zlib
]
# Optional dependencies
++ lists.optionals useCuda [cuda-redist]
++ lists.optionals useCuda (
# TODO(@connorbaker): Is this correct that we need both cudart and nvcc as native dependencies?
with cudaPackages; [
autoAddOpenGLRunpathHook
cuda_cudart # cuda_runtime.h
cuda_nvcc # crt/host_config.h
]
)
++ lists.optionals useGloo [gloo]
++ lists.optionals useMagma [magma]
++ lists.optionals useMkldnn [oneDNN.dev] # oneDNN is the new name for MKL-DNN
Expand All @@ -257,13 +244,48 @@ in
++ lists.optionals useXnnpack [xnnpack]
++ lists.optionals useZstd [zstd.dev];

# TODO(@connorbaker): Currently CUDA build fails with:
# CMake Error at cmake/public/cuda.cmake:65 (message):
# Found two conflicting CUDA installs:
#
# V11.8.89 in
# '/nix/store/rsjxr5b5zifa0wbpziwqfzg7lncfz0f0-cuda_cudart-11.8.89/include'
# and
#
# V11.8.89 in
# '/nix/store/rsjxr5b5zifa0wbpziwqfzg7lncfz0f0-cuda_cudart-11.8.89/include;/nix/store/nljxvgbp6fy0q7cbrp5l5igv57p5fa3v-cuda_nvcc-11.8.89/include;/nix/store/mfk63jcw2r77asgai82rzbzbph10dhh8-cuda_cccl-11.8.89/include;/nix/store/0xhbghrnf7x289m78c8ha2dm6n83wfbg-cuda_cupti-11.8.87/include;/nix/store/4x7gb192a6pskj2skwn9s3m0vnn73bff-cuda_nvml_dev-11.8.86/include;/nix/store/00p0i6kqw6qjbrc4fddqfnv07zcg7gi1-cuda_nvrtc-11.8.89/include;/nix/store/953p97p0inb7wdj50qcz47dy3lh58vhq-cuda_nvtx-11.8.86/include;/nix/store/qsm8bjydfnapr77wzlyzyzcsnkc0yrh2-libcublas-11.11.3.6/include;/nix/store/fszipvg6jw9dsj2lz1izwy7363mwh4fj-libcufft-10.9.0.58/include;/nix/store/8r9kj0rh0kk9iqi32kkm1bdxqb8jipbr-libcurand-10.3.0.86/include;/nix/store/f0d08h7g4apgngbyrgqvpjxmlp3azf0m-libcusolver-11.4.1.48/include;/nix/store/141gw8r2ypg27186mzg81rhndl402l80-libcusparse-11.7.5.86/include;/nix/store/z5ppzlnw5wzy5bbvhm76kfmjmirpkqhb-cuda_profiler_api-11.8.86/include'
buildInputs = lists.optionals useCuda (with cudaPackages;
[
(lib.getDev nccl)
cuda_cccl # <thrust/*>
cuda_cupti
cuda_nvml_dev # <nvml.h>
cuda_nvrtc
cuda_nvtx # -llibNVToolsExt
libcublas
libcufft
libcurand
libcusolver
libcusparse
nccl
]
++ lists.optionals useCudnn [cudnn]
++ lists.optionals (strings.versionOlder cudaVersion "11.8") [
cuda_nvprof # <cuda_profiler_api.h>
]
++ lists.optionals (strings.versionAtLeast cudaVersion "11.8") [
cuda_profiler_api # <cuda_profiler_api.h>
]);

cmakeFlags =
# Core configuration options
[
"-DATEN_NO_TEST:BOOL=ON"
"-DBUILD_PYTHON:BOOL=OFF"
"-DBUILD_SHARED_LIBS:BOOL=ON"
"-DCMAKE_BUILD_TYPE:STRING=Release"
"-DCMAKE_C_STANDARD:STRING=17"
"-DCMAKE_CXX_STANDARD:STRING=17"
"-DUSE_PRECOMPILED_HEADERS:BOOL=ON"
]
# Core dependencies
Expand All @@ -279,6 +301,7 @@ in
"-DUSE_SYSTEM_FMT:BOOL=ON"
"-DUSE_SYSTEM_FP16:BOOL=ON"
"-DUSE_SYSTEM_FXDIV:BOOL=ON"
"-DUSE_SYSTEM_ONNX:BOOL=ON"
"-DUSE_SYSTEM_PSIMD:BOOL=ON"
"-DUSE_SYSTEM_PTHREADPOOL:BOOL=ON"
"-DUSE_SYSTEM_PYBIND11:BOOL=ON"
Expand Down

0 comments on commit 225a72a

Please sign in to comment.