Skip to content

Commit

Permalink
Merge pull request #333 from NVlabs/extended-vector
Browse files Browse the repository at this point in the history
Extend vector code, log callbacks instead of iostream, misc refactor / bugfixes
  • Loading branch information
Tom94 committed Jul 9, 2023
2 parents 1ee0787 + beec22f commit a2ca883
Show file tree
Hide file tree
Showing 75 changed files with 3,801 additions and 1,821 deletions.
15 changes: 2 additions & 13 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,7 @@ function(TCNN_AUTODETECT_CUDA_ARCHITECTURES OUT_VARIABLE)
"}\n"
)

if (CMAKE_CUDA_COMPILER_LOADED) # CUDA as a language
try_run(run_result compile_result ${PROJECT_BINARY_DIR} ${file} RUN_OUTPUT_VARIABLE compute_capabilities)
else()
try_run(
run_result compile_result ${PROJECT_BINARY_DIR} ${file}
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}"
LINK_LIBRARIES ${CUDA_LIBRARIES}
RUN_OUTPUT_VARIABLE compute_capabilities
)
endif()

try_run(run_result compile_result ${PROJECT_BINARY_DIR} ${file} RUN_OUTPUT_VARIABLE compute_capabilities)
if (run_result EQUAL 0)
# If the user has multiple GPUs with the same compute capability installed, list that capability only once.
list(REMOVE_DUPLICATES compute_capabilities)
Expand Down Expand Up @@ -257,8 +247,7 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_MINSIZEREL ${CMAKE_BINARY_DIR})
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${CMAKE_BINARY_DIR})

