## Laod Data

In [1]:
import torch
import taichi as ti
import numpy as np
import pandas as pd
import imageio.v2 as imageio
import tqdm
import math
import json
import os
import lpips
import matplotlib.pyplot as plt

from torch.optim.optimizer import Optimizer
from skimage.metrics import structural_similarity

device = torch.device("cuda")
torch.set_default_tensor_type('torch.cuda.FloatTensor')

[Taichi] version 1.7.2, llvm 15.0.1, commit 0131dce9, win, python 3.11.0


NameError: name 'impo' is not defined

In [None]:
# 加载数据
pinf_data = np.load("data/ScalarReal.npz")
images_train_ = pinf_data['images_train']
poses_train = pinf_data['poses_train']
hwf = pinf_data['hwf']
render_poses = pinf_data['render_poses']
render_timesteps = pinf_data['render_timesteps']
voxel_tran = pinf_data['voxel_tran']
voxel_scale = pinf_data['voxel_scale']
near = pinf_data['near']
far = pinf_data['far']

pinf_data_test = np.load("data/ScalarRealTest.npz")
images_test = pinf_data_test['images_test']
poses_test = pinf_data_test['poses_test']

# 创建表格数据
data = {
    "Name": [
        "images_train",
        "poses_train",
        "hwf",
        "render_poses",
        "render_timesteps",
        "voxel_tran",
        "voxel_scale",
        "near",
        "far",
        "images_test",
        "poses_test",
    ],
    "Shape/Value": [
        images_train_.shape,
        poses_train.shape,
        hwf.shape,
        render_poses.shape,
        render_timesteps.shape,
        voxel_tran.shape,
        voxel_scale.shape,
        near,
        far,
        images_test.shape,
        poses_test.shape,
    ],
    "Size (MB)": [
        images_train_.nbytes / 1024 ** 2,
        poses_train.nbytes / 1024 ** 2,
        hwf.nbytes / 1024 ** 2,
        render_poses.nbytes / 1024 ** 2,
        render_timesteps.nbytes / 1024 ** 2,
        voxel_tran.nbytes / 1024 ** 2,
        voxel_scale.nbytes / 1024 ** 2,
        None,
        None,
        images_test.nbytes / 1024 ** 2,
        poses_test.nbytes / 1024 ** 2,
    ],
}

# 创建 DataFrame
df = pd.DataFrame(data)

# 显示 DataFrame
df.style.set_table_attributes("style='display:inline'").set_caption("Data Summary")

In [None]:
def pos_world2smoke(Pworld, w2s, scale_vector):
    pos_rot = torch.sum(Pworld[..., None, :] * (w2s[:3, :3]), -1)  # 4.world to 3.target
    pos_off = (w2s[:3, -1]).expand(pos_rot.shape)  # 4.world to 3.target
    new_pose = pos_rot + pos_off
    pos_scale = new_pose / (scale_vector)  # 3.target to 2.simulation
    return pos_scale


class BBox_Tool(object):
    def __init__(self, smoke_tran_inv, smoke_scale, in_min=[0.15, 0.0, 0.15], in_max=[0.85, 1., 0.85]):
        self.s_w2s = torch.tensor(smoke_tran_inv).expand([4, 4]).float()
        self.s2w = torch.inverse(self.s_w2s)
        self.s_scale = torch.tensor(smoke_scale.copy()).expand([3]).float()
        self.s_min = torch.Tensor(in_min)
        self.s_max = torch.Tensor(in_max)

    def world2sim(self, pts_world):
        pts_world_homo = torch.cat([pts_world, torch.ones_like(pts_world[..., :1])], dim=-1)
        pts_sim_ = torch.matmul(self.s_w2s, pts_world_homo[..., None]).squeeze(-1)[..., :3]
        pts_sim = pts_sim_ / (self.s_scale)  # 3.target to 2.simulation
        return pts_sim

    def world2sim_rot(self, pts_world):
        pts_sim_ = torch.matmul(self.s_w2s[:3, :3], pts_world[..., None]).squeeze(-1)
        pts_sim = pts_sim_ / (self.s_scale)  # 3.target to 2.simulation
        return pts_sim

    def sim2world(self, pts_sim):
        pts_sim_ = pts_sim * self.s_scale
        pts_sim_homo = torch.cat([pts_sim_, torch.ones_like(pts_sim_[..., :1])], dim=-1)
        pts_world = torch.matmul(self.s2w, pts_sim_homo[..., None]).squeeze(-1)[..., :3]
        return pts_world

    def sim2world_rot(self, pts_sim):
        pts_sim_ = pts_sim * self.s_scale
        pts_world = torch.matmul(self.s2w[:3, :3], pts_sim_[..., None]).squeeze(-1)
        return pts_world

    def isInside(self, inputs_pts):
        target_pts = pos_world2smoke(inputs_pts, self.s_w2s, self.s_scale)
        above = torch.logical_and(target_pts[..., 0] >= self.s_min[0], target_pts[..., 1] >= self.s_min[1])
        above = torch.logical_and(above, target_pts[..., 2] >= self.s_min[2])
        below = torch.logical_and(target_pts[..., 0] <= self.s_max[0], target_pts[..., 1] <= self.s_max[1])
        below = torch.logical_and(below, target_pts[..., 2] <= self.s_max[2])
        outputs = torch.logical_and(below, above)
        return outputs

    def insideMask(self, inputs_pts, to_float=True):
        return self.isInside(inputs_pts).to(torch.float) if to_float else self.isInside(inputs_pts)


voxel_tran_inv = np.linalg.inv(voxel_tran)
bbox_model = BBox_Tool(voxel_tran_inv, voxel_scale)

In [None]:
@ti.func
def linear_step(t):
    return t


@ti.func
def d_linear_step(t):
    return 1


@ti.kernel
def torch2ti(field: ti.template(), data: ti.types.ndarray()):
    for I in ti.grouped(data):
        field[I] = data[I]


