In [1]:
import torch
import taichi as ti
import numpy as np
import math

device = torch.device("cuda")
ti.init(arch=ti.cuda, device_memory_GB=36.0)

[Taichi] version 1.7.2, llvm 15.0.1, commit 0131dce9, win, python 3.11.0
[Taichi] Starting on arch=cuda


In [2]:
from types import SimpleNamespace

args_npz = np.load("args.npz", allow_pickle=True)
ARGs = SimpleNamespace(**{
    key: value.item() if isinstance(value, np.ndarray) and value.size == 1 else
    value.tolist() if isinstance(value, np.ndarray) else
    value
    for key, value in args_npz.items()
})
del args_npz

print(type(ARGs))

<class 'types.SimpleNamespace'>


In [3]:
pinf_data = np.load("train_dataset.npz")
IMAGE_TRAIN_np = pinf_data['images_train']
POSES_TRAIN_np = pinf_data['poses_train']
HWF_np = pinf_data['hwf']
RENDER_POSE_np = pinf_data['render_poses']
RENDER_TIMESTEPs_np = pinf_data['render_timesteps']
VOXEL_TRAN_np = pinf_data['voxel_tran']
VOXEL_SCALE_np = pinf_data['voxel_scale']
NEAR_float = pinf_data['near'].item()
FAR_float = pinf_data['far'].item()

pinf_data_test = np.load("test_dataset.npz")
IMAGE_TEST_np = pinf_data_test['images_test']
POSES_TEST_np = pinf_data_test['poses_test']

del pinf_data
del pinf_data_test
print(f'IMAGE_TRAIN_np.shape: {IMAGE_TRAIN_np.shape}')
print(f'POSES_TRAIN_np.shape: {POSES_TRAIN_np.shape}')
print(f'HWF_np: {HWF_np}')
print(f'RENDER_POSE_np.shape: {RENDER_POSE_np.shape}')
print(f'RENDER_TIMESTEPs_np.shape: {RENDER_TIMESTEPs_np.shape}')
print(f'VOXEL_TRAN_np.shape: {VOXEL_TRAN_np.shape}')
print(f'VOXEL_SCALE_np.shape: {VOXEL_SCALE_np.shape}')
print(f'NEAR_float: {NEAR_float}')
print(f'FAR_float: {FAR_float}')
print(f'IMAGE_TEST_np.shape: {IMAGE_TEST_np.shape}')
print(f'POSES_TEST_np.shape: {POSES_TEST_np.shape}')

IMAGE_TRAIN_np.shape: (120, 4, 960, 540, 3)
POSES_TRAIN_np.shape: (4, 4, 4)
HWF_np: [ 960.      540.     1306.8817]
RENDER_POSE_np.shape: (120, 4, 4)
RENDER_TIMESTEPs_np.shape: (120,)
VOXEL_TRAN_np.shape: (4, 4)
VOXEL_SCALE_np.shape: (3,)
NEAR_float: 1.1
FAR_float: 1.5
IMAGE_TEST_np.shape: (120, 1, 960, 540, 3)
POSES_TEST_np.shape: (1, 4, 4)


In [4]:
from encoder import HashEncoderHyFluid

ENCODER = HashEncoderHyFluid(
    min_res=np.array([ARGs.base_resolution, ARGs.base_resolution, ARGs.base_resolution, ARGs.base_resolution_t]),
    max_res=np.array(
        [ARGs.finest_resolution, ARGs.finest_resolution, ARGs.finest_resolution, ARGs.finest_resolution_t]),
    num_scales=ARGs.num_levels,
    max_params=2 ** ARGs.log2_hashmap_size).to(device)
ENCODER_params = list(ENCODER.parameters())

  @custom_fwd(cast_inputs=torch.float32)
  @custom_bwd
  @custom_fwd(cast_inputs=torch.float32)
  @custom_bwd