set(TCNN_SOURCES
src/common.cu
src/common_device.cu
src/common_host.cu
src/cpp_api.cu
src/cutlass_mlp.cu
src/encoding.cu
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ using namespace tcnn;

auto model = create_from_config(n_input_dims, n_output_dims, config);

// Train the model (batch_size must be a multiple of tcnn::batch_size_granularity)
// Train the model (batch_size must be a multiple of tcnn::BATCH_SIZE_GRANULARITY)
GPUMatrix<float> training_batch_inputs(n_input_dims, batch_size);
GPUMatrix<float> training_batch_targets(n_output_dims, batch_size);

Expand Down
37 changes: 26 additions & 11 deletions bindings/torch/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from setuptools import setup
from pkg_resources import parse_version
import subprocess
import shutil
import sys
import torch
from glob import glob
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -75,6 +77,8 @@ def find_cl_path():
# won't try to activate a developer command prompt a second time.
os.environ["DISTUTILS_USE_SDK"] = "1"

cpp_standard = 14

# Get CUDA version and make sure the targeted compute capability is compatible
if os.system("nvcc --version") == 0:
nvcc_out = subprocess.check_output(["nvcc", "--version"]).decode()
Expand All @@ -83,6 +87,9 @@ def find_cl_path():
if cuda_version:
cuda_version = parse_version(cuda_version.group(1))
print(f"Detected CUDA version {cuda_version}")
if cuda_version >= parse_version("11.0"):
cpp_standard = 17

supported_compute_capabilities = [
cc for cc in compute_capabilities if cc >= min_supported_compute_capability(cuda_version) and cc <= max_supported_compute_capability(cuda_version)
]
Expand All @@ -96,8 +103,10 @@ def find_cl_path():

min_compute_capability = min(compute_capabilities)

print(f"Targeting C++ standard {cpp_standard}")

base_nvcc_flags = [
"-std=c++14",
f"-std=c++{cpp_standard}",
"--extended-lambda",
"--expt-relaxed-constexpr",
# The following definitions must be undefined
Expand All @@ -108,13 +117,13 @@ def find_cl_path():
]

if os.name == "posix":
base_cflags = ["-std=c++14"]
base_cflags = [f"-std=c++{cpp_standard}"]
base_nvcc_flags += [
"-Xcompiler=-Wno-float-conversion",
"-Xcompiler=-fno-strict-aliasing",
]
elif os.name == "nt":
base_cflags = ["/std:c++14"]
base_cflags = [f"/std:c++{cpp_standard}"]


# Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine.
Expand All @@ -123,15 +132,21 @@ def find_cl_path():
# List of sources.
bindings_dir = os.path.dirname(__file__)
root_dir = os.path.abspath(os.path.join(bindings_dir, "../.."))
base_definitions = []

base_definitions = [
# PyTorch-supplied parameters may be unaligned. TCNN must be made aware of this such that
# it does not optimize for aligned memory accesses.
"-DTCNN_PARAMS_UNALIGNED",
]

base_source_files = [
"tinycudann/bindings.cpp",
"../../dependencies/fmt/src/format.cc",
"../../dependencies/fmt/src/os.cc",
"../../src/cpp_api.cu",
"../../src/common.cu",
"../../src/common_device.cu",
"../../src/common_host.cu",
"../../src/encoding.cu",
"../../src/object.cu",
]

if include_networks:
Expand All @@ -158,11 +173,11 @@ def make_extension(compute_capability):
name=f"tinycudann_bindings._{compute_capability}_C",
sources=source_files,
include_dirs=[
"%s/include" % root_dir,
"%s/dependencies" % root_dir,
"%s/dependencies/cutlass/include" % root_dir,
"%s/dependencies/cutlass/tools/util/include" % root_dir,
"%s/dependencies/fmt/include" % root_dir,
f"{root_dir}/include",
f"{root_dir}/dependencies",
f"{root_dir}/dependencies/cutlass/include",
f"{root_dir}/dependencies/cutlass/tools/util/include",
f"{root_dir}/dependencies/fmt/include",
],
extra_compile_args={"cxx": cflags, "nvcc": nvcc_flags},
libraries=["cuda"],
Expand Down
67 changes: 29 additions & 38 deletions bindings/torch/tinycudann/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include <json/json.hpp>

#include <pybind11_json/pybind11_json.hpp>
#include <pybind11/functional.h>

#include <tiny-cuda-nn/cpp_api.h>

Expand All @@ -53,10 +54,10 @@
#define CHECK_THROW(x) \
do { if (!(x)) throw std::runtime_error(std::string(FILE_LINE " check failed " #x)); } while(0)

c10::ScalarType torch_type(tcnn::cpp::EPrecision precision) {
c10::ScalarType torch_type(tcnn::cpp::Precision precision) {
switch (precision) {
case tcnn::cpp::EPrecision::Fp32: return torch::kFloat32;
case tcnn::cpp::EPrecision::Fp16: return torch::kHalf;
case tcnn::cpp::Precision::Fp32: return torch::kFloat32;
case tcnn::cpp::Precision::Fp16: return torch::kHalf;
default: throw std::runtime_error{"Unknown precision tcnn->torch"};
}
}
Expand Down Expand Up @@ -246,41 +247,19 @@ class Module {
return output;
}

uint32_t n_input_dims() const {
return m_module->n_input_dims();
}
uint32_t n_input_dims() const { return m_module->n_input_dims(); }

uint32_t n_params() const {
return (uint32_t)m_module->n_params();
}
uint32_t n_params() const { return (uint32_t)m_module->n_params(); }
tcnn::cpp::Precision param_precision() const { return m_module->param_precision(); }
c10::ScalarType c10_param_precision() const { return torch_type(param_precision()); }

tcnn::cpp::EPrecision param_precision() const {
return m_module->param_precision();
}
uint32_t n_output_dims() const { return m_module->n_output_dims(); }
tcnn::cpp::Precision output_precision() const { return m_module->output_precision(); }
c10::ScalarType c10_output_precision() const { return torch_type(output_precision()); }

c10::ScalarType c10_param_precision() const {
return torch_type(param_precision());
}
nlohmann::json hyperparams() const { return m_module->hyperparams(); }
std::string name() const { return m_module->name(); }

uint32_t n_output_dims() const {
return m_module->n_output_dims();
}

tcnn::cpp::EPrecision output_precision() const {
return m_module->output_precision();
}

c10::ScalarType c10_output_precision() const {
return torch_type(output_precision());
}

nlohmann::json hyperparams() const {
return m_module->hyperparams();
}

std::string name() const {
return m_module->name();
}

private:
std::unique_ptr<tcnn::cpp::Module> m_module;
Expand All @@ -296,22 +275,34 @@ Module create_network(uint32_t n_input_dims, uint32_t n_output_dims, const nlohm
}
#endif

Module create_encoding(uint32_t n_input_dims, const nlohmann::json& encoding, tcnn::cpp::EPrecision requested_precision) {
Module create_encoding(uint32_t n_input_dims, const nlohmann::json& encoding, tcnn::cpp::Precision requested_precision) {
return Module{tcnn::cpp::create_encoding(n_input_dims, encoding, requested_precision)};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::enum_<tcnn::cpp::EPrecision>(m, "Precision")
.value("Fp32", tcnn::cpp::EPrecision::Fp32)
.value("Fp16", tcnn::cpp::EPrecision::Fp16)
py::enum_<tcnn::cpp::LogSeverity>(m, "LogSeverity")
.value("Info", tcnn::cpp::LogSeverity::Info)
.value("Debug", tcnn::cpp::LogSeverity::Debug)
.value("Warning", tcnn::cpp::LogSeverity::Warning)
.value("Error", tcnn::cpp::LogSeverity::Error)
.value("Success", tcnn::cpp::LogSeverity::Success)
.export_values()
;

py::enum_<tcnn::cpp::Precision>(m, "Precision")
.value("Fp32", tcnn::cpp::Precision::Fp32)
.value("Fp16", tcnn::cpp::Precision::Fp16)
.export_values()
;

m.def("batch_size_granularity", &tcnn::cpp::batch_size_granularity);
m.def("default_loss_scale", &tcnn::cpp::default_loss_scale);
m.def("free_temporary_memory", &tcnn::cpp::free_temporary_memory);
m.def("has_networks", &tcnn::cpp::has_networks);
m.def("preferred_precision", &tcnn::cpp::preferred_precision);

m.def("set_log_callback", &tcnn::cpp::set_log_callback);

// Encapsulates an abstract context of an operation
// (commonly the forward pass) to be passed on to other
// operations (commonly the backward pass).
Expand Down
11 changes: 10 additions & 1 deletion bindings/torch/tinycudann/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import gc
import importlib
import os
import warnings

import torch
Expand Down Expand Up @@ -57,6 +58,14 @@ def _get_system_compute_capability():
if _C is None:
raise EnvironmentError(f"Could not find compatible tinycudann extension for compute capability {system_compute_capability}.")

# Pipe tcnn warnings and errors into Python
# def _log(severity, msg):
# if severity == _C.LogSeverity.Warning:
# warnings.warn(f"tinycudann warning: {msg}")
# elif severity == _C.LogSeverity.Error:
# warnings.warn(f"tinycudann error: {msg}")

# _C.set_log_callback(_log)
def _torch_precision(tcnn_precision):
if tcnn_precision == _C.Precision.Fp16:
return torch.half
Expand Down Expand Up @@ -162,7 +171,7 @@ def __init__(self, seed=1337):
self.params = torch.nn.Parameter(initial_params, requires_grad=True)
self.register_parameter(name="params", param=self.params)

self.loss_scale = 128.0 if self.native_tcnn_module.param_precision() == _C.Precision.Fp16 else 1.0
self.loss_scale = _C.default_loss_scale(self.native_tcnn_module.param_precision())

def forward(self, x):
if not x.is_cuda:
Expand Down
12 changes: 3 additions & 9 deletions dependencies/pcg32/pcg32.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@
#define PCG32_DEFAULT_STREAM 0xda3e39cb94b95bdbULL
#define PCG32_MULT 0x5851f42d4c957f2dULL

#include <cmath>
#include <cassert>


TCNN_NAMESPACE_BEGIN
namespace tcnn {

/// PCG32 Pseudorandom number generator
struct pcg32 {
Expand Down Expand Up @@ -171,8 +167,6 @@ struct pcg32 {

/// Compute the distance between two PCG32 pseudorandom number generators
TCNN_HOST_DEVICE int64_t operator-(const pcg32 &other) const {
assert(inc == other.inc);

uint64_t
cur_mult = PCG32_MULT,
cur_plus = inc,
Expand All @@ -185,7 +179,7 @@ struct pcg32 {
cur_state = cur_state * cur_mult + cur_plus;
distance |= the_bit;
}
assert((state & the_bit) == (cur_state & the_bit));

the_bit <<= 1;
cur_plus = (cur_mult + 1ULL) * cur_plus;
cur_mult *= cur_mult;
Expand All @@ -204,4 +198,4 @@ struct pcg32 {
uint64_t inc; // Controls which RNG sequence (stream) is selected. Must *always* be odd.
};

TCNN_NAMESPACE_END
}
Loading

0 comments on commit a2ca883

Please sign in to comment.