@ti.kernel
def ti2torch(field: ti.template(), data: ti.types.ndarray()):
    for I in ti.grouped(data):
        data[I] = field[I]


@ti.kernel
def ti2torch_grad(field: ti.template(), grad: ti.types.ndarray()):
    for I in ti.grouped(grad):
        grad[I] = field.grad[I]


@ti.kernel
def torch2ti_grad(field: ti.template(), grad: ti.types.ndarray()):
    for I in ti.grouped(grad):
        field.grad[I] = grad[I]


@ti.func
def fast_hash(pos_grid_local):
    result = ti.uint32(0)
    primes = ti.math.uvec4(ti.uint32(1), ti.uint32(2654435761), ti.uint32(805459861), ti.uint32(3674653429))
    for i in ti.static(range(4)):
        result ^= ti.uint32(pos_grid_local[i]) * primes[i]
    return result


# ravel (i, j, k, t) to i + i_dim * j + (i_dim * j_dim) * k + (i_dim * j_dim * k_dim) * t
@ti.func
def under_hash(pos_grid_local, resolution):
    result = ti.uint32(0)
    stride = ti.uint32(1)
    for i in ti.static(range(4)):
        result += ti.uint32(pos_grid_local[i] * stride)
        stride *= resolution[i] + 1  # note the +1 here, because 256 x 256 grid actually has 257 x 257 entries
    return result


@ti.func
def grid_pos2hash_index(indicator, pos_grid_local, plane_res, map_size):
    hash_result = ti.uint32(0)
    if indicator == 1:
        hash_result = under_hash(pos_grid_local, plane_res)
    else:
        hash_result = fast_hash(pos_grid_local)

    return hash_result % map_size


@ti.kernel
def hash_encode_kernel(
        xyzts: ti.template(),
        table: ti.template(),
        xyzts_embedding: ti.template(),
        hash_map_indicator: ti.template(),
        hash_map_sizes_field: ti.template(),
        hash_map_shapes_field: ti.template(),
        offsets: ti.template(),
        B: ti.i32,
        num_scales: ti.i32):
    ti.loop_config(block_dim=16)
    for i, level in ti.ndrange(B, num_scales):
        res_x = hash_map_shapes_field[level, 0]
        res_y = hash_map_shapes_field[level, 1]
        res_z = hash_map_shapes_field[level, 2]
        res_t = hash_map_shapes_field[level, 3]
        plane_res = ti.Vector([res_x, res_y, res_z, res_t])
        pos = ti.Vector([xyzts[i, 0], xyzts[i, 1], xyzts[i, 2], xyzts[i, 3]]) * plane_res

        pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32)  # floor
        pos_grid_uint = ti.math.clamp(pos_grid_uint, 0, plane_res - 1)
        pos -= pos_grid_uint  # pos now represents frac
        pos = ti.math.clamp(pos, 0.0, 1.0)

        offset = offsets[level]

        indicator = hash_map_indicator[level]
        map_size = hash_map_sizes_field[level]

        local_feature_0 = 0.0
        local_feature_1 = 0.0

        for idx in ti.static(range(16)):
            w = 1.
            pos_grid_local = ti.math.uvec4(0)

            for d in ti.static(range(4)):
                t = linear_step(pos[d])
                if (idx & (1 << d)) == 0:
                    pos_grid_local[d] = pos_grid_uint[d]
                    w *= 1 - t
                else:
                    pos_grid_local[d] = pos_grid_uint[d] + 1
                    w *= t

            index = grid_pos2hash_index(indicator, pos_grid_local, plane_res, map_size)
            index_table = offset + index * 2  # the flat index for the 1st entry
            index_table_int = ti.cast(index_table, ti.int32)
            local_feature_0 += w * table[index_table_int]
            local_feature_1 += w * table[index_table_int + 1]

        xyzts_embedding[i, level * 2] = local_feature_0
        xyzts_embedding[i, level * 2 + 1] = local_feature_1