In [5]:
class NeRFSmall(torch.nn.Module):
    def __init__(self,
                 num_layers=3,
                 hidden_dim=64,
                 geo_feat_dim=15,
                 num_layers_color=2,
                 hidden_dim_color=16,
                 input_ch=3,
                 ):
        super(NeRFSmall, self).__init__()

        self.input_ch = input_ch
        self.rgb = torch.nn.Parameter(torch.tensor([0.0]))

        # sigma network
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.geo_feat_dim = geo_feat_dim

        sigma_net = []
        for l in range(num_layers):
            if l == 0:
                in_dim = self.input_ch
            else:
                in_dim = hidden_dim

            if l == num_layers - 1:
                out_dim = 1  # 1 sigma + 15 SH features for color
            else:
                out_dim = hidden_dim

            sigma_net.append(torch.nn.Linear(in_dim, out_dim, bias=False))

        self.sigma_net = torch.nn.ModuleList(sigma_net)

        self.color_net = []
        for l in range(num_layers_color):
            if l == 0:
                in_dim = 1
            else:
                in_dim = hidden_dim_color

            if l == num_layers_color - 1:
                out_dim = 1
            else:
                out_dim = hidden_dim_color

            self.color_net.append(torch.nn.Linear(in_dim, out_dim, bias=True))

    def forward(self, x):
        h = x
        for l in range(self.num_layers):
            h = self.sigma_net[l](h)
            h = torch.nn.functional.relu(h, inplace=True)

        sigma = h
        return sigma


MODEL = NeRFSmall(num_layers=2,
                  hidden_dim=64,
                  geo_feat_dim=15,
                  num_layers_color=2,
                  hidden_dim_color=16,
                  input_ch=ENCODER.num_scales * 2).to(device)
GRAD_vars = list(MODEL.parameters())

In [6]:
from torch.optim.optimizer import Optimizer


