Skip to content

Commit

Permalink
Merge pull request #43 from NVlabs/pytorch-improvements
Browse files Browse the repository at this point in the history
PyTorch improvements
  • Loading branch information
Tom94 committed Feb 17, 2022
2 parents 0a9f5c2 + 8170062 commit e7826c3
Show file tree
Hide file tree
Showing 40 changed files with 450 additions and 90 deletions.
63 changes: 31 additions & 32 deletions bindings/torch/tinycudann/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,6 @@ void* void_data_ptr(torch::Tensor& tensor) {
class Module {
public:
Module(tcnn::cpp::Module* module) : m_module{module} {}
virtual ~Module() {}

#if !defined(TCNN_NO_NETWORKS)
// Helper constructor to create a NetworkWithInputEncoding module
Module(uint32_t n_input_dims, uint32_t n_output_dims, const nlohmann::json& encoding, const nlohmann::json& network)
: Module{tcnn::cpp::create_network_with_input_encoding(n_input_dims, n_output_dims, encoding, network)} {}

// Helper constructor to create a Network module
Module(uint32_t n_input_dims, uint32_t n_output_dims, const nlohmann::json& network)
: Module{tcnn::cpp::create_network(n_input_dims, n_output_dims, network)} {}
#endif

// Helper constructor to create a Encoding module
Module(uint32_t n_input_dims, const nlohmann::json& encoding, tcnn::cpp::EPrecision requested_precision)
: Module{tcnn::cpp::create_encoding(n_input_dims, encoding, requested_precision)} {}

std::tuple<tcnn::cpp::Context, torch::Tensor> fwd(torch::Tensor input, torch::Tensor params) {
// Check for correct types
Expand Down Expand Up @@ -219,10 +204,32 @@ class Module {
return torch_type(output_precision());
}

nlohmann::json hyperparams() const {
return m_module->hyperparams();
}

std::string name() const {
return m_module->name();
}

private:
std::unique_ptr<tcnn::cpp::Module> m_module;
};

#if !defined(TCNN_NO_NETWORKS)
Module create_network_with_input_encoding(uint32_t n_input_dims, uint32_t n_output_dims, const nlohmann::json& encoding, const nlohmann::json& network) {
return Module{tcnn::cpp::create_network_with_input_encoding(n_input_dims, n_output_dims, encoding, network)};
}

Module create_network(uint32_t n_input_dims, uint32_t n_output_dims, const nlohmann::json& network) {
return Module{tcnn::cpp::create_network(n_input_dims, n_output_dims, network)};
}
#endif

Module create_encoding(uint32_t n_input_dims, const nlohmann::json& encoding, tcnn::cpp::EPrecision requested_precision) {
return Module{tcnn::cpp::create_encoding(n_input_dims, encoding, requested_precision)};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::enum_<tcnn::cpp::EPrecision>(m, "Precision")
.value("Fp32", tcnn::cpp::EPrecision::Fp32)
Expand All @@ -243,23 +250,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// under the hood. The bindings don't need to concern
// themselves with these implementation details, though.
py::class_<Module>(m, "Module")
#if !defined(TCNN_NO_NETWORKS)
.def(
py::init<uint32_t, uint32_t, const nlohmann::json&, const nlohmann::json&>(),
"Constructor for Encoding+Network combo",
py::arg("n_input_dims"), py::arg("n_output_dims"), py::arg("encoding_config"), py::arg("network_config")
)
.def(
py::init<uint32_t, uint32_t, const nlohmann::json&>(),
"Constructor for just the Network",
py::arg("n_input_dims"), py::arg("n_output_dims"), py::arg("network_config")
)
#endif
.def(
py::init<uint32_t, const nlohmann::json&, tcnn::cpp::EPrecision>(),
"Constructor for just the Encoding",
py::arg("n_input_dims"), py::arg("encoding_config"), py::arg("precision")=tcnn::cpp::preferred_precision()
)
.def("fwd", &Module::fwd)
.def("bwd", &Module::bwd)
.def("initial_params", &Module::initial_params)
Expand All @@ -268,5 +258,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("param_precision", &Module::param_precision)
.def("n_output_dims", &Module::n_output_dims)
.def("output_precision", &Module::output_precision)
.def("hyperparams", &Module::hyperparams)
.def("name", &Module::name)
;

#if !defined(TCNN_NO_NETWORKS)
m.def("create_network_with_input_encoding", &create_network_with_input_encoding);
m.def("create_network", &create_network);
#endif

m.def("create_encoding", &create_encoding);
}
44 changes: 35 additions & 9 deletions bindings/torch/tinycudann/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,55 @@
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import torch
from torch.autograd.function import once_differentiable
from tinycudann_bindings import _C

class _module_func(torch.autograd.Function):
def _torch_precision(tcnn_precision):
if tcnn_precision == _C.Precision.Fp16:
return torch.half
elif tcnn_precision == _C.Precision.Fp32:
return torch.float
else:
raise ValueError(f"Unknown precision {tcnn_precision}")

class _module_function(torch.autograd.Function):
@staticmethod
def forward(ctx, native_tcnn_module, input, params, loss_scale):
# If no output gradient is provided, no need to
# automatically materialize it as torch.zeros.
ctx.set_materialize_grads(False)

native_ctx, output = native_tcnn_module.fwd(input, params)
ctx.save_for_backward(input, params, output)
ctx.native_tcnn_module = native_tcnn_module
ctx.native_ctx = native_ctx
ctx.loss_scale = loss_scale

return output

@staticmethod
@once_differentiable
def backward(ctx, doutput):
input, weights, output = ctx.saved_tensors
if doutput is None:
return None, None, None, None

input, params, output = ctx.saved_tensors
with torch.no_grad():
scaled_grad = doutput * ctx.loss_scale
input_grad, weight_grad = ctx.native_tcnn_module.bwd(ctx.native_ctx, input, weights, output, scaled_grad)
return None, None if input_grad is None else (input_grad / ctx.loss_scale), None if weight_grad is None else (weight_grad / ctx.loss_scale), None
input_grad, weight_grad = ctx.native_tcnn_module.bwd(ctx.native_ctx, input, params, output, scaled_grad)
input_grad = None if input_grad is None else (input_grad / ctx.loss_scale)
weight_grad = None if weight_grad is None else (weight_grad / ctx.loss_scale)

return None, input_grad, weight_grad, None

class Module(torch.nn.Module):
def __init__(self, seed=1337):
super(Module, self).__init__()

self.native_tcnn_module = self._native_tcnn_module()
self.dtype = _torch_precision(self.native_tcnn_module.param_precision())

self.seed = seed
initial_params = self.native_tcnn_module.initial_params(seed)
self.params = torch.nn.Parameter(initial_params, requires_grad=True)
self.register_parameter(name="params", param=self.params)
Expand All @@ -45,10 +68,10 @@ def forward(self, x):
padded_batch_size = (batch_size + 255) // 256 * 256

x_padded = x if batch_size == padded_batch_size else torch.nn.functional.pad(x, [0, 0, 0, padded_batch_size - batch_size])
output = _module_func.apply(
output = _module_function.apply(
self.native_tcnn_module,
x_padded.to(torch.float).contiguous(),
self.params.to(torch.half if self.native_tcnn_module.param_precision() == _C.Precision.Fp16 else torch.float32).contiguous(),
self.params.to(_torch_precision(self.native_tcnn_module.param_precision())).contiguous(),
self.loss_scale
)
return output[:batch_size, :self.n_output_dims]
Expand All @@ -65,6 +88,9 @@ def __setstate__(self, state):
# Reconstruct native entries
self.native_tcnn_module = self._native_tcnn_module()

def extra_repr(self):
return f"n_input_dims={self.n_input_dims}, n_output_dims={self.n_output_dims}, seed={self.seed}, dtype={self.dtype}, hyperparams={self.native_tcnn_module.hyperparams()}"

class NetworkWithInputEncoding(Module):
"""
Input encoding, followed by a neural network.
Expand Down Expand Up @@ -102,7 +128,7 @@ def __init__(self, n_input_dims, n_output_dims, encoding_config, network_config,
super(NetworkWithInputEncoding, self).__init__(seed=seed)

def _native_tcnn_module(self):
return _C.Module(self.n_input_dims, self.n_output_dims, self.encoding_config, self.network_config)
return _C.create_network_with_input_encoding(self.n_input_dims, self.n_output_dims, self.encoding_config, self.network_config)

class Network(Module):
"""
Expand Down Expand Up @@ -134,7 +160,7 @@ def __init__(self, n_input_dims, n_output_dims, network_config, seed=1337):
super(Network, self).__init__(seed=seed)

def _native_tcnn_module(self):
return _C.Module(self.n_input_dims, self.n_output_dims, self.network_config)
return _C.create_network(self.n_input_dims, self.n_output_dims, self.network_config)

class Encoding(Module):
"""
Expand Down Expand Up @@ -179,4 +205,4 @@ def __init__(self, n_input_dims, encoding_config, seed=1337, dtype=None):
self.n_output_dims = self.native_tcnn_module.n_output_dims()

def _native_tcnn_module(self):
return _C.Module(self.n_input_dims, self.encoding_config, self.precision)
return _C.create_encoding(self.n_input_dims, self.encoding_config, self.precision)
4 changes: 4 additions & 0 deletions include/tiny-cuda-nn/cpp_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <json/json.hpp>

#include <memory>
#include <string>

namespace tcnn {
struct Context {
Expand Down Expand Up @@ -79,6 +80,9 @@ class Module {
return m_output_precision;
}

virtual json hyperparams() const = 0;
virtual std::string name() const = 0;

private:
EPrecision m_param_precision;
EPrecision m_output_precision;
Expand Down
6 changes: 4 additions & 2 deletions include/tiny-cuda-nn/encoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@

TCNN_NAMESPACE_BEGIN

enum InterpolationType {
enum class InterpolationType {
Nearest,
Linear,
Smoothstep,
};

InterpolationType string_to_interpolation_type(std::string interpolation_type);
InterpolationType string_to_interpolation_type(const std::string& interpolation_type);

std::string to_string(InterpolationType interpolation_type);

template <typename T>
class Encoding : public ParametricObject<T> {
Expand Down
12 changes: 12 additions & 0 deletions include/tiny-cuda-nn/encodings/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,18 @@ class CompositeEncoding : public Encoding<T> {
return total;
}

json hyperparams() const override {
json::array_t nested;
for (auto& n : m_nested) {
nested.emplace_back(n->hyperparams());
}

return {
{"otype", "Composite"},
{"nested", nested}
};
}

private:
std::vector<std::unique_ptr<Encoding<T>>> m_nested;
uint32_t m_n_dims_to_encode;
Expand Down
7 changes: 7 additions & 0 deletions include/tiny-cuda-nn/encodings/frequency.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,13 @@ class FrequencyEncoding : public Encoding<T> {
return 1;
}

json hyperparams() const override {
return {
{"otype", "Frequency"},
{"n_frequencies", m_n_frequencies},
};
}

private:
uint32_t m_n_frequencies;
uint32_t m_n_dims_to_encode;
Expand Down
39 changes: 36 additions & 3 deletions include/tiny-cuda-nn/encodings/grid.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,13 @@

TCNN_NAMESPACE_BEGIN

enum GridType { Hash, Dense, Tiled };
enum class GridType {
Hash,
Dense,
Tiled,
};

inline GridType string_to_grid_type(std::string grid_type) {
inline GridType string_to_grid_type(const std::string& grid_type) {
if (equals_case_insensitive(grid_type, "Hash")) {
return GridType::Hash;
} else if (equals_case_insensitive(grid_type, "Dense")) {
Expand All @@ -61,6 +65,15 @@ inline GridType string_to_grid_type(std::string grid_type) {
throw std::runtime_error{std::string{"Invalid grid type: "} + grid_type};
}

inline std::string to_string(GridType grid_type) {
switch (grid_type) {
case GridType::Hash: return "Hash";
case GridType::Dense: return "Dense";
case GridType::Tiled: return "Tiled";
default: throw std::runtime_error{std::string{"Invalid grid type"}};
}
}

template <uint32_t N_DIMS>
__device__ uint32_t fast_hash(const uint32_t pos_grid[N_DIMS]) {
static_assert(N_DIMS <= 7, "fast_hash can only hash up to 7 dimensions.");
Expand Down Expand Up @@ -534,6 +547,7 @@ class GridEncodingTemplated : public GridEncoding<T> {
GridType grid_type
) :
m_n_features{n_features},
m_log2_hashmap_size{log2_hashmap_size},
m_base_resolution{base_resolution},
m_per_level_scale{per_level_scale},
m_stochastic_interpolation{stochastic_interpolation},
Expand Down Expand Up @@ -565,7 +579,7 @@ class GridEncodingTemplated : public GridEncoding<T> {
// If hash table needs fewer params than dense, then use fewer and rely on the hash.
params_in_level = std::min(params_in_level, (1u << log2_hashmap_size));
} else {
throw std::runtime_error{std::string{"GridEncoding: invalid grid type "} + std::to_string(grid_type)};
throw std::runtime_error{std::string{"GridEncoding: invalid grid type "} + to_string(grid_type)};
}

offsets_table_host[i] = offset;
Expand Down Expand Up @@ -834,11 +848,30 @@ class GridEncodingTemplated : public GridEncoding<T> {
return N_FEATURES_PER_LEVEL;
}

json hyperparams() const override {
json result = {
{"otype", "Grid"},
{"type", to_string(m_grid_type)},
{"n_levels", m_n_levels},
{"n_features_per_level", N_FEATURES_PER_LEVEL},
{"base_resolution", m_base_resolution},
{"per_level_scale", m_per_level_scale},
{"interpolation", to_string(m_interpolation_type)},
};

if (m_grid_type == GridType::Hash) {
result["log2_hashmap_size"] = m_log2_hashmap_size;
}

return result;
}

private:
uint32_t m_n_features;
uint32_t m_n_levels;
uint32_t m_n_params;
GPUMemory<uint32_t> m_hashmap_offsets_table;
uint32_t m_log2_hashmap_size;
uint32_t m_base_resolution;

uint32_t m_n_dims_to_pass_through;
Expand Down
8 changes: 8 additions & 0 deletions include/tiny-cuda-nn/encodings/identity.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,14 @@ class IdentityEncoding : public Encoding<T> {
return 1;
}

json hyperparams() const override {
return {
{"otype", "Identity"},
{"scale", m_scale},
{"offset", m_offset},
};
}

private:
uint32_t m_n_dims_to_encode;

Expand Down
7 changes: 7 additions & 0 deletions include/tiny-cuda-nn/encodings/oneblob.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,13 @@ class OneBlobEncoding : public Encoding<T> {
return m_n_bins;
}

json hyperparams() const override {
return {
{"otype", "OneBlob"},
{"n_bins", m_n_bins},
};
}

private:
uint32_t m_n_bins;
uint32_t m_n_dims_to_encode;
Expand Down
Loading

0 comments on commit e7826c3

Please sign in to comment.