@ti.kernel
def hash_encode_kernel_grad(
        xyzts: ti.template(),
        table: ti.template(),
        xyzts_embedding: ti.template(),
        hash_map_indicator: ti.template(),
        hash_map_sizes_field: ti.template(),
        hash_map_shapes_field: ti.template(),
        offsets: ti.template(),
        B: ti.i32,
        num_scales: ti.i32,
        xyzts_grad: ti.template(),
        table_grad: ti.template(),
        output_grad: ti.template()):
    # # # get hash table embedding

    ti.loop_config(block_dim=16)
    for i, level in ti.ndrange(B, num_scales):
        res_x = hash_map_shapes_field[level, 0]
        res_y = hash_map_shapes_field[level, 1]
        res_z = hash_map_shapes_field[level, 2]
        res_t = hash_map_shapes_field[level, 3]
        plane_res = ti.Vector([res_x, res_y, res_z, res_t])
        pos = ti.Vector([xyzts[i, 0], xyzts[i, 1], xyzts[i, 2], xyzts[i, 3]]) * plane_res

        pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32)  # floor
        pos_grid_uint = ti.math.clamp(pos_grid_uint, 0, plane_res - 1)
        pos -= pos_grid_uint  # pos now represents frac
        pos = ti.math.clamp(pos, 0.0, 1.0)

        offset = offsets[level]

        indicator = hash_map_indicator[level]
        map_size = hash_map_sizes_field[level]

        for idx in ti.static(range(16)):
            w = 1.
            pos_grid_local = ti.math.uvec4(0)
            dw = ti.Vector([0., 0., 0., 0.])
            # prods = ti.Vector([0., 0., 0.,0.])
            for d in ti.static(range(4)):
                t = linear_step(pos[d])
                dt = d_linear_step(pos[d])
                if (idx & (1 << d)) == 0:
                    pos_grid_local[d] = pos_grid_uint[d]
                    w *= 1 - t
                    dw[d] = -dt

                else:
                    pos_grid_local[d] = pos_grid_uint[d] + 1
                    w *= t
                    dw[d] = dt

            index = grid_pos2hash_index(indicator, pos_grid_local, plane_res, map_size)
            index_table = offset + index * 2  # the flat index for the 1st entry
            index_table_int = ti.cast(index_table, ti.int32)
            table_grad[index_table_int] += w * output_grad[i, 2 * level]
            table_grad[index_table_int + 1] += w * output_grad[i, 2 * level + 1]
            for d in ti.static(range(4)):
                # eps = 1e-15
                # prod = w / ((linear_step(pos[d]) if idx & (1 << d) > 0 else 1 - linear_step(pos[d])) + eps)
                # prod=1.0
                # for k in range(4):
                #     if k == d:
                #         prod *= dw[k]
                #     else:
                #         prod *= 1- linear_step(pos[k]) if (idx & (1 << k) == 0) else linear_step(pos[k])
                prod = dw[d] * (
                    linear_step(pos[(d + 1) % 4]) if (idx & (1 << ((d + 1) % 4)) > 0) else 1 - linear_step(
                        pos[(d + 1) % 4])
                ) * (
                           linear_step(pos[(d + 2) % 4]) if (idx & (1 << ((d + 2) % 4)) > 0) else 1 - linear_step(
                               pos[(d + 2) % 4])
                       ) * (
                           linear_step(pos[(d + 3) % 4]) if (idx & (1 << ((d + 3) % 4)) > 0) else 1 - linear_step(
                               pos[(d + 3) % 4])
                       )
                xyzts_grad[i, d] += table[index_table_int] * prod * plane_res[d] * output_grad[i, 2 * level]
                xyzts_grad[i, d] += table[index_table_int + 1] * prod * plane_res[d] * output_grad[i, 2 * level + 1]


class HashEncoderHyFluid(torch.nn.Module):
    def __init__(
            self,
            min_res: np.array,
            max_res: np.array,
            num_scales: int,
            max_params=2 ** 19,
            features_per_level: int = 2,
            max_num_queries=10000000,
    ):
        super().__init__()
        b = np.exp((np.log(max_res) - np.log(min_res)) / (num_scales - 1))

        hash_map_shapes = []
        hash_map_sizes = []
        hash_map_indicator = []
        offsets = []
        total_hash_size = 0
        for scale_i in range(num_scales):
            res = np.ceil(min_res * np.power(b, scale_i)).astype(int)
            params_in_level_raw = np.int64(res[0] + 1) * np.int64(res[1] + 1) * np.int64(res[2] + 1) * np.int64(
                res[3] + 1)
            params_in_level = int(params_in_level_raw) if params_in_level_raw % 8 == 0 else int(
                (params_in_level_raw + 8 - 1) / 8) * 8
            params_in_level = min(max_params, params_in_level)
            hash_map_shapes.append(res)
            hash_map_sizes.append(params_in_level)
            hash_map_indicator.append(1 if params_in_level_raw <= params_in_level else 0)
            offsets.append(total_hash_size)
            total_hash_size += params_in_level * features_per_level

        ####################################################################################################
        self.hash_map_shapes_field = ti.field(dtype=ti.i32, shape=(num_scales, 4))
        self.hash_map_shapes_field.from_numpy(np.array(hash_map_shapes))

        self.hash_map_sizes_field = ti.field(dtype=ti.i32, shape=(num_scales,))
        self.hash_map_sizes_field.from_numpy(np.array(hash_map_sizes))

        self.hash_map_indicator_field = ti.field(dtype=ti.i32, shape=(num_scales,))
        self.hash_map_indicator_field.from_numpy(np.array(hash_map_indicator))

        self.offsets_fields = ti.field(ti.i32, shape=(num_scales,))
        self.offsets_fields.from_numpy(np.array(offsets))

        self.hash_table = torch.nn.Parameter(
            (torch.rand(size=(total_hash_size,), dtype=torch.float32) * 2.0 - 1.0) * 1e-4, requires_grad=True)

        self.parameter_fields = ti.field(dtype=ti.f32, shape=(total_hash_size,), needs_grad=True)
        self.parameter_fields_grad = ti.field(dtype=ti.f32, shape=(total_hash_size,), needs_grad=True)

        self.output_fields = ti.field(dtype=ti.f32, shape=(max_num_queries, num_scales * features_per_level),
                                      needs_grad=True)
        self.output_grad = ti.field(dtype=ti.f32, shape=(max_num_queries, num_scales * features_per_level),
                                    needs_grad=True)

        self.input_fields = ti.field(dtype=ti.f32, shape=(max_num_queries, 4), needs_grad=True)
        self.input_fields_grad = ti.field(dtype=ti.f32, shape=(max_num_queries, 4), needs_grad=True)

        self.num_scales = num_scales
        self.features_per_level = features_per_level
        ####################################################################################################

        self.register_buffer('hash_grad', torch.zeros(total_hash_size, dtype=torch.float32), persistent=False)
        self.register_buffer('hash_grad2', torch.zeros(total_hash_size, dtype=torch.float32), persistent=False)
        self.register_buffer('input_grad', torch.zeros(max_num_queries, 4, dtype=torch.float32), persistent=False)
        self.register_buffer('input_grad2', torch.zeros(max_num_queries, 4, dtype=torch.float32), persistent=False)
        self.register_buffer('output_embedding', torch.zeros(max_num_queries, num_scales * 2, dtype=torch.float32),
                             persistent=False)

        ####################################################################################################
        class ModuleFunction(torch.autograd.Function):
            @staticmethod
            @torch.amp.custom_fwd(cast_inputs=torch.float32,
                                  device_type='cuda')  # @custom_fwd(cast_inputs=torch.float32)
            def forward(ctx, input_pos, params):
                output_embedding = self.output_embedding[:input_pos.shape[0]].contiguous()
                torch2ti(self.input_fields, input_pos.contiguous())
                torch2ti(self.parameter_fields, params.contiguous())

                hash_encode_kernel(
                    self.input_fields,
                    self.parameter_fields,
                    self.output_fields,
                    self.hash_map_indicator_field,
                    self.hash_map_sizes_field,
                    self.hash_map_shapes_field,
                    self.offsets_fields,
                    input_pos.shape[0],
                    self.num_scales,
                )

                ti2torch(self.output_fields, output_embedding)
                ctx.save_for_backward(input_pos, params)
                return output_embedding

            @staticmethod
            @torch.amp.custom_bwd(device_type='cuda')  # @custom_bwd
            def backward(ctx, doutput):
                self.input_fields.grad.fill(0.)
                self.input_fields_grad.fill(0.)
                self.parameter_fields.grad.fill(0.)
                self.parameter_fields_grad.fill(0.)

                input_pos, params = ctx.saved_tensors
                return self.module_function_grad.apply(input_pos, params, doutput)

        class ModuleFunctionGrad(torch.autograd.Function):
            @staticmethod
            @torch.amp.custom_fwd(cast_inputs=torch.float32,
                                  device_type='cuda')  # @custom_fwd(cast_inputs=torch.float32)
            def forward(ctx, input_pos, params, doutput):
                torch2ti(self.input_fields, input_pos.contiguous())
                torch2ti(self.parameter_fields, params.contiguous())
                torch2ti(self.output_grad, doutput.contiguous())

                hash_encode_kernel_grad(
                    self.input_fields,
                    self.parameter_fields,
                    self.output_fields,
                    self.hash_map_indicator_field,
                    self.hash_map_sizes_field,
                    self.hash_map_shapes_field,
                    self.offsets_fields,
                    doutput.shape[0],
                    self.num_scales,
                    self.input_fields_grad,
                    self.parameter_fields_grad,
                    self.output_grad
                )

                ti2torch(self.input_fields_grad, self.input_grad.contiguous())
                ti2torch(self.parameter_fields_grad, self.hash_grad.contiguous())
                return self.input_grad[:doutput.shape[0]], self.hash_grad

            @staticmethod
            @torch.amp.custom_bwd(device_type='cuda')  # @custom_bwd
            def backward(ctx, d_input_grad, d_hash_grad):
                self.parameter_fields.grad.fill(0.)
                self.input_fields.grad.fill(0.)
                torch2ti_grad(self.input_fields_grad, d_input_grad.contiguous())
                torch2ti_grad(self.parameter_fields_grad, d_hash_grad.contiguous())

                hash_encode_kernel_grad.grad(
                    self.input_fields,
                    self.parameter_fields,
                    self.output_fields,
                    self.hash_map_indicator_field,
                    self.hash_map_sizes_field,
                    self.hash_map_shapes_field,
                    self.offsets_fields,
                    d_input_grad.shape[0],
                    self.num_scales,
                    self.input_fields_grad,
                    self.parameter_fields_grad,
                    self.output_grad
                )

                ti2torch_grad(self.input_fields, self.input_grad2.contiguous()[:d_input_grad.shape[0]])
                ti2torch_grad(self.parameter_fields, self.hash_grad2.contiguous())
                # set_trace(term_size=(120,30))
                return self.input_grad2[:d_input_grad.shape[0]], self.hash_grad2, None

        self.module_function = ModuleFunction
        self.module_function_grad = ModuleFunctionGrad
        ####################################################################################################

    def forward(self, positions):
        # positions: (N, 4), normalized to [-1, 1]
        positions = positions * 0.5 + 0.5
        return self.module_function.apply(positions, self.hash_table)


