In [6]:
import argparse

args = {'config': 'configs/scalar.txt',
        'expname': 'scalar_test1',
        'basedir': './log',
        'datadir': './data/ScalarReal',
        'net_model': 'siren',
        'netdepth': 8,
        'netwidth': 256,
        'netdepth_fine': 8,
        'netwidth_fine': 256,
        'N_rand': 1024,
        'lrate': 0.0005,
        'lrate_decay': 500,
        'chunk': 32768,
        'netchunk': 65536,
        'no_batching': True,
        'no_reload': False,
        'ft_path': None,
        'fix_seed': 42,
        'fading_layers': 50000,
        'tempo_delay': 0,
        'vel_delay': 10000,
        'N_iter': 600000,
        'train_warp': True,
        'bbox_min': '0.05',
        'bbox_max': '0.9',
        'vgg_strides': 4,
        'ghostW': 0.07,
        'vggW': 0.01,
        'overlayW': -0.0,
        'd2vW': 2.0,
        'nseW': 0.001,
        'vol_output_only': False,
        'vol_output_W': 128,
        'render_only': False,
        'render_test': False,
        'N_samples': 64,
        'N_importance': 64,
        'perturb': 1.0,
        'use_viewdirs': False,
        'i_embed': -1,
        'multires': 10,
        'multires_views': 4,
        'raw_noise_std': 0.0,
        'render_factor': 0,
        'precrop_iters': 1000,
        'precrop_frac': 0.5,
        'dataset_type': 'pinf_data',
        'testskip': 20,
        'shape': 'greek',
        'white_bkgd': [1., 1., 1.],
        'half_res': 'half',
        'factor': 8,
        'no_ndc': False,
        'lindisp': False,
        'spherify': False,
        'llffhold': 8,
        'i_print': 400,
        'i_img': 2000,
        'i_weights': 25000,
        'i_testset': 50000,
        'i_video': 50000}

args = argparse.Namespace(**args)
DEBUG = False

In [7]:
import torch
import torchvision
import cv2
import imageio.v2 as imageio
import json
import os
import numpy as np
import math
from typing import Tuple, Union

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device('cuda')
torch.set_default_dtype(torch.float32)

### generate pose spherical

- pure torch functions
- pure device functions

In [8]:
def create_translation_matrix(in_device: torch.device, in_t: float) -> torch.Tensor:
    return torch.tensor([
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., in_t],
        [0., 0., 0., 1.]
    ], dtype=torch.float32, device=in_device)


def create_rotation_matrix_phi(in_device: torch.device, in_phi: float) -> torch.Tensor:
    phi_tensor = torch.tensor(data=in_phi, dtype=torch.float32, device=in_device)
    cos_phi = torch.cos(input=phi_tensor)
    sin_phi = torch.sin(input=phi_tensor)
    return torch.tensor(data=[
        [1, 0, 0, 0],
        [0, cos_phi, -sin_phi, 0],
        [0, sin_phi, cos_phi, 0],
        [0, 0, 0, 1]
    ], dtype=torch.float32, device=in_device)


def create_rotation_matrix_theta(in_device: torch.device, in_theta: float) -> torch.Tensor:
    theta_tensor = torch.tensor(data=in_theta, dtype=torch.float32, device=in_device)
    cos_theta = torch.cos(input=theta_tensor)
    sin_theta = torch.sin(input=theta_tensor)
    return torch.tensor(data=[
        [cos_theta, 0, -sin_theta, 0],
        [0, 1, 0, 0],
        [sin_theta, 0, cos_theta, 0],
        [0, 0, 0, 1]
    ], dtype=torch.float32, device=in_device)


def generate_pose_spherical(in_device: torch.device, in_theta: float, in_phi: float, in_radius: float,
                            in_rotZ: bool, in_wx: float, in_wy: float, in_wz: float) -> torch.Tensor:
    theta_rad = in_theta * torch.pi / 180.0
    phi_rad = in_phi * torch.pi / 180.0

    c2w = create_translation_matrix(in_device=in_device, in_t=in_radius)
    c2w = create_rotation_matrix_phi(in_device=in_device, in_phi=phi_rad) @ c2w
    c2w = create_rotation_matrix_theta(in_device=in_device, in_theta=theta_rad) @ c2w

    if in_rotZ:
        # Swap yz to keep right-hand coordinate system
        swap_yz_matrix = torch.tensor(data=[
            [-1, 0, 0, 0],
            [0, 0, 1, 0],
            [0, 1, 0, 0],
            [0, 0, 0, 1]
        ], dtype=torch.float32, device=in_device)
        c2w = swap_yz_matrix @ c2w

    translation = torch.tensor([
        [1, 0, 0, in_wx],
        [0, 1, 0, in_wy],
        [0, 0, 1, in_wz],
        [0, 0, 0, 1]
    ], dtype=torch.float32, device=in_device)
    c2w = translation @ c2w

    return c2w

