In [1]:
import pytest
import torch

from nerfacc import OccupancyGrid, ray_marching, unpack_info

import taichi as ti
ti.init(arch=ti.cuda)

[Taichi] version 1.2.2, llvm 10.0.0, commit 608e4b57, linux, python 3.9.13
[I 01/04/23 14:57:03.752 2940179] [shell.py:_shell_pop_print@33] Graphical python shell detected, using wrapped sys.stdout
[Taichi] Starting on arch=cuda


In [2]:
device = "cuda:0"
batch_size = 128

In [3]:
def test_marching_with_grid_cuda():
    rays_o = torch.rand((batch_size, 3), device=device)
    rays_d = torch.randn((batch_size, 3), device=device)
    rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
    grid = OccupancyGrid(roi_aabb=[0, 0, 0, 1, 1, 1]).to(device)
    grid._binary[:] = True

    packed_info, t_starts, t_ends = ray_marching(
        rays_o,
        rays_d,
        grid=grid,
        near_plane=0.0,
        far_plane=1.0,
        render_step_size=1e-2,
    )
    torch.cuda.synchronize()

In [7]:
%timeit test_marching_with_grid_cuda() # grid_binary.flatten() with block

13.8 ms ± 25.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
%timeit test_marching_with_grid_cuda() # packed_info.flatten() t_starts.flatten() t_ends.flatten()

14.5 ms ± 139 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
%timeit test_marching_with_grid_cuda() # no block

14.4 ms ± 90.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
%timeit test_marching_with_grid_cuda() # grid_binary shape 1

14.7 ms ± 325 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
%timeit test_marching_with_grid_cuda() #cuda

13.9 ms ± 66.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [2]:
import tinycudann as tcnn
import numpy as np

In [3]:
# constants
n_levels = 16
targe_resolution = 1024
base_resolution = 16
b = np.exp(np.log(targe_resolution/base_resolution)/(n_levels-1))
per_level_scale = b

posi_encoder = tcnn.Encoding(
    n_input_dims=3,
    encoding_config={
        "otype": "HashGrid",
        "n_levels": n_levels,
        "n_features_per_level": 2,
        "log2_hashmap_size": 19,
        "base_resolution": base_resolution,
        "per_level_scale": 1.3195079565048218,
    }
)
xyz_encoder = tcnn.Network(
    n_input_dims=32, n_output_dims=16,
    network_config={
        "otype": "FullyFusedMLP",
        "activation": "ReLU",
        "output_activation": "None",
        "n_neurons": 64,
        "n_hidden_layers": 1,
    }
)

In [3]:
posi_encoder.params.shape

torch.Size([11445040])

In [9]:
31 %2

1

In [6]:
posi_encoder.params.shape

torch.Size([11445040])

In [10]:
xyz_encoder.native_tcnn_module.n_input_dims()

32

In [18]:
b_tcnn = posi_encoder.native_tcnn_module.hyperparams()['per_level_scale']
print(f"python per_level_scale: {b}")
print(f"tcnn per_level_scale: {b_tcnn}")

python per_level_scale: 1.3195079107728942
tcnn per_level_scale: 1.3195079565048218


In [8]:
offset = 0
max_params = 2**19
for i in range(n_levels):
    resolution = int(
        np.ceil(
            base_resolution * np.exp(
                i*np.log(1.3195079565048218)
            ) - 1.0
        )
    ) + 1
    params_in_level = resolution ** 3
    params_in_level = int(resolution ** 3) if params_in_level % 8 == 0 else int((params_in_level + 8 - 1) / 8) * 8
    params_in_level = min(max_params, params_in_level)
    offset += params_in_level
    
print(params_in_level)

524288


In [4]:
1e-2

0.01

In [4]:
ress = []
for i in range(n_levels):
    ress.append(int(np.ceil(16 * np.exp(i*np.log(b_tcnn)) - 1.0)) + 1)
    
print(f"tcnn res: {ress}")

ress = []
for i in range(n_levels):
    ress.append(int(np.ceil(16 * np.exp(i*np.log(b)) - 1.0)) + 1)
    
print(f"python res: {ress}")

tcnn res: [16, 27, 45, 74, 123, 204, 338, 562, 934, 1553, 2581, 4290, 7132, 11857, 19711, 32769]
python res: [16, 27, 45, 74, 123, 204, 338, 562, 934, 1553, 2581, 4290, 7132, 11857, 19711, 32768]


In [14]:
posi_encoder.params.shape

torch.Size([12196240])

In [10]:
3*64+64*64+64*64+64*3

8576

In [70]:
2**21

2097152

In [77]:
38**4 - 2**21

-12016

In [29]:
26 / 2

13.0

In [50]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 150
from ipywidgets import interact, widgets
from PIL import Image
import glob

import re
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    '''
    alist.sort(key=natural_keys) sorts in human order
    http://nedbatchelder.com/blog/200712/human_sorting.html
    (See Toothy's implementation in the comments)
    '''
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

In [51]:
def read_images(file_list):
    images = []
    for i in file_list:
        images.append(Image.open(i))
    
    return images
        
        
