In [1]:
import argparse

args = {'config': 'configs/scalar.txt',
        'expname': 'scalar_test1',
        'basedir': './log',
        'datadir': './data/ScalarReal',
        'net_model': 'siren',
        'netdepth': 8,
        'netwidth': 256,
        'netdepth_fine': 8,
        'netwidth_fine': 256,
        'N_rand': 1024,
        'lrate': 0.0005,
        'lrate_decay': 500,
        'chunk': 32768,
        'netchunk': 65536,
        'no_batching': True,
        'no_reload': False,
        'ft_path': None,
        'fix_seed': 42,
        'fading_layers': 50000,
        'tempo_delay': 0,
        'vel_delay': 10000,
        'N_iter': 600000,
        'train_warp': True,
        'bbox_min': '0.05',
        'bbox_max': '0.9',
        'vgg_strides': 4,
        'ghostW': 0.07,
        'vggW': 0.01,
        'overlayW': -0.0,
        'd2vW': 2.0,
        'nseW': 0.001,
        'vol_output_only': False,
        'vol_output_W': 128,
        'render_only': False,
        'render_test': False,
        'N_samples': 64,
        'N_importance': 64,
        'perturb': 1.0,
        'use_viewdirs': False,
        'i_embed': -1,
        'multires': 10,
        'multires_views': 4,
        'raw_noise_std': 0.0,
        'render_factor': 0,
        'precrop_iters': 1000,
        'precrop_frac': 0.5,
        'dataset_type': 'pinf_data',
        'testskip': 20,
        'shape': 'greek',
        'white_bkgd': [1., 1., 1.],
        'half_res': 'half',
        'factor': 8,
        'no_ndc': False,
        'lindisp': False,
        'spherify': False,
        'llffhold': 8,
        'i_print': 400,
        'i_img': 2000,
        'i_weights': 25000,
        'i_testset': 50000,
        'i_video': 50000}

args = argparse.Namespace(**args)
DEBUG = False

In [2]:
import torch
import torchvision
import cv2
import imageio.v2 as imageio
import json
import os
import numpy as np
import math
from typing import Tuple, Union

host = torch.device("cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device('cuda')
torch.set_default_dtype(torch.float32)

### generate pose spherical

- pure torch functions
- pure device functions

In [3]:
def create_translation_matrix(in_device: torch.device, in_t: float) -> torch.Tensor:
    return torch.tensor([
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., in_t],
        [0., 0., 0., 1.]
    ], dtype=torch.float32, device=in_device)


def create_rotation_matrix_phi(in_device: torch.device, in_phi: float) -> torch.Tensor:
    _phi_tensor = torch.tensor(data=in_phi, dtype=torch.float32, device=in_device)
    _cos_phi = torch.cos(input=_phi_tensor)
    _sin_phi = torch.sin(input=_phi_tensor)
    return torch.tensor(data=[
        [1, 0, 0, 0],
        [0, _cos_phi, -_sin_phi, 0],
        [0, _sin_phi, _cos_phi, 0],
        [0, 0, 0, 1]
    ], dtype=torch.float32, device=in_device)


def create_rotation_matrix_theta(in_device: torch.device, in_theta: float) -> torch.Tensor:
    _theta_tensor = torch.tensor(data=in_theta, dtype=torch.float32, device=in_device)
    _cos_theta = torch.cos(input=_theta_tensor)
    _sin_theta = torch.sin(input=_theta_tensor)
    return torch.tensor(data=[
        [_cos_theta, 0, -_sin_theta, 0],
        [0, 1, 0, 0],
        [_sin_theta, 0, _cos_theta, 0],
        [0, 0, 0, 1]
    ], dtype=torch.float32, device=in_device)