ti.init(ti.cuda)
max_res = np.array([256, 256, 256, 128])
min_res = np.array([16, 16, 16, 16])
embed_fn = HashEncoderHyFluid(min_res=min_res, max_res=max_res, num_scales=16, max_params=2 ** 19)
input_ch = embed_fn.num_scales * 2
embedding_params = list(embed_fn.parameters())
print(f'embedding_params: {len(embedding_params)}, {embedding_params[0].shape}')

In [None]:
# Small NeRF for Hash embeddings
class NeRFSmall(torch.nn.Module):
    def __init__(self,
                 num_layers=3,
                 hidden_dim=64,
                 geo_feat_dim=15,
                 num_layers_color=2,
                 hidden_dim_color=16,
                 input_ch=3,
                 ):
        super(NeRFSmall, self).__init__()

        self.input_ch = input_ch
        self.rgb = torch.nn.Parameter(torch.tensor([0.0]))

        # sigma network
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.geo_feat_dim = geo_feat_dim

        sigma_net = []
        for l in range(num_layers):
            if l == 0:
                in_dim = self.input_ch
            else:
                in_dim = hidden_dim

            if l == num_layers - 1:
                out_dim = 1  # 1 sigma + 15 SH features for color
            else:
                out_dim = hidden_dim

            sigma_net.append(torch.nn.Linear(in_dim, out_dim, bias=False))

        self.sigma_net = torch.nn.ModuleList(sigma_net)

        self.color_net = []
        for l in range(num_layers_color):
            if l == 0:
                in_dim = 1
            else:
                in_dim = hidden_dim_color

            if l == num_layers_color - 1:
                out_dim = 1
            else:
                out_dim = hidden_dim_color

            self.color_net.append(torch.nn.Linear(in_dim, out_dim, bias=True))

    def forward(self, x):
        h = x
        for l in range(self.num_layers):
            h = self.sigma_net[l](h)
            h = torch.nn.functional.relu(h, inplace=True)

        sigma = h
        return sigma


model = NeRFSmall(num_layers=2,
                  hidden_dim=64,
                  geo_feat_dim=15,
                  num_layers_color=2,
                  hidden_dim_color=16,
                  input_ch=input_ch).to(device)
