-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
nerf_network.h
505 lines (407 loc) · 20 KB
/
nerf_network.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
*
* NVIDIA CORPORATION and its licensors retain all intellectual property
* and proprietary rights in and to this software, related documentation
* and any modifications thereto. Any use, reproduction, disclosure or
* distribution of this software and related documentation without an express
* license agreement from NVIDIA CORPORATION is strictly prohibited.
*/
/** @file nerf_network.h
* @author Thomas Müller, NVIDIA
* @brief A network that first processes 3D position to density and
* subsequently direction to color.
*/
#pragma once
#include <tiny-cuda-nn/common.h>
#include <tiny-cuda-nn/encoding.h>
#include <tiny-cuda-nn/gpu_matrix.h>
#include <tiny-cuda-nn/gpu_memory.h>
#include <tiny-cuda-nn/multi_stream.h>
#include <tiny-cuda-nn/network.h>
#include <tiny-cuda-nn/network_with_input_encoding.h>
namespace ngp {
template <typename T>
__global__ void extract_density(
const uint32_t n_elements,
const uint32_t density_stride,
const uint32_t rgbd_stride,
const T* __restrict__ density,
T* __restrict__ rgbd
) {
const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= n_elements) return;
rgbd[i * rgbd_stride] = density[i * density_stride];
}
template <typename T>
__global__ void extract_rgb(
const uint32_t n_elements,
const uint32_t rgb_stride,
const uint32_t output_stride,
const T* __restrict__ rgbd,
T* __restrict__ rgb
) {
const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= n_elements) return;
const uint32_t elem_idx = i / 3;
const uint32_t dim_idx = i - elem_idx * 3;
rgb[elem_idx*rgb_stride + dim_idx] = rgbd[elem_idx*output_stride + dim_idx];
}
template <typename T>
__global__ void add_density_gradient(
const uint32_t n_elements,
const uint32_t rgbd_stride,
const T* __restrict__ rgbd,
const uint32_t density_stride,
T* __restrict__ density
) {
const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= n_elements) return;
density[i * density_stride] += rgbd[i * rgbd_stride + 3];
}
template <typename T>
class NerfNetwork : public Network<float, T> {
public:
using json = nlohmann::json;
NerfNetwork(uint32_t n_pos_dims, uint32_t n_dir_dims, uint32_t n_extra_dims, uint32_t dir_offset, const json& pos_encoding, const json& dir_encoding, const json& density_network, const json& rgb_network) : m_n_pos_dims{n_pos_dims}, m_n_dir_dims{n_dir_dims}, m_dir_offset{dir_offset}, m_n_extra_dims{n_extra_dims} {
m_pos_encoding.reset(create_encoding<T>(n_pos_dims, pos_encoding, density_network.contains("otype") && (equals_case_insensitive(density_network["otype"], "FullyFusedMLP") || equals_case_insensitive(density_network["otype"], "MegakernelMLP")) ? 16u : 8u));
uint32_t rgb_alignment = minimum_alignment(rgb_network);
m_dir_encoding.reset(create_encoding<T>(m_n_dir_dims + m_n_extra_dims, dir_encoding, rgb_alignment));
json local_density_network_config = density_network;
local_density_network_config["n_input_dims"] = m_pos_encoding->padded_output_width();
if (!density_network.contains("n_output_dims")) {
local_density_network_config["n_output_dims"] = 16;
}
m_density_network.reset(create_network<T>(local_density_network_config));
m_rgb_network_input_width = next_multiple(m_dir_encoding->padded_output_width() + m_density_network->padded_output_width(), rgb_alignment);
json local_rgb_network_config = rgb_network;
local_rgb_network_config["n_input_dims"] = m_rgb_network_input_width;
local_rgb_network_config["n_output_dims"] = 3;
m_rgb_network.reset(create_network<T>(local_rgb_network_config));
m_density_model = std::make_shared<NetworkWithInputEncoding<T>>(m_pos_encoding, m_density_network);
}
virtual ~NerfNetwork() { }
void inference_mixed_precision_impl(cudaStream_t stream, const GPUMatrixDynamic<float>& input, GPUMatrixDynamic<T>& output, bool use_inference_params = true) override {
uint32_t batch_size = input.n();
GPUMatrixDynamic<T> density_network_input{m_pos_encoding->padded_output_width(), batch_size, stream, m_pos_encoding->preferred_output_layout()};
GPUMatrixDynamic<T> rgb_network_input{m_rgb_network_input_width, batch_size, stream, m_dir_encoding->preferred_output_layout()};
GPUMatrixDynamic<T> density_network_output = rgb_network_input.slice_rows(0, m_density_network->padded_output_width());
GPUMatrixDynamic<T> rgb_network_output{output.data(), m_rgb_network->padded_output_width(), batch_size, output.layout()};
m_pos_encoding->inference_mixed_precision(
stream,
input.slice_rows(0, m_pos_encoding->input_width()),
density_network_input,
use_inference_params
);
m_density_network->inference_mixed_precision(stream, density_network_input, density_network_output, use_inference_params);
auto dir_out = rgb_network_input.slice_rows(m_density_network->padded_output_width(), m_dir_encoding->padded_output_width());
m_dir_encoding->inference_mixed_precision(
stream,
input.slice_rows(m_dir_offset, m_dir_encoding->input_width()),
dir_out,
use_inference_params
);
m_rgb_network->inference_mixed_precision(stream, rgb_network_input, rgb_network_output, use_inference_params);
linear_kernel(extract_density<T>, 0, stream,
batch_size,
density_network_output.layout() == AoS ? density_network_output.stride() : 1,
output.layout() == AoS ? padded_output_width() : 1,
density_network_output.data(),
output.data() + 3 * (output.layout() == AoS ? 1 : batch_size)
);
}
uint32_t padded_density_output_width() const {
return m_density_network->padded_output_width();
}
std::unique_ptr<Context> forward_impl(cudaStream_t stream, const GPUMatrixDynamic<float>& input, GPUMatrixDynamic<T>* output = nullptr, bool use_inference_params = false, bool prepare_input_gradients = false) override {
// Make sure our temporary buffers have the correct size for the given batch size
uint32_t batch_size = input.n();
auto forward = std::make_unique<ForwardContext>();
forward->density_network_input = GPUMatrixDynamic<T>{m_pos_encoding->padded_output_width(), batch_size, stream, m_pos_encoding->preferred_output_layout()};
forward->rgb_network_input = GPUMatrixDynamic<T>{m_rgb_network_input_width, batch_size, stream, m_dir_encoding->preferred_output_layout()};
forward->pos_encoding_ctx = m_pos_encoding->forward(
stream,
input.slice_rows(0, m_pos_encoding->input_width()),
&forward->density_network_input,
use_inference_params,
prepare_input_gradients
);
forward->density_network_output = forward->rgb_network_input.slice_rows(0, m_density_network->padded_output_width());
forward->density_network_ctx = m_density_network->forward(stream, forward->density_network_input, &forward->density_network_output, use_inference_params, prepare_input_gradients);
auto dir_out = forward->rgb_network_input.slice_rows(m_density_network->padded_output_width(), m_dir_encoding->padded_output_width());
forward->dir_encoding_ctx = m_dir_encoding->forward(
stream,
input.slice_rows(m_dir_offset, m_dir_encoding->input_width()),
&dir_out,
use_inference_params,
prepare_input_gradients
);
if (output) {
forward->rgb_network_output = GPUMatrixDynamic<T>{output->data(), m_rgb_network->padded_output_width(), batch_size, output->layout()};
}
forward->rgb_network_ctx = m_rgb_network->forward(stream, forward->rgb_network_input, output ? &forward->rgb_network_output : nullptr, use_inference_params, prepare_input_gradients);
if (output) {
linear_kernel(extract_density<T>, 0, stream,
batch_size, m_dir_encoding->preferred_output_layout() == AoS ? forward->density_network_output.stride() : 1, padded_output_width(), forward->density_network_output.data(), output->data()+3
);
}
return forward;
}
void backward_impl(
cudaStream_t stream,
const Context& ctx,
const GPUMatrixDynamic<float>& input,
const GPUMatrixDynamic<T>& output,
const GPUMatrixDynamic<T>& dL_doutput,
GPUMatrixDynamic<float>* dL_dinput = nullptr,
bool use_inference_params = false,
GradientMode param_gradients_mode = GradientMode::Overwrite
) override {
const auto& forward = dynamic_cast<const ForwardContext&>(ctx);
// Make sure our teporary buffers have the correct size for the given batch size
uint32_t batch_size = input.n();
GPUMatrix<T> dL_drgb{m_rgb_network->padded_output_width(), batch_size, stream};
CUDA_CHECK_THROW(cudaMemsetAsync(dL_drgb.data(), 0, dL_drgb.n_bytes(), stream));
linear_kernel(extract_rgb<T>, 0, stream,
batch_size*3, dL_drgb.m(), dL_doutput.m(), dL_doutput.data(), dL_drgb.data()
);
const GPUMatrixDynamic<T> rgb_network_output{(T*)output.data(), m_rgb_network->padded_output_width(), batch_size, output.layout()};
GPUMatrixDynamic<T> dL_drgb_network_input{m_rgb_network_input_width, batch_size, stream, m_dir_encoding->preferred_output_layout()};
m_rgb_network->backward(stream, *forward.rgb_network_ctx, forward.rgb_network_input, rgb_network_output, dL_drgb, &dL_drgb_network_input, use_inference_params, param_gradients_mode);
// Backprop through dir encoding if it is trainable or if we need input gradients
if (m_dir_encoding->n_params() > 0 || dL_dinput) {
GPUMatrixDynamic<T> dL_ddir_encoding_output = dL_drgb_network_input.slice_rows(m_density_network->padded_output_width(), m_dir_encoding->padded_output_width());
GPUMatrixDynamic<float> dL_ddir_encoding_input;
if (dL_dinput) {
dL_ddir_encoding_input = dL_dinput->slice_rows(m_dir_offset, m_dir_encoding->input_width());
}
m_dir_encoding->backward(
stream,
*forward.dir_encoding_ctx,
input.slice_rows(m_dir_offset, m_dir_encoding->input_width()),
forward.rgb_network_input.slice_rows(m_density_network->padded_output_width(), m_dir_encoding->padded_output_width()),
dL_ddir_encoding_output,
dL_dinput ? &dL_ddir_encoding_input : nullptr,
use_inference_params,
param_gradients_mode
);
}
GPUMatrixDynamic<T> dL_ddensity_network_output = dL_drgb_network_input.slice_rows(0, m_density_network->padded_output_width());
linear_kernel(add_density_gradient<T>, 0, stream,
batch_size,
dL_doutput.m(),
dL_doutput.data(),
dL_ddensity_network_output.layout() == RM ? 1 : dL_ddensity_network_output.stride(),
dL_ddensity_network_output.data()
);
GPUMatrixDynamic<T> dL_ddensity_network_input;
if (m_pos_encoding->n_params() > 0 || dL_dinput) {
dL_ddensity_network_input = GPUMatrixDynamic<T>{m_pos_encoding->padded_output_width(), batch_size, stream, m_pos_encoding->preferred_output_layout()};
}
m_density_network->backward(stream, *forward.density_network_ctx, forward.density_network_input, forward.density_network_output, dL_ddensity_network_output, dL_ddensity_network_input.data() ? &dL_ddensity_network_input : nullptr, use_inference_params, param_gradients_mode);
// Backprop through pos encoding if it is trainable or if we need input gradients
if (dL_ddensity_network_input.data()) {
GPUMatrixDynamic<float> dL_dpos_encoding_input;
if (dL_dinput) {
dL_dpos_encoding_input = dL_dinput->slice_rows(0, m_pos_encoding->input_width());
}
m_pos_encoding->backward(
stream,
*forward.pos_encoding_ctx,
input.slice_rows(0, m_pos_encoding->input_width()),
forward.density_network_input,
dL_ddensity_network_input,
dL_dinput ? &dL_dpos_encoding_input : nullptr,
use_inference_params,
param_gradients_mode
);
}
}
void density(cudaStream_t stream, const GPUMatrixDynamic<float>& input, GPUMatrixDynamic<T>& output, bool use_inference_params = true) {
if (input.layout() != CM) {
throw std::runtime_error("NerfNetwork::density input must be in column major format.");
}
uint32_t batch_size = output.n();
GPUMatrixDynamic<T> density_network_input{m_pos_encoding->padded_output_width(), batch_size, stream, m_pos_encoding->preferred_output_layout()};
m_density_model->inference_mixed_precision(stream, input.slice_rows(0, m_pos_encoding->input_width()), output, use_inference_params);
}
std::unique_ptr<Context> density_forward(cudaStream_t stream, const GPUMatrixDynamic<float>& input, GPUMatrixDynamic<T>* output = nullptr, bool use_inference_params = false, bool prepare_input_gradients = false) {
if (input.layout() != CM) {
throw std::runtime_error("NerfNetwork::density_forward input must be in column major format.");
}
// Make sure our temporary buffers have the correct size for the given batch size
uint32_t batch_size = input.n();
auto forward = std::make_unique<ForwardContext>();
forward->density_network_input = GPUMatrixDynamic<T>{m_pos_encoding->padded_output_width(), batch_size, stream, m_pos_encoding->preferred_output_layout()};
forward->pos_encoding_ctx = m_pos_encoding->forward(
stream,
input.slice_rows(0, m_pos_encoding->input_width()),
&forward->density_network_input,
use_inference_params,
prepare_input_gradients
);
if (output) {
forward->density_network_output = GPUMatrixDynamic<T>{output->data(), m_density_network->padded_output_width(), batch_size, output->layout()};
}
forward->density_network_ctx = m_density_network->forward(stream, forward->density_network_input, output ? &forward->density_network_output : nullptr, use_inference_params, prepare_input_gradients);
return forward;
}
void density_backward(
cudaStream_t stream,
const Context& ctx,
const GPUMatrixDynamic<float>& input,
const GPUMatrixDynamic<T>& output,
const GPUMatrixDynamic<T>& dL_doutput,
GPUMatrixDynamic<float>* dL_dinput = nullptr,
bool use_inference_params = false,
GradientMode param_gradients_mode = GradientMode::Overwrite
) {
if (input.layout() != CM || (dL_dinput && dL_dinput->layout() != CM)) {
throw std::runtime_error("NerfNetwork::density_backward input must be in column major format.");
}
const auto& forward = dynamic_cast<const ForwardContext&>(ctx);
// Make sure our temporary buffers have the correct size for the given batch size
uint32_t batch_size = input.n();
GPUMatrixDynamic<T> dL_ddensity_network_input;
if (m_pos_encoding->n_params() > 0 || dL_dinput) {
dL_ddensity_network_input = GPUMatrixDynamic<T>{m_pos_encoding->padded_output_width(), batch_size, stream, m_pos_encoding->preferred_output_layout()};
}
m_density_network->backward(stream, *forward.density_network_ctx, forward.density_network_input, output, dL_doutput, dL_ddensity_network_input.data() ? &dL_ddensity_network_input : nullptr, use_inference_params, param_gradients_mode);
// Backprop through pos encoding if it is trainable or if we need input gradients
if (dL_ddensity_network_input.data()) {
GPUMatrixDynamic<float> dL_dpos_encoding_input;
if (dL_dinput) {
dL_dpos_encoding_input = dL_dinput->slice_rows(0, m_pos_encoding->input_width());
}
m_pos_encoding->backward(
stream,
*forward.pos_encoding_ctx,
input.slice_rows(0, m_pos_encoding->input_width()),
forward.density_network_input,
dL_ddensity_network_input,
dL_dinput ? &dL_dpos_encoding_input : nullptr,
use_inference_params,
param_gradients_mode
);
}
}
void set_params_impl(T* params, T* inference_params, T* gradients) override {
m_density_model->set_params(params, inference_params, gradients);
size_t offset = 0;
m_density_network->set_params(params + offset, inference_params + offset, gradients + offset);
offset += m_density_network->n_params();
m_rgb_network->set_params(params + offset, inference_params + offset, gradients + offset);
offset += m_rgb_network->n_params();
m_pos_encoding->set_params(params + offset, inference_params + offset, gradients + offset);
offset += m_pos_encoding->n_params();
m_dir_encoding->set_params(params + offset, inference_params + offset, gradients + offset);
offset += m_dir_encoding->n_params();
}
void initialize_params(pcg32& rnd, float* params_full_precision, float scale = 1) override {
m_density_network->initialize_params(rnd, params_full_precision, scale);
params_full_precision += m_density_network->n_params();
m_rgb_network->initialize_params(rnd, params_full_precision, scale);
params_full_precision += m_rgb_network->n_params();
m_pos_encoding->initialize_params(rnd, params_full_precision, scale);
params_full_precision += m_pos_encoding->n_params();
m_dir_encoding->initialize_params(rnd, params_full_precision, scale);
params_full_precision += m_dir_encoding->n_params();
}
size_t n_params() const override {
return m_pos_encoding->n_params() + m_density_network->n_params() + m_dir_encoding->n_params() + m_rgb_network->n_params();
}
uint32_t padded_output_width() const override {
return std::max(m_rgb_network->padded_output_width(), (uint32_t)4);
}
uint32_t input_width() const override {
return m_dir_offset + m_n_dir_dims + m_n_extra_dims;
}
uint32_t output_width() const override {
return 4;
}
uint32_t n_extra_dims() const {
return m_n_extra_dims;
}
uint32_t required_input_alignment() const override {
return 1; // No alignment required due to encoding
}
std::vector<std::pair<uint32_t, uint32_t>> layer_sizes() const override {
auto layers = m_density_network->layer_sizes();
auto rgb_layers = m_rgb_network->layer_sizes();
layers.insert(layers.end(), rgb_layers.begin(), rgb_layers.end());
return layers;
}
uint32_t width(uint32_t layer) const override {
if (layer == 0) {
return m_pos_encoding->padded_output_width();
} else if (layer < m_density_network->num_forward_activations() + 1) {
return m_density_network->width(layer - 1);
} else if (layer == m_density_network->num_forward_activations() + 1) {
return m_rgb_network_input_width;
} else {
return m_rgb_network->width(layer - 2 - m_density_network->num_forward_activations());
}
}
uint32_t num_forward_activations() const override {
return m_density_network->num_forward_activations() + m_rgb_network->num_forward_activations() + 2;
}
std::pair<const T*, MatrixLayout> forward_activations(const Context& ctx, uint32_t layer) const override {
const auto& forward = dynamic_cast<const ForwardContext&>(ctx);
if (layer == 0) {
return {forward.density_network_input.data(), m_pos_encoding->preferred_output_layout()};
} else if (layer < m_density_network->num_forward_activations() + 1) {
return m_density_network->forward_activations(*forward.density_network_ctx, layer - 1);
} else if (layer == m_density_network->num_forward_activations() + 1) {
return {forward.rgb_network_input.data(), m_dir_encoding->preferred_output_layout()};
} else {
return m_rgb_network->forward_activations(*forward.rgb_network_ctx, layer - 2 - m_density_network->num_forward_activations());
}
}
const std::shared_ptr<Encoding<T>>& pos_encoding() const {
return m_pos_encoding;
}
const std::shared_ptr<Encoding<T>>& dir_encoding() const {
return m_dir_encoding;
}
const std::shared_ptr<Network<T>>& density_network() const {
return m_density_network;
}
const std::shared_ptr<Network<T>>& rgb_network() const {
return m_rgb_network;
}
json hyperparams() const override {
json density_network_hyperparams = m_density_network->hyperparams();
density_network_hyperparams["n_output_dims"] = m_density_network->padded_output_width();
return {
{"otype", "NerfNetwork"},
{"pos_encoding", m_pos_encoding->hyperparams()},
{"dir_encoding", m_dir_encoding->hyperparams()},
{"density_network", density_network_hyperparams},
{"rgb_network", m_rgb_network->hyperparams()},
};
}
private:
std::shared_ptr<Network<T>> m_density_network;
std::shared_ptr<Network<T>> m_rgb_network;
std::shared_ptr<Encoding<T>> m_pos_encoding;
std::shared_ptr<Encoding<T>> m_dir_encoding;
// Aggregates m_pos_encoding and m_density_network
std::shared_ptr<NetworkWithInputEncoding<T>> m_density_model;
uint32_t m_rgb_network_input_width;
uint32_t m_n_pos_dims;
uint32_t m_n_dir_dims;
uint32_t m_n_extra_dims; // extra dimensions are assumed to be part of a compound encoding with dir_dims
uint32_t m_dir_offset;
// // Storage of forward pass data
struct ForwardContext : public Context {
GPUMatrixDynamic<T> density_network_input;
GPUMatrixDynamic<T> density_network_output;
GPUMatrixDynamic<T> rgb_network_input;
GPUMatrix<T> rgb_network_output;
std::unique_ptr<Context> pos_encoding_ctx;
std::unique_ptr<Context> dir_encoding_ctx;
std::unique_ptr<Context> density_network_ctx;
std::unique_ptr<Context> rgb_network_ctx;
};
};
}