### load pinf frame data

- pure numpy functions

In [None]:
def resample_images(in_images: np.ndarray, in_factor: float) -> np.ndarray:
    ret_images = np.zeros((in_images.shape[0], math.floor(in_images.shape[1] * in_factor),
                           math.floor(in_images.shape[2] * in_factor), in_images.shape[3]))
    for _ in range(in_images.shape[0]):
        ret_images[_] = cv2.resize(in_images[_], (ret_images.shape[2], ret_images.shape[1]),
                                   interpolation=cv2.INTER_AREA)
    return ret_images


def load_pinf_frame_data(in_type: str, in_skip: int, in_resample: float) -> Tuple[
    np.ndarray, np.ndarray, np.ndarray, dict]:
    all_images = []
    all_poses = []
    all_time_steps = []
    ret_params = {}
    with open(os.path.normpath(os.path.join(args.datadir, 'info.json')), 'r') as fp:
        _meta = json.load(fp)
        if (in_type + '_videos') not in _meta:
            raise ValueError(f'No {in_type} videos found in the dataset')

        ret_params['camera_angle_x'] = []
        for _video in _meta[in_type + '_videos']:
            # load images, poses and time steps
            _image_array = []
            _pose_array = []
            _time_step_array = []
            _reader = imageio.get_reader(os.path.normpath(os.path.join(args.datadir, _video['file_name'])))
            for _idx in range(0, _video['frame_num'], in_skip):
                _reader.set_image_index(_idx)
                _image_array.append(_reader.get_next_data())
                _pose_array.append(_video['transform_matrix'])
                _time_step_array.append(_idx * 1. / _video['frame_num'])
            _reader.close()
            all_images.append((np.array(_image_array) / 255.).astype(np.float32))
            all_poses.append(np.array(_pose_array).astype(np.float32))
            all_time_steps.append(np.array(_time_step_array).astype(np.float32))
            ret_params['camera_angle_x'].append(float(_video['camera_angle_x']))

        ret_params['near'] = float(_meta['near'])
        ret_params['far'] = float(_meta['far'])
        ret_params['phi'] = float(_meta['phi'])
        ret_params['rotZ'] = (_meta['rot'] == 'Z')
        ret_params['r_center'] = np.array(_meta['render_center']).astype(np.float32)
        ret_params['bkg_color'] = np.array(_meta['frame_bkg_color']).astype(np.float32)

    ret_images_np = np.concatenate(all_images, 0)
    ret_poses_np = np.concatenate(all_poses, 0)
    ret_time_steps_np = np.concatenate(all_time_steps, 0)

    ret_images_resampled_np = resample_images(in_images=ret_images_np, in_factor=in_resample)
    return ret_images_resampled_np, ret_poses_np, ret_time_steps_np, ret_params


TRAIN_IMAGES, TRAIN_POSES, TRAIN_TIME_STEPS, TRAIN_PARAMS = load_pinf_frame_data(in_type='train', in_skip=1,
                                                                                 in_resample=0.5 if args.half_res == 'half' else 1.0)
VAR_IMAGES, VAR_POSES, VAR_TIME_STEPS, VAR_PARAMS = load_pinf_frame_data(in_type='train', in_skip=args.testskip,
                                                                         in_resample=0.5 if args.half_res == 'half' else 1.0)
TEST_IMAGES, TEST_POSES, TEST_TIME_STEPS, TEST_PARAMS = load_pinf_frame_data(in_type='train', in_skip=args.testskip,
                                                                             in_resample=0.5 if args.half_res == 'half' else 1.0)

In [5]:
print(TRAIN_PARAMS)

{'camera_angle_x': 0.40746459248665245, 'near': 1.1, 'far': 1.5, 'phi': 20.0, 'rotZ': False, 'r_center': array([ 0.338207  ,  0.38795385, -0.26092097], dtype=float32), 'bkg_color': array([0., 0., 0.], dtype=float32)}