grad_vars = list(model.parameters())
print(f'grad_vars: {len(grad_vars)}, {grad_vars[0].shape}, {grad_vars[1].shape}, {grad_vars[2].shape}')

network_query_fn = lambda x: model(embed_fn(x))

In [None]:
class RAdam(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))

        self.degenerated_to_sgd = degenerated_to_sgd
        if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
            for param in params:
                if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
                    param['buffer'] = [[None, None, None] for _ in range(10)]
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
                        buffer=[[None, None, None] for _ in range(10)])
        super(RAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('RAdam does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

                state['step'] += 1
                buffered = group['buffer'][int(state['step'] % 10)]
                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma

                    # more conservative since it's an approximated value
                    if N_sma >= 5:
                        step_size = math.sqrt(
                            (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
                                    N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    elif self.degenerated_to_sgd:
                        step_size = 1.0 / (1 - beta1 ** state['step'])
                    else:
                        step_size = -1
                    buffered[2] = step_size

                # more conservative since it's an approximated value
                if N_sma >= 5:
                    if group['weight_decay'] != 0:
                        p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
                    p.data.copy_(p_data_fp32)
                elif step_size > 0:
                    if group['weight_decay'] != 0:
                        p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
                    p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
                    p.data.copy_(p_data_fp32)

        return loss


optimizer = RAdam([
    {'params': grad_vars, 'weight_decay': 1e-6},
    {'params': embedding_params, 'eps': 1e-15}
], lr=0.01, betas=(0.9, 0.99))
grad_vars += list(embedding_params)

In [None]:
current_device = torch.cuda.current_device()
print(f"Device: {torch.cuda.get_device_name(current_device)}")
print(f"Allocated Memory: {torch.cuda.memory_allocated(current_device) / 1024 ** 2:.2f} MB")
print(f"Cached Memory: {torch.cuda.memory_reserved(current_device) / 1024 ** 2:.2f} MB")

In [None]:
render_poses = torch.tensor(render_poses, device=device)

In [None]:
current_device = torch.cuda.current_device()
print(f"Device: {torch.cuda.get_device_name(current_device)}")
print(f"Allocated Memory: {torch.cuda.memory_allocated(current_device) / 1024 ** 2:.2f} MB")
print(f"Cached Memory: {torch.cuda.memory_reserved(current_device) / 1024 ** 2:.2f} MB")

In [None]:
def get_rays_np_continuous(H, W, K, c2w):
    i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
    random_offset_i = np.random.uniform(0, 1, size=(H, W))
    random_offset_j = np.random.uniform(0, 1, size=(H, W))
    i = i + random_offset_i
    j = j + random_offset_j
    i = np.clip(i, 0, W - 1)
    j = np.clip(j, 0, H - 1)

    dirs = np.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -np.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3],
                    -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = np.broadcast_to(c2w[:3, -1], np.shape(rays_d))
    return rays_o, rays_d, i, j


def sample_bilinear(img, xy):
    """
    Sample image with bilinear interpolation
    :param img: (T, V, H, W, 3)
    :param xy: (V, 2, H, W)
    :return: img: (T, V, H, W, 3)
    """
    T, V, H, W, _ = img.shape
    u, v = xy[:, 0], xy[:, 1]

    u = np.clip(u, 0, W - 1)
    v = np.clip(v, 0, H - 1)

    u_floor, v_floor = np.floor(u).astype(int), np.floor(v).astype(int)
    u_ceil, v_ceil = np.ceil(u).astype(int), np.ceil(v).astype(int)

    u_ratio, v_ratio = u - u_floor, v - v_floor
    u_ratio, v_ratio = u_ratio[None, ..., None], v_ratio[None, ..., None]

    bottom_left = img[:, np.arange(V)[:, None, None], v_floor, u_floor]
    bottom_right = img[:, np.arange(V)[:, None, None], v_floor, u_ceil]
    top_left = img[:, np.arange(V)[:, None, None], v_ceil, u_floor]
    top_right = img[:, np.arange(V)[:, None, None], v_ceil, u_ceil]

    bottom = (1 - u_ratio) * bottom_left + u_ratio * bottom_right
    top = (1 - u_ratio) * top_left + u_ratio * top_right

    interpolated = (1 - v_ratio) * bottom + v_ratio * top

    return interpolated


import time

start = time.time()
H, W, focal = hwf
H, W = int(H), int(W)
hwf = [H, W, focal]
K = np.array([
    [focal, 0, 0.5 * W],
    [0, focal, 0.5 * H],
    [0, 0, 1]
])
# anti-aliasing
rays = []
ij = []
for p in poses_train[:, :3, :4]:
    r_o, r_d, i_, j_ = get_rays_np_continuous(H, W, K, p)
    rays.append([r_o, r_d])
    ij.append([i_, j_])
rays = np.stack(rays, 0)  # [V, ro+rd=2, H, W, 3]
ij = np.stack(ij, 0)  # [V, 2, H, W]
print(f'Elapsed time: {time.time() - start:.2f} s')
start = time.time()
images_train = sample_bilinear(images_train_, ij)  # [T, V, H, W, 3]
print(f'Elapsed time: {time.time() - start:.2f} s')
rays = np.transpose(rays, [0, 2, 3, 1, 4])  # [V, H, W, ro+rd=2, 3]
rays = np.reshape(rays, [-1, 2, 3])  # [VHW, ro+rd=2, 3]
rays = rays.astype(np.float32)

memory_rays = rays.nbytes / (1024 ** 2)  # 转换为 MB
memory_images_train = images_train.nbytes / (1024 ** 2)  # 转换为 MB
print(f'rays: {memory_rays:.2f} MB, shape: {rays.shape}')
print(f'images_train: {memory_images_train:.2f} MB, shape: {images_train.shape}')

In [None]:
images_train = torch.tensor(images_train, device=device).flatten(start_dim=1, end_dim=3)  # [T, VHW, 3]
T, S, _ = images_train.shape
rays = torch.tensor(rays, device=device)
ray_idxs = torch.randperm(rays.shape[0])

In [None]:
current_device = torch.cuda.current_device()
print(f"Device: {torch.cuda.get_device_name(current_device)}")
print(f"Allocated Memory: {torch.cuda.memory_allocated(current_device) / 1024 ** 2:.2f} MB")
print(f"Cached Memory: {torch.cuda.memory_reserved(current_device) / 1024 ** 2:.2f} MB")

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    val = 0
    avg = 0
    sum = 0
    count = 0
    tot_count = 0

    def __init__(self):
        self.reset()
        self.tot_count = 0

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = float(val)
        self.sum += float(val) * n
        self.count += n
        self.tot_count += n
        self.avg = self.sum / self.count


i_batch = 0
start = 1
global_step = start
loss_meter, psnr_meter = AverageMeter(), AverageMeter()
loss_list = []
psnr_list = []
resample_rays = False
N_rand = 256
N_time = 1
N_samples = 192

img2mse = lambda x, y: torch.mean((x - y) ** 2)
mse2psnr = lambda x: -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)
lrate = 0.01
lrate_decay = 100000

basedir = './logs'
expname = 'exp_real/density_256_128'
os.makedirs(os.path.join(basedir, expname), exist_ok=True)

In [None]:
def get_rays(directions, c2w):
    """
    Get ray origin and normalized directions in world coordinate for all pixels in one image.
    Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
               ray-tracing-generating-camera-rays/standard-coordinate-systems

    Inputs:
        directions: (H, W, 3) precomputed ray directions in camera coordinate
        c2w: (3, 4) transformation matrix from camera coordinate to world coordinate

    Outputs:
        rays_o: (H*W, 3), the origin of the rays in world coordinate
        rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
    """
    # Rotate ray directions from camera coordinate to the world coordinate
    rays_d = directions @ c2w[:3, :3].T  # (H, W, 3)
    rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
    # The origin of all rays is the camera origin in world coordinate
    rays_o = c2w[:3, -1].expand(rays_d.shape)  # (H, W, 3)

    rays_d = rays_d.view(-1, 3)
    rays_o = rays_o.view(-1, 3)

    return rays_o, rays_d


def raw2outputs(raw, z_vals, rays_d):
    """Transforms model's predictions to semantically meaningful values.
    Args:
        raw: [num_rays, num_samples along ray, 4]. Prediction from model.
        z_vals: [num_rays, num_samples along ray]. Integration time.
        rays_d: [num_rays, 3]. Direction of each ray.
    Returns:
        rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
        disp_map: [num_rays]. Disparity map. Inverse of depth map.
        acc_map: [num_rays]. Sum of weights along each ray.
        weights: [num_rays, num_samples]. Weights assigned to each sampled color.
        depth_map: [num_rays]. Estimated distance to object.
    """
    raw2alpha = lambda raw, dists, act_fn=torch.nn.functional.relu: 1. - torch.exp(-act_fn(raw) * dists)

    dists = z_vals[..., 1:] - z_vals[..., :-1]
    dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[..., :1].shape)], -1)  # [N_rays, N_samples]

    dists = dists * torch.norm(rays_d[..., None, :], dim=-1)

    rgb = torch.ones(3) * (0.6 + torch.tanh(model.rgb) * 0.4)
    # rgb = 0.6 + torch.tanh(learned_rgb) * 0.4
    noise = 0.

    alpha = raw2alpha(raw[..., -1] + noise, dists)  # [N_rays, N_samples]
    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1. - alpha + 1e-10], -1), -1)[:,
                      :-1]  # [N_rays, N_samples]
    rgb_map = torch.sum(weights[..., None] * rgb, -2)  # [N_rays, 3]

    depth_map = torch.sum(weights * z_vals, -1) / (torch.sum(weights, -1) + 1e-10)
    disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map), depth_map)
    acc_map = torch.sum(weights, -1)
    depth_map[acc_map < 1e-1] = 0.

    return rgb_map, disp_map, acc_map, weights, depth_map