def browse_images(digits):
    n = len(digits)
    w = widgets.IntSlider(value=15, min=0, max=(n-1))
    def view_image(x):
        plt.imshow(digits[x], cmap=plt.cm.gray_r, interpolation='nearest')
        plt.title(f'rgb_{x}.png')
        plt.show()
    interact(view_image, x=w)

In [55]:
root = "/home/loyot/workspace/code/training_results/nerfacc/ngp_dnerf/trex/lr_0-01/hl1_pf_nopw_l-huber_dive1024.0_op_te_ta/rgb*.png"
images_list = glob.glob(root)
images_list.sort(key=natural_keys)
image_data = read_images(images_list)
browse_images(image_data)

interactive(children=(IntSlider(value=15, description='x', max=19), Output()), _dom_classes=('widget-interact'…

In [85]:
root = "/home/loyot/workspace/code/training_results/nerfacc/ngp_dnerf/trex/lr_0-01/pf_nopw_l-huber/training/rgb*.png"
images_list = glob.glob(root)
images_list.sort(key=natural_keys)
image_data = read_images(images_list)
browse_images(image_data)

interactive(children=(IntSlider(value=0, description='x', max=20), Output()), _dom_classes=('widget-interact',…

In [31]:
image_data = read_images(images_list)
browse_images(image_data)

interactive(children=(IntSlider(value=15, description='x', max=19), Output()), _dom_classes=('widget-interact'…

In [15]:
import taichi as ti
import numpy as np
ti.init(arch=ti.cuda)

[Taichi] Starting on arch=cuda


In [13]:
@ti.func
def scalbn(x, exponent):
    return x * ti.math.pow(2, exponent)

@ti.func
def frexp_hard(x, base):
    exponent = 0
    start = base
    for i in ti.static(range(6)):
        # start = start * (2**i)
        start_next = start * 2
        if start <= x and x < start_next:
            exponent = i+1
            break
        start = start_next
    return exponent

@ti.func
def frexp_bit(x):
    exponent = 0
    if x != 0.0:
        bits = ti.bit_cast(x, ti.u32)
        # exponent = (ti.u32(bits >> 23)) & 0x7f
        exponent = (ti.i32(bits & ti.u32(0x7f800000)) >> 23) - 127
        bits &= ti.u32(0x7fffff)
        bits |= ti.u32(0x3f800000)
        frac = ti.bit_cast(bits, ti.f32)
        print("exponent: ", exponent)
        # print("frac: ", frac)
        
    return exponent

@ti.func
def frexp(x):
    exponent = 0
    if x != 0.0:
        exponent = ti.i32(ti.math.log2(x)) + 1
        frac = x / ti.pow(2, exponent)
        ori = frac * ti.pow(2, exponent)
        print("exponent: ", exponent)
        print("frac: ", frac)
        print("ori: ", ori)
    return exponent

# exponent = static_cast<int>(std::log2(std::abs(x))) + 1;
#     x /= std::pow(2, exponent);

@ti.func
def mip_from_pos(xyz, cascades):
    mx = ti.abs(xyz).max()
    exponent = frexp_bit(mx)
    return ti.min(cascades-1, ti.max(0, exponent+1))

@ti.func
def mip_from_dt(dt, grid_size, cascades):
    exponent = frexp_bit(dt*grid_size)
    return ti.min(cascades-1, ti.max(0, exponent))


In [14]:
scale = 16
x = ti.math.vec3(0.0689, 0.2918, 0.5009)
dt = 0.0076
grid_size = 128
grid_size3 = grid_size*grid_size*grid_size
cascades = max(1+int(np.ceil(np.log2(2*scale))), 1)
print("cascades: ", cascades)
@ti.kernel
def test():
    print("----mip_pos----")
    mip_pos = mip_from_pos(x, cascades)
    print("mip_pos: ", mip_pos)
    print("----mip_dt----")
    mip_dt = mip_from_dt(dt, grid_size, cascades)
    print("mip_dt: ", mip_dt)
    
    mip = ti.max(mip_pos, mip_dt)
    # mip_bound = ti.min(2**(mip-1), 16.0)
    print("----final_mip----")
    print("mip: ", mip)
    # print("mip_bound: ", mip_bound)
#     print("idx_start: ", mip*grid_size3)
    
test()

cascades:  6
----mip_pos----
exponent:  -1
mip_pos:  0
----mip_dt----
exponent:  -1
mip_dt:  0
----final_mip----
mip:  0


In [7]:
0.0077 * 128

0.9856

In [1]:
0.0078 * 128

0.9984

In [None]:
xyz_save_taichi = torch.([[ -3.3070,   1.5409,   3.1265],
                   [ -6.2736,   1.6399,   7.3511],
                   [-13.9541,   2.9160,  15.9904]])
0.0202, 0.0378, 0.0769

In [5]:
grid_size3*6

12582912

In [19]:
@ti.kernel
def test_frexp():
    exponent = frexp_bit(0.1)
    print('exponent: ', exponent)

test_frexp()

exponent:  -2


In [22]:
0.4*2**(-2)

0.1

In [21]:
@ti.kernel
def test_bitcast():
    bits = ti.bit_cast(1.1, ti.u32)
    print("bits: ", bits)

test_bitcast()

bits:  1066192077