class RAdam(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))

        self.degenerated_to_sgd = degenerated_to_sgd
        if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
            for param in params:
                if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
                    param['buffer'] = [[None, None, None] for _ in range(10)]
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
                        buffer=[[None, None, None] for _ in range(10)])
        super(RAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('RAdam does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

                state['step'] += 1
                buffered = group['buffer'][int(state['step'] % 10)]
                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma

                    # more conservative since it's an approximated value
                    if N_sma >= 5:
                        step_size = math.sqrt(
                            (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
                                    N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    elif self.degenerated_to_sgd:
                        step_size = 1.0 / (1 - beta1 ** state['step'])
                    else:
                        step_size = -1
                    buffered[2] = step_size

                # more conservative since it's an approximated value
                if N_sma >= 5:
                    if group['weight_decay'] != 0:
                        p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
                    p.data.copy_(p_data_fp32)
                elif step_size > 0:
                    if group['weight_decay'] != 0:
                        p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
                    p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
                    p.data.copy_(p_data_fp32)

        return loss


optimizer = RAdam([
    {'params': GRAD_vars, 'weight_decay': 1e-6},
    {'params': ENCODER_params, 'eps': 1e-15}
], lr=ARGs.lrate, betas=(0.9, 0.99))
GRAD_vars += list(ENCODER_params)

print(f'len(GRAD_vars): {len(GRAD_vars)}')

len(GRAD_vars): 4


In [7]:
H = int(HWF_np[0])
W = int(HWF_np[1])
FOCAL = float(HWF_np[2])
K = np.array([[FOCAL, 0, 0.5 * W], [0, FOCAL, 0.5 * H], [0, 0, 1]])
print(f'H: {H}, W: {W}, FOCAL: {FOCAL}, K: {K}')

H: 960, W: 540, FOCAL: 1306.8817138671875, K: [[1.30688171e+03 0.00000000e+00 2.70000000e+02]
 [0.00000000e+00 1.30688171e+03 4.80000000e+02]
 [0.00000000e+00 0.00000000e+00 1.00000000e+00]]


In [8]:
def pos_world2smoke(Pworld, w2s, scale_vector):
    pos_rot = torch.sum(Pworld[..., None, :] * (w2s[:3, :3]), -1)  # 4.world to 3.target
    pos_off = (w2s[:3, -1]).expand(pos_rot.shape)  # 4.world to 3.target
    new_pose = pos_rot + pos_off
    pos_scale = new_pose / (scale_vector)  # 3.target to 2.simulation
    return pos_scale


class BBox_Tool(object):
    def __init__(self, smoke_tran_inv, smoke_scale, in_min=[0.15, 0.0, 0.15], in_max=[0.85, 1., 0.85]):
        self.s_w2s = torch.tensor(smoke_tran_inv, device=device, dtype=torch.float32).expand([4, 4])
        self.s2w = torch.inverse(self.s_w2s)
        self.s_scale = torch.tensor(smoke_scale.copy(), device=device, dtype=torch.float32).expand([3])
        self.s_min = torch.tensor(in_min, device=device, dtype=torch.float32)
        self.s_max = torch.tensor(in_max, device=device, dtype=torch.float32)

    def world2sim(self, pts_world):
        pts_world_homo = torch.cat([pts_world, torch.ones_like(pts_world[..., :1])], dim=-1)
        pts_sim_ = torch.matmul(self.s_w2s, pts_world_homo[..., None]).squeeze(-1)[..., :3]
        pts_sim = pts_sim_ / (self.s_scale)  # 3.target to 2.simulation
        return pts_sim

    def world2sim_rot(self, pts_world):
        pts_sim_ = torch.matmul(self.s_w2s[:3, :3], pts_world[..., None]).squeeze(-1)
        pts_sim = pts_sim_ / (self.s_scale)  # 3.target to 2.simulation
        return pts_sim

    def sim2world(self, pts_sim):
        pts_sim_ = pts_sim * self.s_scale
        pts_sim_homo = torch.cat([pts_sim_, torch.ones_like(pts_sim_[..., :1])], dim=-1)
        pts_world = torch.matmul(self.s2w, pts_sim_homo[..., None]).squeeze(-1)[..., :3]
        return pts_world

    def sim2world_rot(self, pts_sim):
        pts_sim_ = pts_sim * self.s_scale
        pts_world = torch.matmul(self.s2w[:3, :3], pts_sim_[..., None]).squeeze(-1)
        return pts_world

    def isInside(self, inputs_pts):
        target_pts = pos_world2smoke(inputs_pts, self.s_w2s, self.s_scale)
        above = torch.logical_and(target_pts[..., 0] >= self.s_min[0], target_pts[..., 1] >= self.s_min[1])
        above = torch.logical_and(above, target_pts[..., 2] >= self.s_min[2])
        below = torch.logical_and(target_pts[..., 0] <= self.s_max[0], target_pts[..., 1] <= self.s_max[1])
        below = torch.logical_and(below, target_pts[..., 2] <= self.s_max[2])
        outputs = torch.logical_and(below, above)
        return outputs

    def insideMask(self, inputs_pts, to_float=True):
        return self.isInside(inputs_pts).to(torch.float) if to_float else self.isInside(inputs_pts)


voxel_tran_inv = np.linalg.inv(VOXEL_TRAN_np)
BBOX_MODEL_gpu = BBox_Tool(voxel_tran_inv, VOXEL_SCALE_np)

In [9]:
def get_rays_np_continuous(c2w):
    i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
    random_offset_i = np.random.uniform(0, 1, size=(H, W))
    random_offset_j = np.random.uniform(0, 1, size=(H, W))
    i = i + random_offset_i
    j = j + random_offset_j
    i = np.clip(i, 0, W - 1)
    j = np.clip(j, 0, H - 1)

    dirs = np.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -np.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3],
                    -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = np.broadcast_to(c2w[:3, -1], np.shape(rays_d))
    return rays_o, rays_d, i, j


def sample_bilinear(img, xy):
    """
    Sample image with bilinear interpolation
    :param img: (T, V, H, W, 3)
    :param xy: (V, 2, H, W)
    :return: img: (T, V, H, W, 3)
    """
    T, V, H, W, _ = img.shape
    u, v = xy[:, 0], xy[:, 1]

    u = np.clip(u, 0, W - 1)
    v = np.clip(v, 0, H - 1)

    u_floor, v_floor = np.floor(u).astype(int), np.floor(v).astype(int)
    u_ceil, v_ceil = np.ceil(u).astype(int), np.ceil(v).astype(int)

    u_ratio, v_ratio = u - u_floor, v - v_floor
    u_ratio, v_ratio = u_ratio[None, ..., None], v_ratio[None, ..., None]

    bottom_left = img[:, np.arange(V)[:, None, None], v_floor, u_floor]
    bottom_right = img[:, np.arange(V)[:, None, None], v_floor, u_ceil]
    top_left = img[:, np.arange(V)[:, None, None], v_ceil, u_floor]
    top_right = img[:, np.arange(V)[:, None, None], v_ceil, u_ceil]

    bottom = (1 - u_ratio) * bottom_left + u_ratio * bottom_right
    top = (1 - u_ratio) * top_left + u_ratio * top_right

    interpolated = (1 - v_ratio) * bottom + v_ratio * top

    return interpolated


def do_resample_rays():
    rays_list = []
    ij = []
    for p in POSES_TRAIN_np[:, :3, :4]:
        r_o, r_d, i_, j_ = get_rays_np_continuous(p)
        rays_list.append([r_o, r_d])
        ij.append([i_, j_])
    ij = np.stack(ij, 0)
    images_train_sample = sample_bilinear(IMAGE_TRAIN_np, ij)
    ret_IMAGE_TRAIN_gpu = torch.tensor(images_train_sample, device=device, dtype=torch.float32).flatten(start_dim=1,
                                                                                                        end_dim=3)

    rays_np = np.stack(rays_list, 0)
    rays_np = np.transpose(rays_np, [0, 2, 3, 1, 4])
    rays_np = np.reshape(rays_np, [-1, 2, 3])  # [VHW, ro+rd=2, 3]
    rays_np = rays_np.astype(np.float32)
    ret_RAYs_gpu = torch.tensor(rays_np, device=device, dtype=torch.float32)
    ret_RAY_IDX_gpu = torch.randperm(ret_RAYs_gpu.shape[0], device=device, dtype=torch.int32)

    return ret_IMAGE_TRAIN_gpu, ret_RAYs_gpu, ret_RAY_IDX_gpu


IMAGE_TRAIN_gpu, RAYs_gpu, RAY_IDX_gpu = do_resample_rays()
resample_rays = False
print(f'IMAGE_TRAIN_gpu: shape={IMAGE_TRAIN_gpu.shape}, dtype={IMAGE_TRAIN_gpu.dtype}, device={IMAGE_TRAIN_gpu.device}')
print(f'RAYs_gpu: shape={RAYs_gpu.shape}, dtype={RAYs_gpu.dtype}, device={RAYs_gpu.device}')
print(f'RAY_IDX_gpu: shape={RAY_IDX_gpu.shape}, dtype={RAY_IDX_gpu.dtype}, device={RAY_IDX_gpu.device}')

IMAGE_TRAIN_gpu: shape=torch.Size([120, 2073600, 3]), dtype=torch.float32, device=cuda:0
RAYs_gpu: shape=torch.Size([2073600, 2, 3]), dtype=torch.float32, device=cuda:0
RAY_IDX_gpu: shape=torch.Size([2073600]), dtype=torch.int32, device=cuda:0


In [10]:
import tqdm

i_batch = 0
for i in tqdm.trange(1, ARGs.N_iters + 1):
    BATCH_RAY_IDX_gpu = RAY_IDX_gpu[i_batch:i_batch + ARGs.N_rand]  # (N_rand,)
    BATCH_RAYs_gpu = torch.transpose(RAYs_gpu[BATCH_RAY_IDX_gpu], 0, 1)  # (2, N_rand, 3)

    TIME_IDX_gpu = torch.randperm(IMAGE_TRAIN_gpu.shape[0], device=device, dtype=torch.float32)[
                   :ARGs.N_time] + torch.randn(ARGs.N_time, device=device, dtype=torch.float32)  # (N_time,)
    TIME_IDX_FLOOR_gpu = torch.clamp(torch.floor(TIME_IDX_gpu).long(), 0, IMAGE_TRAIN_gpu.shape[0] - 1)  # (N_time,)
    TIME_IDX_CEIL_gpu = torch.clamp(torch.ceil(TIME_IDX_gpu).long(), 0, IMAGE_TRAIN_gpu.shape[0] - 1)  # (N_time,)
    TIME_IDX_RESIDUAL_gpu = TIME_IDX_gpu - TIME_IDX_FLOOR_gpu.float()  # (N_time,)
    TIME_STEP_gpu = TIME_IDX_gpu / (IMAGE_TRAIN_gpu.shape[0] - 1) if IMAGE_TRAIN_gpu.shape[0] > 1 else torch.zeros_like(
        TIME_IDX_gpu)  # (N_time,)

    FRAMES_INTERPOLATED_gpu = IMAGE_TRAIN_gpu[TIME_IDX_FLOOR_gpu] * (1 - TIME_IDX_RESIDUAL_gpu).unsqueeze(-1) + \
                              IMAGE_TRAIN_gpu[TIME_IDX_CEIL_gpu] * TIME_IDX_RESIDUAL_gpu.unsqueeze(
        -1)  # (N_time, 4 * WIDTH * HEIGHT, 3)

    TARGET_S_gpu = FRAMES_INTERPOLATED_gpu[:, BATCH_RAY_IDX_gpu].flatten(0, 1)  # (N_time, N_rand, 3)

    i_batch += ARGs.N_rand
    if i_batch >= RAYs_gpu.shape[0]:
        print("Shuffle data after an epoch!")
        RAY_IDX_gpu = torch.randperm(RAYs_gpu.shape[0], device=device)
        i_batch = 0
        resample_rays = True

    RAYS_O_gpu, RAYS_D_gpu = BATCH_RAYs_gpu  # (N_rand, 3), (N_rand, 3)
    T_VALS_gpu = torch.linspace(0., 1., steps=ARGs.N_samples, device=device, dtype=torch.float32)  # (N_samples,)
    Z_VALS_GPU = NEAR_float * torch.ones_like(RAYS_D_gpu[..., :1]) * (1. - T_VALS_gpu) + FAR_float * torch.ones_like(
        RAYS_D_gpu[..., :1]) * T_VALS_gpu  # (N_rand, N_samples)

    ### Randomize the z-values
    MIDs_gpu = .5 * (Z_VALS_GPU[..., 1:] + Z_VALS_GPU[..., :-1])  # (N_rand, N_samples - 1)
    UPPER_gpu = torch.cat([MIDs_gpu, Z_VALS_GPU[..., -1:]], -1)  # (N_rand, N_samples)
    LOWER_gpu = torch.cat([Z_VALS_GPU[..., :1], MIDs_gpu], -1)  # (N_rand, N_samples)
    T_RAND_gpu = torch.rand(Z_VALS_GPU.shape, device=device, dtype=torch.float32)  # (N_rand, N_samples)
    Z_VALS_GPU = LOWER_gpu + (UPPER_gpu - LOWER_gpu) * T_RAND_gpu  # (N_rand, N_samples)
    ### Randomize the z-values

    POINTS_gpu = RAYS_O_gpu[..., None, :] + RAYS_D_gpu[..., None, :] * Z_VALS_GPU[..., :, None]
    TIME_STEP_EXPENDED_gpu = TIME_STEP_gpu.expand(POINTS_gpu[..., :1].shape)  # TODO: check here
    POINTS_TIME_gpu = torch.cat([POINTS_gpu, TIME_STEP_EXPENDED_gpu], dim=-1)
    POINTS_TIME_FLAT_gpu = torch.reshape(POINTS_TIME_gpu, [-1, POINTS_TIME_gpu.shape[-1]])

    out_dim = 1
    RAW_FLAT_gpu = torch.zeros([POINTS_TIME_FLAT_gpu.shape[0], out_dim], device=device, dtype=torch.float32)
    bbox_mask = BBOX_MODEL_gpu.insideMask(POINTS_TIME_FLAT_gpu[..., :3], to_float=False)
    if bbox_mask.sum() == 0:
        bbox_mask[0] = True
    POINTS_TIME_FLAT_FINAL_gpu = POINTS_TIME_FLAT_gpu[bbox_mask]
    RAW_FLAT_gpu[bbox_mask] = MODEL(ENCODER(POINTS_TIME_FLAT_FINAL_gpu))
    RAW_gpu = RAW_FLAT_gpu.reshape(*POINTS_TIME_gpu.shape[:-1], out_dim)

    DISTS_gpu = Z_VALS_GPU[..., 1:] - Z_VALS_GPU[..., :-1]
    DISTS_gpu = torch.cat([DISTS_gpu, torch.tensor([1e10], device=device).expand(DISTS_gpu[..., :1].shape)], -1)
    DISTS_gpu = DISTS_gpu * torch.norm(RAYS_D_gpu[..., None, :], dim=-1)
    RGB_TRAINED = torch.ones(3, device=device) * (0.6 + torch.tanh(MODEL.rgb) * 0.4)
    raw2alpha = lambda raw, dists, act_fn=torch.nn.functional.relu: 1. - torch.exp(-act_fn(raw) * dists)
    noise = 0.
    alpha = raw2alpha(RAW_gpu[..., -1] + noise, DISTS_gpu)
    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1), device=device), 1. - alpha + 1e-10], -1),
                                    -1)[:, :-1]
    rgb_map = torch.sum(weights[..., None] * RGB_TRAINED, -2)

    img2mse = lambda x, y: torch.mean((x - y) ** 2)
    img_loss = img2mse(rgb_map, TARGET_S_gpu)
    loss = img_loss

    for param in GRAD_vars:  # slightly faster than optimizer.zero_grad()
        param.grad = None
    loss.backward()
    optimizer.step()

    decay_rate = 0.1
    decay_steps = ARGs.lrate_decay
    new_lrate = ARGs.lrate * (decay_rate ** (i / decay_steps))
    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lrate

    if resample_rays:
        print("Sampling new rays!")
        IMAGE_TRAIN_gpu, RAYs_gpu, RAY_IDX_gpu = do_resample_rays()
        i_batch = 0
        resample_rays = False

  0%|          | 0/10000 [00:04<?, ?it/s]