def render_rays(ray_batch,
                retraw=False,
                perturb=0.):
    """Volumetric rendering.
    Args:
      ray_batch: array of shape [batch_size, ...]. All information necessary
        for sampling along a ray, including: ray origin, ray direction, min
        dist, max dist, and unit-magnitude viewing direction.
      perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
        random points in time.
    Returns:
      rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
      disp_map: [num_rays]. Disparity map. 1 / depth.
      acc_map: [num_rays]. Accumulated opacity along each ray.
      raw: [num_rays, num_samples, 4]. Raw predictions from model.
      z_std: [num_rays]. Standard deviation of distances along ray for each
        sample.
    """
    N_rays = ray_batch.shape[0]
    rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6]  # [N_rays, 3] each
    time_step = ray_batch[:, -1]
    bounds = torch.reshape(ray_batch[..., 6:8], [-1, 1, 2])
    near, far = bounds[..., 0], bounds[..., 1]  # [-1,1]

    t_vals = torch.linspace(0., 1., steps=N_samples)
    z_vals = near * (1. - t_vals) + far * (t_vals)

    z_vals = z_vals.expand([N_rays, N_samples])

    if perturb > 0.:
        # get intervals between samples
        mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
        upper = torch.cat([mids, z_vals[..., -1:]], -1)
        lower = torch.cat([z_vals[..., :1], mids], -1)
        # stratified samples in those intervals
        t_rand = torch.rand(z_vals.shape)

        z_vals = lower + (upper - lower) * t_rand

    pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]  # [N_rays, N_samples, 3]
    pts_time_step = time_step[..., None, None].expand(-1, pts.shape[1], -1)
    pts = torch.cat([pts, pts_time_step], -1)  # [..., 4]
    pts_flat = torch.reshape(pts, [-1, 4])
    out_dim = 1
    raw_flat = torch.zeros([N_rays, N_samples, out_dim]).reshape(-1, out_dim)

    bbox_mask = bbox_model.insideMask(pts_flat[..., :3], to_float=False)
    if bbox_mask.sum() == 0:
        bbox_mask[0] = True  # in case zero rays are inside the bbox
    pts = pts_flat[bbox_mask]

    raw_flat[bbox_mask] = network_query_fn(pts)
    raw = raw_flat.reshape(N_rays, N_samples, out_dim)
    rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d)

    ret = {'rgb_map': rgb_map, 'depth_map': depth_map, 'acc_map': acc_map}
    if retraw:
        ret['raw'] = raw
    return ret