def generate_pose_spherical(in_device: torch.device, in_theta: float, in_phi: float, in_radius: float,
                            in_rotZ: bool, in_wx: float, in_wy: float, in_wz: float) -> torch.Tensor:
    _theta_rad = in_theta * torch.pi / 180.0
    _phi_rad = in_phi * torch.pi / 180.0

    ret_c2w = create_translation_matrix(in_device=in_device, in_t=in_radius)
    ret_c2w = create_rotation_matrix_phi(in_device=in_device, in_phi=_phi_rad) @ ret_c2w
    ret_c2w = create_rotation_matrix_theta(in_device=in_device, in_theta=_theta_rad) @ ret_c2w

    if in_rotZ:
        # Swap yz to keep right-hand coordinate system
        _swap_yz_matrix = torch.tensor(data=[
            [-1, 0, 0, 0],
            [0, 0, 1, 0],
            [0, 1, 0, 0],
            [0, 0, 0, 1]
        ], dtype=torch.float32, device=in_device)
        ret_c2w = _swap_yz_matrix @ ret_c2w

    _translation = torch.tensor([
        [1, 0, 0, in_wx],
        [0, 1, 0, in_wy],
        [0, 0, 1, in_wz],
        [0, 0, 0, 1]
    ], dtype=torch.float32, device=in_device)
    ret_c2w = _translation @ ret_c2w

    return ret_c2w


def generate_position_encoding_fn(in_multires: int):
    _input_dims = 3
    _max_freq_log2 = in_multires - 1
    _num_freq = in_multires
    _freq_bands = 2. ** torch.linspace(start=0., end=_max_freq_log2, steps=_num_freq)

    _embed_fns = [lambda x: x]
    ret_out_dim = _input_dims
    for _freq in _freq_bands:
        for _p_fn in [torch.sin, torch.cos]:
            _embed_fns.append(lambda x, p_fn=_p_fn, freq=_freq: p_fn(x * freq))
            ret_out_dim += _input_dims

    ret_embed_fn = lambda x: torch.cat([fn(x) for fn in _embed_fns], -1)
    return ret_embed_fn, ret_out_dim

### load pinf frame data

- pure numpy functions

In [4]:
def resample_images(in_images: np.ndarray, in_factor: float) -> np.ndarray:
    ret_images = np.zeros((in_images.shape[0], math.floor(in_images.shape[1] * in_factor),
                           math.floor(in_images.shape[2] * in_factor), in_images.shape[3]), dtype=np.float32)
    for _ in range(in_images.shape[0]):
        ret_images[_] = cv2.resize(in_images[_], (ret_images.shape[2], ret_images.shape[1]),
                                   interpolation=cv2.INTER_AREA)
    return ret_images


def load_pinf_frame_data(in_type: str, in_skip: int, in_resample: float) -> Tuple[
    np.ndarray, np.ndarray, np.ndarray, dict]:
    _all_images = []
    _all_poses = []
    _all_time_steps = []
    ret_params = {}
    with open(os.path.normpath(os.path.join(args.datadir, 'info.json')), 'r') as fp:
        _meta = json.load(fp)
        if (in_type + '_videos') not in _meta:
            raise ValueError(f'No {in_type} videos found in the dataset')

        ret_params['camera_angle_x'] = []
        for _video in _meta[in_type + '_videos']:
            # load images, poses and time steps
            _image_array = []
            _pose_array = []
            _time_step_array = []
            _reader = imageio.get_reader(os.path.normpath(os.path.join(args.datadir, _video['file_name'])))
            for _idx in range(0, _video['frame_num'], in_skip):
                _reader.set_image_index(_idx)
                _image_array.append(_reader.get_next_data())
                _pose_array.append(_video['transform_matrix'])
                _time_step_array.append(_idx * 1. / _video['frame_num'])
            _reader.close()
            _all_images.append((np.array(_image_array, dtype=np.float32) / 255.))
            _all_poses.append(np.array(_pose_array, dtype=np.float32))
            _all_time_steps.append(np.array(_time_step_array, dtype=np.float32))
            ret_params['camera_angle_x'].append(float(_video['camera_angle_x']))

        ret_params['near'] = float(_meta['near'])
        ret_params['far'] = float(_meta['far'])
        ret_params['phi'] = float(_meta['phi'])
        ret_params['rotZ'] = (_meta['rot'] == 'Z')
        ret_params['r_center'] = np.array(_meta['render_center'], dtype=np.float32)
        ret_params['bkg_color'] = np.array(_meta['frame_bkg_color'], dtype=np.float32)
        ret_params['voxel_matrix'] = np.array(_meta['voxel_matrix'], dtype=np.float32)
        ret_params['voxel_scale'] = np.broadcast_to(_meta['voxel_scale'], [3]).astype(np.float32)

    ret_images_np = np.concatenate(_all_images, 0)
    ret_poses_np = np.concatenate(_all_poses, 0)
    ret_time_steps_np = np.concatenate(_all_time_steps, 0)

    ret_images_resampled_np = resample_images(in_images=ret_images_np, in_factor=in_resample)
    return ret_images_resampled_np, ret_poses_np, ret_time_steps_np, ret_params


