In [1]:
import torch
device = torch.device("cuda")

In [2]:
import numpy as np
import taichi as ti

ti.init(arch=ti.cuda, device_memory_GB=36.0)

base_resolution=16
base_resolution_t=16
finest_resolution=256
finest_resolution_t=128
num_levels=16
log2_hashmap_size=19

lrate=0.01
lrate_decay=10000

from encoder import HashEncoderHyFluid

ENCODER = HashEncoderHyFluid(
    min_res=np.array([base_resolution, base_resolution, base_resolution, base_resolution_t]),
    max_res=np.array(
        [finest_resolution, finest_resolution, finest_resolution, finest_resolution_t]),
    num_scales=num_levels,
    max_params=2 ** log2_hashmap_size).to(device)
ENCODER_params = list(ENCODER.parameters())

[Taichi] version 1.7.2, llvm 15.0.1, commit 0131dce9, win, python 3.11.0
[Taichi] Starting on arch=cuda


  @custom_fwd(cast_inputs=torch.float32)
  @custom_bwd
  @custom_fwd(cast_inputs=torch.float32)
  @custom_bwd


In [3]:
class NeRFSmall(torch.nn.Module):
    def __init__(self,
                 num_layers=3,
                 hidden_dim=64,
                 geo_feat_dim=15,
                 num_layers_color=2,
                 hidden_dim_color=16,
                 input_ch=3,
                 ):
        super(NeRFSmall, self).__init__()

        self.input_ch = input_ch
        self.rgb = torch.nn.Parameter(torch.tensor([0.0]))

        # sigma network
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.geo_feat_dim = geo_feat_dim

        sigma_net = []
        for l in range(num_layers):
            if l == 0:
                in_dim = self.input_ch
            else:
                in_dim = hidden_dim

            if l == num_layers - 1:
                out_dim = 1  # 1 sigma + 15 SH features for color
            else:
                out_dim = hidden_dim

            sigma_net.append(torch.nn.Linear(in_dim, out_dim, bias=False))

        self.sigma_net = torch.nn.ModuleList(sigma_net)

        self.color_net = []
        for l in range(num_layers_color):
            if l == 0:
                in_dim = 1
            else:
                in_dim = hidden_dim_color

            if l == num_layers_color - 1:
                out_dim = 1
            else:
                out_dim = hidden_dim_color

            self.color_net.append(torch.nn.Linear(in_dim, out_dim, bias=True))

    def forward(self, x):
        h = x
        for l in range(self.num_layers):
            h = self.sigma_net[l](h)
            h = torch.nn.functional.relu(h, inplace=True)

        sigma = h
        return sigma


MODEL = NeRFSmall(num_layers=2,
                  hidden_dim=64,
                  geo_feat_dim=15,
                  num_layers_color=2,
                  hidden_dim_color=16,
                  input_ch=ENCODER.num_scales * 2).to(device)
GRAD_vars = list(MODEL.parameters())

In [4]:
torch.save({
    'global_step': global_step,
    'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
    'embed_fn_state_dict': render_kwargs_train['embed_fn'].state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, path)

OrderedDict([('rgb', tensor([0.], device='cuda:0')), ('sigma_net.0.weight', tensor([[ 0.0558, -0.0628,  0.1638,  ..., -0.0514,  0.1751, -0.0634],
        [ 0.1027, -0.1041,  0.1392,  ..., -0.0491,  0.1338, -0.0573],
        [-0.0621, -0.1157,  0.1138,  ..., -0.1672, -0.0187, -0.1348],
        ...,
        [-0.0463,  0.0328,  0.0727,  ...,  0.0618,  0.1649,  0.1239],
        [ 0.1692, -0.1448, -0.0416,  ...,  0.1656, -0.0307, -0.0014],
        [ 0.0795, -0.1002, -0.1327,  ...,  0.0399,  0.1681,  0.0923]],
       device='cuda:0')), ('sigma_net.1.weight', tensor([[-0.0694,  0.0107, -0.0133, -0.1072, -0.1051,  0.0851,  0.0395,  0.0086,
         -0.0155, -0.1220, -0.0845, -0.0150,  0.0041,  0.0150, -0.0963, -0.0140,
         -0.0322,  0.0933,  0.0602, -0.0018,  0.0137,  0.1026,  0.0538,  0.1168,
         -0.0708,  0.0440, -0.1083,  0.0329,  0.0144, -0.1174, -0.0546, -0.0962,
         -0.0291, -0.0273,  0.0809,  0.0407,  0.0645,  0.1233, -0.0516,  0.0397,
          0.0058, -0.0120,  0.0295, 