def batchify_rays(rays_flat, chunk=1024 * 64):
    """Render rays in smaller minibatches to avoid OOM.
    """
    all_ret = {}
    for i in range(0, rays_flat.shape[0], chunk):
        ret = render_rays(rays_flat[i:i + chunk])
        for k in ret:
            if k not in all_ret:
                all_ret[k] = []
            all_ret[k].append(ret[k])

    all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret}
    return all_ret


def render(H, W, K, chunk=1024 * 64, rays=None, c2w=None, near=0., far=1., time_step=None):
    if c2w is not None:
        # special case to render full image
        rays_o, rays_d = get_rays(H, W, K, c2w)
    else:
        # use provided ray batch
        rays_o, rays_d = rays

    sh = rays_d.shape  # [..., 3]

    # Create ray batch
    rays_o = torch.reshape(rays_o, [-1, 3]).float()
    rays_d = torch.reshape(rays_d, [-1, 3]).float()

    near, far = near * torch.ones_like(rays_d[..., :1]), far * torch.ones_like(rays_d[..., :1])
    rays = torch.cat([rays_o, rays_d, near, far], -1)
    time_step = time_step[:, None, None]  # [N_t, 1, 1]
    N_t = time_step.shape[0]
    N_r = rays.shape[0]
    rays = torch.cat([rays[None].expand(N_t, -1, -1), time_step.expand(-1, N_r, -1)], -1)  # [N_t, n_rays, 7]
    rays = rays.flatten(0, 1)  # [n_time_steps * n_rays, 7]

    # Render and reshape
    all_ret = batchify_rays(rays, chunk)
    if N_t == 1:
        for k in all_ret:
            k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
            all_ret[k] = torch.reshape(all_ret[k], k_sh)

    k_extract = ['rgb_map', 'depth_map', 'acc_map']
    ret_list = [all_ret[k] for k in k_extract]
    ret_dict = [{k: all_ret[k] for k in all_ret if k not in k_extract}, ]
    return ret_list + ret_dict

In [None]:
def render_path(render_poses, hwf, K, chunk = 512 * 64, gt_imgs=None, savedir=None, time_steps=None):
    def merge_imgs(save_dir, framerate=30, prefix=''):
        os.system(
            'ffmpeg -hide_banner -loglevel error -y -i {0}/{1}%03d.png -vf palettegen {0}/palette.png'.format(save_dir,
                                                                                                              prefix))
        os.system(
            'ffmpeg -hide_banner -loglevel error -y -framerate {0} -i {1}/{2}%03d.png -i {1}/palette.png -lavfi paletteuse {1}/_{2}.gif'.format(
                framerate, save_dir, prefix))
        os.system(
            'ffmpeg -hide_banner -loglevel error -y -framerate {0} -i {1}/{2}%03d.png -i {1}/palette.png -lavfi paletteuse {1}/_{2}.mp4'.format(
                framerate, save_dir, prefix))


    H, W, focal = hwf
    if time_steps is None:
        time_steps = torch.ones(render_poses.shape[0], dtype=torch.float32)

    rgbs = []
    depths = []
    psnrs = []
    ssims = []
    lpipss = []

    lpips_net = lpips.LPIPS().cuda()

    for i, c2w in enumerate(tqdm.tqdm(render_poses)):
        rgb, depth, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3, :4], time_step=time_steps[i][None])
        rgbs.append(rgb.cpu().numpy())
        # normalize depth to [0,1]
        depth = (depth - near) / (far - near)
        depths.append(depth.cpu().numpy())

        if gt_imgs is not None:
            gt_img = torch.tensor(gt_imgs[i].squeeze(), dtype=torch.float32)  # [H, W, 3]
            gt_img8 = to8b(gt_img.cpu().numpy())
            gt_img = gt_img[90:960, 45:540]
            rgb = rgb[90:960, 45:540]
            lpips_value = lpips_net(rgb.permute(2, 0, 1), gt_img.permute(2, 0, 1), normalize=True).item()
            p = -10. * np.log10(np.mean(np.square(rgb.detach().cpu().numpy() - gt_img.cpu().numpy())))
            ssim_value = structural_similarity(gt_img.cpu().numpy(), rgb.cpu().numpy(), data_range=1.0, channel_axis=2)
            lpipss.append(lpips_value)
            psnrs.append(p)
            ssims.append(ssim_value)
            print(f'PSNR: {p:.4g}, SSIM: {ssim_value:.4g}, LPIPS: {lpips_value:.4g}')


        if savedir is not None:
            # save rgb and depth as a figure
            rgb8 = to8b(rgbs[-1])
            imageio.imsave(os.path.join(savedir, 'rgb_{:03d}.png'.format(i)), rgb8)
            depth = depths[-1]
            colored_depth_map = plt.cm.viridis(depth.squeeze())
            imageio.imwrite(os.path.join(savedir, 'depth_{:03d}.png'.format(i)),
                            (colored_depth_map * 255).astype(np.uint8))

    if savedir is not None:
        merge_imgs(savedir, prefix='rgb_')
        merge_imgs(savedir, prefix='depth_')

    rgbs = np.stack(rgbs, 0)
    depths = np.stack(depths, 0)
    if gt_imgs is not None:
        avg_psnr = sum(psnrs) / len(psnrs)
        avg_lpips = sum(lpipss) / len(lpipss)
        avg_ssim = sum(ssims) / len(ssims)
        print("Avg PSNR over Test set: ", avg_psnr)
        print("Avg LPIPS over Test set: ", avg_lpips)
        print("Avg SSIM over Test set: ", avg_ssim)
        with open(os.path.join(savedir, "test_psnrs_{:0.4f}_lpips_{:0.4f}_ssim_{:0.4f}.json".format(avg_psnr, avg_lpips, avg_ssim)), 'w') as fp:
            json.dump(psnrs, fp)

    return rgbs, depths