TRAIN_IMAGES, TRAIN_POSES, TRAIN_TIME_STEPS, TRAIN_PARAMS = load_pinf_frame_data(in_type='train', in_skip=20,
                                                                                 in_resample=0.5 if args.half_res == 'half' else 1.0)
# VAR_IMAGES, VAR_POSES, VAR_TIME_STEPS, VAR_PARAMS = load_pinf_frame_data(in_type='train', in_skip=args.testskip,
#                                                                          in_resample=0.5 if args.half_res == 'half' else 1.0)
# TEST_IMAGES, TEST_POSES, TEST_TIME_STEPS, TEST_PARAMS = load_pinf_frame_data(in_type='train', in_skip=args.testskip,
#                                                                              in_resample=0.5 if args.half_res == 'half' else 1.0)

In [5]:
VOXEL_TRAN = torch.tensor(data=np.stack(
    [TRAIN_PARAMS['voxel_matrix'][:, 2], TRAIN_PARAMS['voxel_matrix'][:, 1], TRAIN_PARAMS['voxel_matrix'][:, 0],
     TRAIN_PARAMS['voxel_matrix'][:, 3]], axis=1), device=host)
VOXEL_TRAN_INV = torch.linalg.inv(VOXEL_TRAN)
VOXEL_SCALE = torch.tensor(TRAIN_PARAMS['voxel_scale'], device=host)

In [6]:
def pos_world2smoke(points_world: torch.Tensor, w2s: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    pos_rot = torch.sum(points_world[..., None, :] * (w2s[:3, :3]), -1)  # 4.world to 3.target 
    pos_off = (w2s[:3, -1]).expand(pos_rot.shape)  # 4.world to 3.target 
    new_pose = pos_rot + pos_off
    pos_scale = new_pose / (scale)  # 3.target to 2.simulation
    return pos_scale


class BBoxTool:
    def __init__(self, in_smoke_tran_inv: torch.Tensor, in_smoke_scale: torch.Tensor, in_min: float, in_max: float):
        self.s_w2s = in_smoke_tran_inv.clone().detach().expand([4, 4])
        self.s_scale = in_smoke_scale.clone().detach().expand([3])
        self.s_min = torch.tensor(in_min).expand([3])
        self.s_max = torch.tensor(in_max).expand([3])

    def is_inside(self, inputs_points: torch.Tensor) -> torch.Tensor:
        # points_smoke = pos_world2smoke(points_world=inputs_points, w2s=self.s_w2s, scale=self.s_scale)
        pass


BBOX_MODEL = BBoxTool(in_smoke_tran_inv=VOXEL_TRAN_INV, in_smoke_scale=VOXEL_SCALE, in_min=float(args.bbox_min),
                      in_max=float(args.bbox_max))

In [7]:
class SineLayer(torch.nn.Module):
    def __init__(self, in_features_in: torch.Tensor, in_features_out: torch.Tensor, in_bias: bool, in_is_first: bool,
                 in_omega_0: float):
        super(SineLayer, self).__init__()
        self.omega_0 = in_omega_0
        self.linear = torch.nn.Linear(in_features_in, in_features_out, bias=in_bias)
        with torch.no_grad():
            if in_is_first:
                self.linear.weight.uniform_(-1. / in_features_in, 1. / in_features_in)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / in_features_in) / in_omega_0,
                                            np.sqrt(6 / in_features_in) / in_omega_0)

    def forward(self, x: torch.Tensor):
        return torch.sin(self.omega_0 * self.linear(x))

    def forward_with_intermediate(self, x: torch.Tensor):
        intermediate = self.omega_0 * self.linear(x)
        return torch.sin(intermediate), intermediate


class SIREN_VEL(torch.nn.Module):
    def __init__(self):
        super(SIREN_VEL, self).__init__()

    def forward(self, x: torch.Tensor):
        pass


class SIREN_NeRFt(torch.nn.Module):
    def __init__(self):
        super(SIREN_NeRFt, self).__init__()

    def forward(self, x: torch.Tensor):
        pass

In [8]:
def create_nerf():
    pass