KeyboardInterrupt: 

In [11]:
import os
import lpips


def get_rays(H, W, K, c2w):
    i, j = torch.meshgrid(torch.linspace(0, W - 1, W, device=device), torch.linspace(0, H - 1, H, device=device),
                          indexing='ij')  # pytorch's meshgrid has indexing='ij'
    i = i.t()
    j = j.t()
    dirs = torch.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -torch.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3, :3],
                       -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = c2w[:3, -1].expand(rays_d.shape)
    return rays_o, rays_d


def merge_imgs(save_dir, framerate=30, prefix=''):
    os.system(
        'ffmpeg -hide_banner -loglevel error -y -i {0}/{1}%03d.png -vf palettegen {0}/palette.png'.format(save_dir,
                                                                                                          prefix))
    os.system(
        'ffmpeg -hide_banner -loglevel error -y -framerate {0} -i {1}/{2}%03d.png -i {1}/palette.png -lavfi paletteuse {1}/_{2}.gif'.format(
            framerate, save_dir, prefix))
    os.system(
        'ffmpeg -hide_banner -loglevel error -y -framerate {0} -i {1}/{2}%03d.png -i {1}/palette.png -lavfi paletteuse {1}/_{2}.mp4'.format(
            framerate, save_dir, prefix))


os.makedirs(os.path.join("output"), exist_ok=True)
with torch.no_grad():
    test_view_pose = torch.tensor(POSES_TEST_np[0], device=device, dtype=torch.float32)
    N_timesteps = IMAGE_TEST_np.shape[0]
    test_timesteps = torch.arange(N_timesteps, device=device) / (N_timesteps - 1)
    print(test_timesteps.shape)

    lpips_net = lpips.LPIPS().cuda()

    c2w = test_view_pose
    rays_o, rays_d = get_rays(H, W, K, c2w)
    rays_o = torch.reshape(rays_o, [-1, 3]).float()
    rays_d = torch.reshape(rays_d, [-1, 3]).float()
    near, far = NEAR_float * torch.ones_like(rays_d[..., :1]), FAR_float * torch.ones_like(rays_d[..., :1])
    t_vals = torch.linspace(0., 1., steps=ARGs.N_samples, device=device, dtype=torch.float32)  # (N_samples,)
    z_vals = NEAR_float * torch.ones_like(rays_d[..., :1]) * (1. - t_vals) + FAR_float * torch.ones_like(rays_d[..., :1]) * t_vals  # (N_rand, N_samples)
    pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
    print(pts.shape)
    for i in tqdm.trange(0, test_timesteps.shape[0]):
        print(i)
        test_timesteps_expended = test_timesteps[i].expand(pts[..., :1].shape)
        pts_time_gpu = torch.cat([pts, test_timesteps_expended], dim=-1)
        pts_time_flat_gpu = torch.reshape(pts_time_gpu, [-1, pts_time_gpu.shape[-1]])

        out_dim = 1
        RAW_FLAT_gpu = torch.zeros([pts_time_flat_gpu.shape[0], out_dim], device=device, dtype=torch.float32)
        bbox_mask = BBOX_MODEL_gpu.insideMask(pts_time_flat_gpu[..., :3], to_float=False)
        if bbox_mask.sum() == 0:
            bbox_mask[0] = True
        POINTS_TIME_FLAT_FINAL_gpu = pts_time_flat_gpu

        chunk=512 * 64
        all_part = []
        for j in range(0, POINTS_TIME_FLAT_FINAL_gpu.shape[0], chunk):
            part = MODEL(ENCODER(POINTS_TIME_FLAT_FINAL_gpu[j:j + chunk]))
            all_part.append(part)
        result = torch.cat(all_part, 0)

        RAW_FLAT_gpu[bbox_mask] = result
        RAW_gpu = RAW_FLAT_gpu.reshape(*pts_time_gpu.shape[:-1], out_dim)

        dists = z_vals[..., 1:] - z_vals[..., :-1]
        dists = torch.cat([dists, torch.tensor([1e10], device=device).expand(dists[..., :1].shape)], -1)
        dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
        RGB_TRAINED = torch.ones(3, device=device) * (0.6 + torch.tanh(MODEL.rgb) * 0.4)
        raw2alpha = lambda raw, _dists, act_fn=torch.nn.functional.relu: 1. - torch.exp(-act_fn(raw) * _dists)
        noise = 0.
        alpha = raw2alpha(RAW_gpu[..., -1] + noise, dists)
        weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1), device=device), 1. - alpha + 1e-10], -1),
                                        -1)[:, :-1]
        rgb_map = torch.sum(weights[..., None] * RGB_TRAINED, -2)

torch.Size([120])
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]


  self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)


Loading model from: C:\Users\imeho\Documents\VituralEnvs\HyFluid\Lib\site-packages\lpips\weights\v0.1\alex.pth
torch.Size([518400, 192, 3])


  0%|          | 0/120 [00:00<?, ?it/s]

0


  0%|          | 0/120 [00:54<?, ?it/s]
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x000001D8F7A117D0>>
Traceback (most recent call last):
  File "C:\Users\imeho\Documents\VituralEnvs\HyFluid\Lib\site-packages\ipykernel\ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


KeyboardInterrupt: 