In [None]:
for i in tqdm.trange(start, 10000 + 1):
    # Sample random ray batch
    batch_ray_idx = ray_idxs[i_batch:i_batch + N_rand]
    batch_rays = rays[batch_ray_idx]  # [B, 2, 3]
    batch_rays = torch.transpose(batch_rays, 0, 1)  # [2, B, 3]

    i_batch += N_rand
    # temporal bilinear sampling
    time_idx = torch.randperm(T)[:N_time].float().to(device)  # [N_t]
    time_idx += torch.randn(N_time) - 0.5  # -0.5 ~ 0.5
    time_idx_floor = torch.floor(time_idx).long()
    time_idx_ceil = torch.ceil(time_idx).long()
    time_idx_floor = torch.clamp(time_idx_floor, 0, T - 1)
    time_idx_ceil = torch.clamp(time_idx_ceil, 0, T - 1)
    time_idx_residual = time_idx - time_idx_floor.float()
    frames_floor = images_train[time_idx_floor]  # [N_t, VHW, 3]
    frames_ceil = images_train[time_idx_ceil]  # [N_t, VHW, 3]
    frames_interp = frames_floor * (1 - time_idx_residual).unsqueeze(-1) + \
                    frames_ceil * time_idx_residual.unsqueeze(-1)  # [N_t, VHW, 3]
    time_step = time_idx / (T - 1) if T > 1 else torch.zeros_like(time_idx)
    points = frames_interp[:, batch_ray_idx]  # [N_t, B, 3]
    target_s = points.flatten(0, 1)  # [N_t*B, 3]

    if i_batch >= rays.shape[0]:
        print("Shuffle data after an epoch!")
        ray_idxs = torch.randperm(rays.shape[0])
        i_batch = 0
        resample_rays = True

    #####  Core optimization loop  #####
    rgb, depth, acc, extras = render(H, W, K, rays=batch_rays, time_step=time_step)

    img_loss = img2mse(rgb, target_s)
    loss = img_loss
    psnr = mse2psnr(img_loss)
    loss_meter.update(loss.item())
    psnr_meter.update(psnr.item())

    for param in grad_vars:  # slightly faster than optimizer.zero_grad()
        param.grad = None
    loss.backward()
    optimizer.step()

    ###   update learning rate   ###
    decay_rate = 0.1
    decay_steps = lrate_decay
    new_lrate = lrate * (decay_rate ** (global_step / decay_steps))
    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lrate
    ################################

    if i % 1000 == 0 and i > 0:
        # Turn on testing mode
        testsavedir = os.path.join(basedir, expname, 'spiral_{:06d}'.format(i))
        os.makedirs(testsavedir, exist_ok=True)
        with torch.no_grad():
            render_path(render_poses, hwf, K, time_steps=render_timesteps, savedir=testsavedir)

        testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
        os.makedirs(testsavedir, exist_ok=True)
        with torch.no_grad():
            test_view_pose = torch.tensor(poses_test[0])
            N_timesteps = images_test.shape[0]
            test_timesteps = torch.arange(N_timesteps) / (N_timesteps - 1)
            test_view_poses = test_view_pose.unsqueeze(0).repeat(N_timesteps, 1, 1)
            render_path(test_view_poses, hwf, K, time_steps=test_timesteps, gt_imgs=images_test,
                        savedir=testsavedir)

    if i % 100 == 0:
        tqdm.tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss_meter.avg:.2g}  PSNR: {psnr_meter.avg:.4g}")
        loss_list.append(loss_meter.avg)
        psnr_list.append(psnr_meter.avg)
        loss_psnr = {
            "losses": loss_list,
            "psnr": psnr_list,
        }
        loss_meter.reset()
        psnr_meter.reset()
        with open(os.path.join(basedir, expname, "loss_vs_time.json"), "w") as fp:
            json.dump(loss_psnr, fp)

    if resample_rays:
        print("Sampling new rays!")
        rays = []
        ij = []
        for p in poses_train[:, :3, :4]:
            r_o, r_d, i_, j_ = get_rays_np_continuous(H, W, K, p)
            rays.append([r_o, r_d])
            ij.append([i_, j_])
        rays = np.stack(rays, 0)  # [V, ro+rd=2, H, W, 3]
        ij = np.stack(ij, 0)  # [V, 2, H, W]
        images_train = sample_bilinear(images_train_, ij)  # [T, V, H, W, 3]
        rays = np.transpose(rays, [0, 2, 3, 1, 4])  # [V, H, W, ro+rd=2, 3]
        rays = np.reshape(rays, [-1, 2, 3])  # [VHW, ro+rd=2, 3]
        rays = rays.astype(np.float32)

        # Move training data to GPU
        images_train = torch.Tensor(images_train).to(device).flatten(start_dim=1, end_dim=3)  # [T, VHW, 3]
        T, S, _ = images_train.shape
        rays = torch.Tensor(rays).to(device)

        ray_idxs = torch.randperm(rays.shape[0])
        i_batch = 0
        resample_rays = False
    global_step += 1
