# Removal Logs

1. Logs
2. Checkpoints
3. render only
4. volume output only

In [1]:
import argparse

args = {'config': 'configs/scalar.txt',
        'expname': 'scalar_test1',
        'basedir': './log',
        'datadir': './data/ScalarReal',
        'net_model': 'siren',
        'netdepth': 8,
        'netwidth': 256,
        'netdepth_fine': 8,
        'netwidth_fine': 256,
        'N_rand': 1024,
        'lrate': 0.0005,
        'lrate_decay': 500,
        'chunk': 32768,
        'netchunk': 65536,
        'no_batching': True,
        'no_reload': False,
        'ft_path': None,
        'fix_seed': 42,
        'fading_layers': 50000,
        'tempo_delay': 0,
        'vel_delay': 10000,
        'N_iter': 600000,
        'train_warp': True,
        'bbox_min': '0.05',
        'bbox_max': '0.9',
        'vgg_strides': 4,
        'ghostW': 0.07,
        'vggW': 0.01,
        'overlayW': -0.0,
        'd2vW': 2.0,
        'nseW': 0.001,
        'vol_output_only': False,
        'vol_output_W': 128,
        'render_only': False,
        'render_test': False,
        'N_samples': 64,
        'N_importance': 64,
        'perturb': 1.0,
        'use_viewdirs': False,
        'i_embed': -1,
        'multires': 10,
        'multires_views': 4,
        'raw_noise_std': 0.0,
        'render_factor': 0,
        'precrop_iters': 1000,
        'precrop_frac': 0.5,
        'dataset_type': 'pinf_data',
        'testskip': 20,
        'shape': 'greek',
        'white_bkgd': [1., 1., 1.],
        'half_res': 'half',
        'factor': 8,
        'no_ndc': False,
        'lindisp': False,
        'spherify': False,
        'llffhold': 8,
        'i_print': 400,
        'i_img': 2000,
        'i_weights': 25000,
        'i_testset': 50000,
        'i_video': 50000}

args = argparse.Namespace(**args)
DEBUG = False

In [2]:
import torch
import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# torch.set_default_device('cuda')
# torch.set_default_dtype(torch.float32)
torch.set_default_tensor_type('torch.cuda.FloatTensor')

  _C._set_default_tensor_type(t)


In [3]:
import os
import imageio.v2 as imageio
import json
import numpy as np
import cv2
import datetime

trans_t = lambda t: torch.Tensor([
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 1, t],
    [0, 0, 0, 1]]).float()

rot_phi = lambda phi: torch.Tensor([
    [1, 0, 0, 0],
    [0, np.cos(phi), -np.sin(phi), 0],
    [0, np.sin(phi), np.cos(phi), 0],
    [0, 0, 0, 1]]).float()

rot_theta = lambda th: torch.Tensor([
    [np.cos(th), 0, -np.sin(th), 0],
    [0, 1, 0, 0],
    [np.sin(th), 0, np.cos(th), 0],
    [0, 0, 0, 1]]).float()


def pose_spherical(theta, phi, radius, rotZ=True, wx=0.0, wy=0.0, wz=0.0):
    # spherical, rotZ=True: theta rotate around Z; rotZ=False: theta rotate around Y
    # wx,wy,wz, additional translation, normally the center coord.
    c2w = trans_t(radius)
    c2w = rot_phi(phi / 180. * np.pi) @ c2w
    c2w = rot_theta(theta / 180. * np.pi) @ c2w
    if rotZ:  # swap yz, and keep right-hand
        c2w = torch.Tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) @ c2w

    ct = torch.Tensor([
        [1, 0, 0, wx],
        [0, 1, 0, wy],
        [0, 0, 1, wz],
        [0, 0, 0, 1]]).float()
    c2w = ct @ c2w

    return c2w


def load_pinf_frame_data(basedir, half_res='normal', testskip=1):
    # frame data
    all_imgs = []
    all_poses = []
    all_hwf = []
    all_time_steps = []
    counts = [0]
    merge_counts = [0]
    t_info = [0.0, 0.0, 0.0, 0.0]

    # render params
    near, far, radius, phi, rotZ, r_center = 0.0, 1.0, 0.5, 20, False, np.float32([0.0] * 3)

    # scene data
    voxel_tran, voxel_scale, bkg_color = None, None, None

    with open(os.path.join(basedir, 'info.json'), 'r') as fp:
        # read render settings
        meta = json.load(fp)
        near = float(meta['near'])
        far = float(meta['far'])
        radius = (near + far) * 0.5
        phi = float(meta['phi'])
        rotZ = (meta['rot'] == 'Z')
        r_center = np.float32(meta['render_center'])
        bkg_color = np.float32(meta['frame_bkg_color'])

        # read scene data
        voxel_tran = np.float32(meta['voxel_matrix'])
        voxel_tran = np.stack([voxel_tran[:, 2], voxel_tran[:, 1], voxel_tran[:, 0], voxel_tran[:, 3]],
                              axis=1)  # swap_zx
        voxel_scale = np.broadcast_to(meta['voxel_scale'], [3])

        # read video frames
        # all videos should be synchronized, having the same frame_rate and frame_num
        for s in 'train,val,test'.split(','):
            if s == 'train' or testskip == 0:
                skip = 1
            else:
                skip = testskip

            video_list = meta[s + '_videos'] if (s + '_videos') in meta else meta['train_videos'][0:1]

            for train_video in video_list:
                imgs = []
                poses = []
                time_steps = []
                H, W, Focal = 0, 0, 0

                f_name = os.path.join(basedir, train_video['file_name'])
                reader = imageio.get_reader(f_name, "ffmpeg")
                if s == 'train':
                    delta_t = 1.0 / train_video['frame_num']
                for frame_i in range(0, train_video['frame_num'], skip):
                    reader.set_image_index(frame_i)
                    frame = reader.get_next_data()

                    if H == 0:
                        H, W = frame.shape[:2]
                        camera_angle_x = float(train_video['camera_angle_x'])
                        Focal = .5 * W / np.tan(.5 * camera_angle_x)

                    cur_timestep = frame_i
                    time_steps.append([frame_i * delta_t])
                    poses.append(np.array(
                        train_video['transform_matrix_list'][frame_i]
                        if 'transform_matrix_list' in train_video else train_video['transform_matrix']
                    ))

                    imgs.append(frame)

                reader.close()
                imgs = (np.float32(imgs) / 255.)
                poses = np.array(poses).astype(np.float32)
                time_steps = np.array(time_steps).astype(np.float32)

                if half_res != 'normal':
                    if half_res == 'half':  # errors if H or W is not dividable by 2
                        H = H // 2
                        W = W // 2
                        Focal = Focal / 2.
                    elif half_res == 'quater':  # errors if H or W is not dividable by 4
                        H = H // 4
                        W = W // 4
                        Focal = Focal / 4.
                    elif half_res == 'double':
                        H = H * 2
                        W = W * 2
                        focal = focal * 2.

                    imgs_half_res = np.zeros((imgs.shape[0], H, W, imgs.shape[-1]))
                    for i, img in enumerate(imgs):
                        imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
                    imgs = imgs_half_res

                counts.append(counts[-1] + imgs.shape[0])
                all_imgs.append(imgs)
                all_poses.append(poses)
                all_time_steps.append(time_steps)
                all_hwf.append(np.float32([[H, W, Focal]] * imgs.shape[0]))
            merge_counts.append(counts[-1])

    t_info = np.float32([0.0, 1.0, 0.5, delta_t])  # min t, max t, mean t, delta_t
    i_split = [np.arange(merge_counts[i], merge_counts[i + 1]) for i in range(3)]

    imgs = np.concatenate(all_imgs, 0)  # n, H, W
    poses = np.concatenate(all_poses, 0)  # n, 4, 4
    time_steps = np.concatenate(all_time_steps, 0)  # n, 1
    hwfs = np.concatenate(all_hwf, 0)  # n, 3

    # set render settings:
    render_focal = float(hwfs[0][-1])
    sp_n = 40  # an even number!
    sp_poses = [
        pose_spherical(angle, phi, radius, rotZ, r_center[0], r_center[1], r_center[2])
        for angle in np.linspace(-180, 180, sp_n + 1)[:-1]
    ]
    sp_steps = np.linspace(t_info[0], t_info[1], num=sp_n)  # [ float(ct) ]*sp_n, for testing a frozen t
    render_poses = torch.stack(sp_poses, 0)  # [sp_poses[36]]*sp_n, for testing a single pose
    render_timesteps = np.reshape(sp_steps, (-1, 1))

    return imgs, poses, time_steps, hwfs, render_poses, render_timesteps, i_split, t_info, voxel_tran, voxel_scale, bkg_color, near, far


images, poses, time_steps, hwfs, render_poses, render_timesteps, i_split, t_info, voxel_tran, voxel_scale, bkg_color, near, far = load_pinf_frame_data(
    args.datadir, args.half_res, args.testskip)

In [4]:
voxel_tran_inv = np.linalg.inv(voxel_tran)
print('Loaded pinf frame data', images.shape, render_poses.shape, hwfs[0], args.datadir)
print('Loaded voxel matrix', voxel_tran, 'voxel scale', voxel_scale)
voxel_tran_inv = torch.Tensor(voxel_tran_inv)
voxel_tran = torch.Tensor(voxel_tran)
voxel_scale = torch.Tensor(voxel_scale)

i_train, i_val, i_test = i_split
args.white_bkgd = torch.Tensor(bkg_color).to(device)
print('Scene has background color', bkg_color, args.white_bkgd)

Ks = [
    [
        [hwf[-1], 0, 0.5 * hwf[1]],
        [0, hwf[-1], 0.5 * hwf[0]],
        [0, 0, 1]
    ] for hwf in hwfs
]

if args.render_test:
    render_poses = np.array(poses[i_test])
    render_timesteps = np.array(time_steps[i_test])

Loaded pinf frame data (612, 960, 540, 3) torch.Size([40, 4, 4]) [ 960.      540.     1306.8817] ./data/ScalarReal
Loaded voxel matrix [[ 1.0000000e+00  0.0000000e+00  7.5497901e-08  8.1816666e-02]
 [ 0.0000000e+00  1.0000000e+00  0.0000000e+00 -4.4627272e-02]
 [ 7.5497901e-08  0.0000000e+00 -1.0000000e+00 -4.9089999e-03]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  1.0000000e+00]] voxel scale [0.4909  0.73635 0.4909 ]
Scene has background color [0. 0. 0.] tensor([0., 0., 0.])


  voxel_scale = torch.Tensor(voxel_scale)


In [5]:
def pos_world2smoke(Pworld, w2s, scale_vector):
    pos_rot = torch.sum(Pworld[..., None, :] * (w2s[:3, :3]), -1)  # 4.world to 3.target 
    pos_off = (w2s[:3, -1]).expand(pos_rot.shape)  # 4.world to 3.target 
    new_pose = pos_rot + pos_off
    pos_scale = new_pose / (scale_vector)  # 3.target to 2.simulation
    return pos_scale


class BBox_Tool(object):
    def __init__(self, smoke_tran_inv, smoke_scale, in_min=0.0, in_max=1.0):
        self.s_w2s = torch.Tensor(smoke_tran_inv).expand([4, 4])
        self.s_scale = torch.Tensor(smoke_scale).expand([3])
        self.s_min = torch.Tensor(in_min).expand([3])
        self.s_max = torch.Tensor(in_max).expand([3])

    def setMinMax(self, in_min=0.0, in_max=1.0):
        self.s_min = torch.Tensor(in_min).expand([3])
        self.s_max = torch.Tensor(in_max).expand([3])

    def isInside(self, inputs_pts):
        target_pts = pos_world2smoke(inputs_pts, self.s_w2s, self.s_scale)
        above = torch.logical_and(target_pts[..., 0] >= self.s_min[0], target_pts[..., 1] >= self.s_min[1])
        above = torch.logical_and(above, target_pts[..., 2] >= self.s_min[2])
        below = torch.logical_and(target_pts[..., 0] <= self.s_max[0], target_pts[..., 1] <= self.s_max[1])
        below = torch.logical_and(below, target_pts[..., 2] <= self.s_max[2])
        outputs = torch.logical_and(below, above)
        return outputs

    def insideMask(self, inputs_pts):
        return self.isInside(inputs_pts).to(torch.float)


# Create Bbox model
bbox_model = None
# in_min, in_max = 0.0, 1.0
if args.bbox_min != "":
    in_min = [float(_) for _ in args.bbox_min.split(",")]
    in_max = [float(_) for _ in args.bbox_max.split(",")]
    bbox_model = BBox_Tool(voxel_tran_inv, voxel_scale, in_min, in_max)

In [6]:
class SineLayer(torch.nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.

    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the 
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a 
    # hyperparameter.

    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of 
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)

    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first

        self.in_features = in_features
        self.linear = torch.nn.Linear(in_features, out_features, bias=bias)

        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features,
                                            1 / self.in_features)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
                                            np.sqrt(6 / self.in_features) / self.omega_0)

    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))

    def forward_with_intermediate(self, input):
        # For visualization of activation distributions
        intermediate = self.omega_0 * self.linear(input)
        return torch.sin(intermediate), intermediate


# Velocity Model
class SIREN_vel(torch.nn.Module):
    def __init__(self, D=6, W=128, input_ch=4, output_ch=3, skips=[], fading_fin_step=0, bbox_model=None):
        """ 
        fading_fin_step: >0, to fade in layers one by one, fully faded in when self.fading_step >= fading_fin_step
        """

        super(SIREN_vel, self).__init__()
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.skips = skips
        self.fading_step = 0
        self.fading_fin_step = fading_fin_step if fading_fin_step > 0 else 0
        self.bbox_model = bbox_model

        first_omega_0 = 30.0
        hidden_omega_0 = 1.0

        self.hid_linears = torch.nn.ModuleList(
            [SineLayer(input_ch, W, omega_0=first_omega_0)] +
            [SineLayer(W, W, omega_0=hidden_omega_0)
             if i not in self.skips else SineLayer(W + input_ch, W, omega_0=hidden_omega_0) for i in range(D - 1)]
        )

        final_vel_linear = torch.nn.Linear(W, output_ch)

        self.vel_linear = final_vel_linear

    def update_fading_step(self, fading_step):
        # should be updated with the global step
        # e.g., update_fading_step(global_step - vel_in_step)
        if fading_step >= 0:
            self.fading_step = fading_step

    def fading_wei_list(self):
        # try print(fading_wei_list()) for debug
        step_ratio = np.clip(float(self.fading_step) / float(max(1e-8, self.fading_fin_step)), 0, 1)
        ma = 1 + (self.D - 2) * step_ratio  # in range of 1 to self.D-1
        fading_wei_list = [np.clip(1 + ma - m, 0, 1) * np.clip(1 + m - ma, 0, 1) for m in range(self.D)]
        return fading_wei_list

    def print_fading(self):
        w_list = self.fading_wei_list()
        _str = ["h%d:%0.03f" % (i, w_list[i]) for i in range(len(w_list)) if w_list[i] > 1e-8]
        print("; ".join(_str))

    def forward(self, x):
        h = x
        h_layers = []
        for i, l in enumerate(self.hid_linears):
            h = self.hid_linears[i](h)

            h_layers += [h]
            if i in self.skips:
                h = torch.cat([x, h], -1)

        # a sliding window (fading_wei_list) to enable deeper layers progressively
        if self.fading_fin_step > self.fading_step:
            fading_wei_list = self.fading_wei_list()
            h = 0
            for w, y in zip(fading_wei_list, h_layers):
                if w > 1e-8:
                    h = w * y + h

        vel_out = self.vel_linear(h)

        if self.bbox_model is not None:
            bbox_mask = self.bbox_model.insideMask(x[..., :3])
            vel_out = torch.reshape(bbox_mask, [-1, 1]) * vel_out

        return vel_out


# Create vel model
vel_model = None

if args.nseW > 1e-8:
    # D=6, W=128, input_ch=4, output_ch=3, skips=[],
    vel_model = SIREN_vel(fading_fin_step=args.fading_layers, bbox_model=bbox_model).to(device)

In [7]:
# Positional encoding (section 5.1)
class Embedder:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()

    def create_embedding_fn(self):
        embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x: x)
            out_dim += d

        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']

        if self.kwargs['log_sampling']:
            freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs)
        else:
            freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs)

        for _freq in freq_bands:
            for _p_fn in self.kwargs['periodic_fns']:
                embed_fns.append(lambda x, p_fn=_p_fn, freq=_freq: p_fn(x * freq))
                out_dim += d

        self.embed_fns = embed_fns
        self.out_dim = out_dim

    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)


def get_embedder(multires, i=0, dim=3):
    if i == -1:
        return torch.nn.Identity(), dim

    embed_kwargs = {
        'include_input': True,
        'input_dims': dim,
        'max_freq_log2': multires - 1,
        'num_freqs': multires,
        'log_sampling': True,
        'periodic_fns': [torch.sin, torch.cos],
    }

    embedder_obj = Embedder(**embed_kwargs)
    embed = lambda x, eo=embedder_obj: eo.embed(x)
    return embed, embedder_obj.out_dim


# Model
class SIREN_NeRFt(torch.nn.Module):
    def __init__(self, D=8, W=256, input_ch=4, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False,
                 fading_fin_step=0, bbox_model=None):
        """ 
        fading_fin_step: >0, to fade in layers one by one, fully faded in when self.fading_step >= fading_fin_step
        """

        super(SIREN_NeRFt, self).__init__()
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.input_ch_views = input_ch_views
        self.skips = skips
        self.use_viewdirs = use_viewdirs
        self.fading_step = 0
        self.fading_fin_step = fading_fin_step if fading_fin_step > 0 else 0
        self.bbox_model = bbox_model

        first_omega_0 = 30.0
        hidden_omega_0 = 1.0

        self.pts_linears = torch.nn.ModuleList(
            [SineLayer(input_ch, W, omega_0=first_omega_0)] +
            [SineLayer(W, W, omega_0=hidden_omega_0)
             if i not in self.skips else SineLayer(W + input_ch, W, omega_0=hidden_omega_0) for i in range(D - 1)]
        )

        final_alpha_linear = torch.nn.Linear(W, 1)
        self.alpha_linear = final_alpha_linear

        if use_viewdirs:
            self.views_linear = SineLayer(input_ch_views, W // 2, omega_0=first_omega_0)
            self.feature_linear = SineLayer(W, W // 2, omega_0=hidden_omega_0)
            self.feature_view_linears = torch.nn.ModuleList([SineLayer(W, W, omega_0=hidden_omega_0)])

        final_rgb_linear = torch.nn.Linear(W, 3)
        self.rgb_linear = final_rgb_linear

    def update_fading_step(self, fading_step):
        # should be updated with the global step
        # e.g., update_fading_step(global_step - radiance_in_step)
        if fading_step >= 0:
            self.fading_step = fading_step

    def fading_wei_list(self):
        # try print(fading_wei_list()) for debug
        step_ratio = np.clip(float(self.fading_step) / float(max(1e-8, self.fading_fin_step)), 0, 1)
        ma = 1 + (self.D - 2) * step_ratio  # in range of 1 to self.D-1
        fading_wei_list = [np.clip(1 + ma - m, 0, 1) * np.clip(1 + m - ma, 0, 1) for m in range(self.D)]
        return fading_wei_list

    def print_fading(self):
        w_list = self.fading_wei_list()
        _str = ["h%d:%0.03f" % (i, w_list[i]) for i in range(len(w_list)) if w_list[i] > 1e-8]
        print("; ".join(_str))

    def forward(self, x):
        input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)

        h = input_pts
        h_layers = []
        for i, l in enumerate(self.pts_linears):
            h = self.pts_linears[i](h)

            h_layers += [h]
            if i in self.skips:
                h = torch.cat([input_pts, h], -1)

        # a sliding window (fading_wei_list) to enable deeper layers progressively
        if self.fading_fin_step > self.fading_step:
            fading_wei_list = self.fading_wei_list()
            h = 0
            for w, y in zip(fading_wei_list, h_layers):
                if w > 1e-8:
                    h = w * y + h

        alpha = self.alpha_linear(h)

        if self.use_viewdirs:
            input_pts_feature = self.feature_linear(h)
            input_views_feature = self.views_linear(input_views)

            h = torch.cat([input_pts_feature, input_views_feature], -1)

            for i, l in enumerate(self.feature_view_linears):
                h = self.feature_view_linears[i](h)

        rgb = self.rgb_linear(h)
        outputs = torch.cat([rgb, alpha], -1)

        if self.bbox_model is not None:
            bbox_mask = self.bbox_model.insideMask(input_pts[:, :3])
            outputs = torch.reshape(bbox_mask, [-1, 1]) * outputs

        return outputs


def batchify(fn, chunk):
    """Constructs a version of 'fn' that applies to smaller batches.
    """
    if chunk is None:
        return fn

    def ret(inputs):
        return torch.cat([fn(inputs[i:i + chunk]) for i in range(0, inputs.shape[0], chunk)], 0)

    return ret


def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024 * 64):
    """Prepares inputs and applies network 'fn'.
    """
    inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
    embedded = embed_fn(inputs_flat)

    if viewdirs is not None:
        input_dirs = viewdirs[:, None].expand(inputs.shape)
        input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
        embedded_dirs = embeddirs_fn(input_dirs_flat)
        embedded = torch.cat([embedded, embedded_dirs], -1)

    outputs_flat = batchify(fn, netchunk)(embedded)
    outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
    return outputs


def create_nerf(args, vel_model=None, bbox_model=None, ndim=3):
    """Instantiate NeRF's MLP model.
    """
    embed_fn, input_ch = get_embedder(args.multires, args.i_embed, ndim)

    input_ch_views = 0
    embeddirs_fn = None
    if args.use_viewdirs:
        embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed, dim=ndim)
    output_ch = 4  # 5 if args.N_importance > 0 else 4
    skips = [4]

    model_args = {}
    if args.fading_layers > 0:
        if args.net_model == "siren":
            model_args["fading_fin_step"] = args.fading_layers
        elif args.net_model == "hybrid":
            model_args["fading_fin_step_static"] = args.fading_layers
            model_args["fading_fin_step_dynamic"] = args.fading_layers
    if bbox_model is not None:
        model_args["bbox_model"] = bbox_model

    my_model = SIREN_NeRFt

    model = my_model(D=args.netdepth, W=args.netwidth,
                     input_ch=input_ch, output_ch=output_ch, skips=skips,
                     input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs, **model_args)
    if args.net_model == "hybrid":
        model.toDevice(device)
    model = model.to(device)

    grad_vars = list(model.parameters())

    model_fine = None
    if args.N_importance > 0:
        model_fine = my_model(D=args.netdepth_fine, W=args.netwidth_fine,
                              input_ch=input_ch, output_ch=output_ch, skips=skips,
                              input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs, **model_args)
        if args.net_model == "hybrid":
            model_fine.toDevice(device)
        model_fine = model_fine.to(device)
        grad_vars += list(model_fine.parameters())

    network_query_fn = lambda inputs, viewdirs, network_fn: run_network(inputs, viewdirs, network_fn,
                                                                        embed_fn=embed_fn,
                                                                        embeddirs_fn=embeddirs_fn,
                                                                        netchunk=args.netchunk)

    # Create optimizer
    optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
    vel_optimizer = None
    if vel_model is not None:
        vel_grad_vars = list(vel_model.parameters())
        vel_optimizer = torch.optim.Adam(params=vel_grad_vars, lr=args.lrate, betas=(0.9, 0.999))

    start = 0

    render_kwargs_train = {
        'network_query_fn': network_query_fn,
        'perturb': args.perturb,
        'N_importance': args.N_importance,
        'network_fine': model_fine,
        'N_samples': args.N_samples,
        'network_fn': model,
        'use_viewdirs': args.use_viewdirs,
        'raw_noise_std': args.raw_noise_std,
    }

    # NDC only good for LLFF-style forward facing data
    if args.dataset_type != 'llff' or args.no_ndc:
        print('Not ndc!')
        render_kwargs_train['ndc'] = False
        render_kwargs_train['lindisp'] = args.lindisp

    render_kwargs_test = {k: render_kwargs_train[k] for k in render_kwargs_train}
    render_kwargs_test['perturb'] = False
    render_kwargs_test['raw_noise_std'] = 0.

    return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer, vel_optimizer


# Create nerf model
render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer, vel_optimizer = create_nerf(args,
                                                                                                  vel_model=vel_model,
                                                                                                  bbox_model=bbox_model,
                                                                                                  ndim=4)

Not ndc!


In [8]:
def model_fading_update(models, global_step, tempoDelay, velDelay, isHybrid):
    tempoDelay = tempoDelay if isHybrid else 0
    for _m in models:
        if models[_m] is None: continue
        if _m == "vel_model":
            models[_m].update_fading_step(global_step - tempoDelay - velDelay)
        elif isHybrid:
            models[_m].update_fading_step(global_step, global_step - tempoDelay)
        else:
            models[_m].update_fading_step(global_step)


global_step = start

update_dict = {
    'near': near,
    'far': far,
    'has_t': True,
}
render_kwargs_train.update(update_dict)
render_kwargs_test.update(update_dict)
render_kwargs_train['vel_model'] = vel_model
render_kwargs_test['remove99'] = True

all_models = {
    "vel_model": vel_model,
    "coarse": render_kwargs_train['network_fn'],
    "fine": render_kwargs_train['network_fine'],
}
save_dic_keys = {
    "vel_model": "network_vel_state_dict",
    "coarse": "network_fn_state_dict",
    "fine": "network_fine_state_dict",
}

tempoInStep = max(0, args.tempo_delay) if args.net_model == "hybrid" else 0
velInStep = max(0, args.vel_delay) if args.nseW > 1e-8 else 0  # after tempoInStep
if args.net_model != "nerf":
    model_fading_update(all_models, start, tempoInStep, velInStep, args.net_model == "hybrid")

# Move testing data to GPU
render_poses = torch.Tensor(render_poses).to(device)
render_timesteps = torch.Tensor(render_timesteps).to(device)
test_bkg_color = np.float32([0.0, 0.0, 0.3])

# Prepare raybatch tensor if batching random rays
N_rand = args.N_rand
use_batching = not args.no_batching
if (use_batching) or (N_rand is None):
    print('Not supported!')
    raise NotImplementedError

In [9]:
class Scale(torch.nn.Module):
    def __init__(self, module, scale):
        super().__init__()
        self.module = module
        self.register_buffer('scale', torch.tensor(scale))

    def extra_repr(self):
        return f'(scale): {self.scale.item():g}'

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs) * self.scale


# VGG Tool, https://github.com/crowsonkb/style-transfer-pytorch/
class VGGFeatures(torch.nn.Module):
    poolings = {'max': torch.nn.MaxPool2d, 'average': torch.nn.AvgPool2d}  #, 'l2': partial(nn.LPPool2d, 2)}
    pooling_scales = {'max': 1., 'average': 2., 'l2': 0.78}

    def __init__(self, layers, pooling='max'):
        super().__init__()
        self.layers = sorted(set(layers))

        # The PyTorch pre-trained VGG-19 expects sRGB inputs in the range [0, 1] which are then
        # normalized according to this transform, unlike Simonyan et al.'s original model.
        self.normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                          std=[0.229, 0.224, 0.225])

        # The PyTorch pre-trained VGG-19 has different parameters from Simonyan et al.'s original
        # model.
        self.model = torchvision.models.vgg19(pretrained=True).features[:self.layers[-1] + 1]
        self.devices = [torch.device('cpu')] * len(self.model)

        # Reduces edge artifacts.
        self.model[0] = self._change_padding_mode(self.model[0], 'replicate')

        pool_scale = self.pooling_scales[pooling]
        for i, layer in enumerate(self.model):
            if pooling != 'max' and isinstance(layer, torch.nn.MaxPool2d):
                # Changing the pooling type from max results in the scale of activations
                # changing, so rescale them. Gatys et al. (2015) do not do this.
                self.model[i] = Scale(self.poolings[pooling](2), pool_scale)

        self.model.eval()
        self.model.requires_grad_(False)

    @staticmethod
    def _change_padding_mode(conv, padding_mode):
        new_conv = torch.nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size,
                                   stride=conv.stride, padding=conv.padding,
                                   padding_mode=padding_mode)
        with torch.no_grad():
            new_conv.weight.copy_(conv.weight)
            new_conv.bias.copy_(conv.bias)
        return new_conv

    @staticmethod
    def _get_min_size(layers):
        last_layer = max(layers)
        min_size = 1
        for layer in [4, 9, 18, 27, 36]:
            if last_layer < layer:
                break
            min_size *= 2
        return min_size

    def distribute_layers(self, devices):
        for i, layer in enumerate(self.model):
            if i in devices:
                device = torch.device(devices[i])
            self.model[i] = layer.to(device)
            self.devices[i] = device

    def forward(self, input, layers=None):
        # input shape, b,3,h,w
        layers = self.layers if layers is None else sorted(set(layers))
        h, w = input.shape[2:4]
        min_size = self._get_min_size(layers)
        if min(h, w) < min_size:
            raise ValueError(f'Input is {h}x{w} but must be at least {min_size}x{min_size}')
        feats = {'input': input}
        norm_in = torch.stack([self.normalize(input[_i]) for _i in range(input.shape[0])], dim=0)
        # input = self.normalize(input)
        for i in range(max(layers) + 1):
            norm_in = self.model[i](norm_in.to(self.devices[i]))
            if i in layers:
                feats[i] = norm_in
        return feats


# VGG Loss Tool
class VGGlossTool(object):
    def __init__(self, device, pooling='max'):
        # The default content and style layers in Gatys et al. (2015):
        #   content_layers = [22], 'relu4_2'
        #   style_layers = [1, 6, 11, 20, 29], relu layers: [ 'relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1']
        # We use [5, 10, 19, 28], conv layers before relu: [ 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1']
        self.layer_list = [5, 10, 19, 28]
        self.layer_names = [
            "block2_conv1",
            "block3_conv1",
            "block4_conv1",
            "block5_conv1",
        ]
        self.device = device

        # Build a VGG19 model loaded with pre-trained ImageNet weights
        self.vggmodel = VGGFeatures(self.layer_list, pooling=pooling)
        device_plan = {0: device}
        self.vggmodel.distribute_layers(device_plan)

    def feature_norm(self, feature):
        # feature: b,h,w,c
        feature_len = torch.sqrt(torch.sum(torch.square(feature), dim=-1, keepdim=True) + 1e-12)
        norm = feature / feature_len
        return norm

    def cos_sim(self, a, b):
        cos_sim_ab = torch.sum(a * b, dim=-1)
        # cosine similarity, -1~1, 1 best
        cos_sim_ab_score = 1.0 - torch.mean(cos_sim_ab)  # 0 ~ 2, 0 best
        return cos_sim_ab_score

    def compute_cos_loss(self, img, ref):
        # input img, ref should be in range of [0,1]
        input_tensor = torch.stack([ref, img], dim=0)

        input_tensor = input_tensor.permute((0, 3, 1, 2))
        # print(input_tensor.shape)
        _feats = self.vggmodel(input_tensor, layers=self.layer_list)

        # Initialize the loss
        loss = []
        # Add loss
        for layer_i, layer_name in zip(self.layer_list, self.layer_names):
            cur_feature = _feats[layer_i]
            reference_features = self.feature_norm(cur_feature[0, ...])
            img_features = self.feature_norm(cur_feature[1, ...])

            feature_metric = self.cos_sim(reference_features, img_features)
            loss += [feature_metric]
        return loss


def pos_smoke2world(Psmoke, s2w, scale_vector):
    pos_scale = Psmoke * (scale_vector)  # 2.simulation to 3.target
    pos_rot = torch.sum(pos_scale[..., None, :] * (s2w[:3, :3]), -1)  # 3.target to 4.world
    pos_off = (s2w[:3, -1]).expand(pos_rot.shape)  # 3.target to 4.world
    return pos_rot + pos_off


def get_voxel_pts(H, W, D, s2w, scale_vector, n_jitter=0, r_jitter=0.8):
    """Get voxel positions."""

    i, j, k = torch.meshgrid(torch.linspace(0, D - 1, D),
                             torch.linspace(0, H - 1, H),
                             torch.linspace(0, W - 1, W))
    pts = torch.stack([(k + 0.5) / W, (j + 0.5) / H, (i + 0.5) / D], -1)
    # shape D*H*W*3, value [(x,y,z)] , range [0,1]

    jitter_r = torch.Tensor([r_jitter / W, r_jitter / H, r_jitter / D]).float().expand(pts.shape)
    for i_jitter in range(n_jitter):
        off_i = torch.rand(pts.shape, dtype=torch.float) - 0.5
        # shape D*H*W*3, value [(x,y,z)] , range [-0.5,0.5]

        pts = pts + off_i * jitter_r

    return pos_smoke2world(pts, s2w, scale_vector)


# from FFJORD github code
def _get_minibatch_jacobian(y, x):
    """Computes the Jacobian of y wrt x assuming minibatch-mode.
    Args:
      y: (N, ...) with a total of D_y elements in ...
      x: (N, ...) with a total of D_x elements in ...
    Returns:
      The minibatch Jacobian matrix of shape (N, D_y, D_x)
    """
    assert y.shape[0] == x.shape[0]
    y = y.view(y.shape[0], -1)

    # Compute Jacobian row by row.
    jac = []
    for j in range(y.shape[1]):
        dy_j_dx = torch.autograd.grad(
            y[:, j],
            x,
            torch.ones_like(y[:, j], device=y.get_device()),
            retain_graph=True,
            create_graph=True,
        )[0].view(x.shape[0], -1)
        jac.append(torch.unsqueeze(dy_j_dx, 1))
    jac = torch.cat(jac, 1)
    return jac


def vel_world2smoke(Vworld, w2s, scale_vector, st_factor):
    _st_factor = torch.Tensor(st_factor).expand((3,))
    vel_rot = Vworld[..., None, :] * (w2s[:3, :3])
    vel_rot = torch.sum(vel_rot, -1)  # 4.world to 3.target 
    vel_scale = vel_rot / (scale_vector) * _st_factor  # 3.target to 2.simulation
    return vel_scale


def den_scalar2rgb(den, scale=160, is3D=False, logv=False, mix=True):
    # den: a np.float32 array, in shape of (?=b,) d,h,w,1 for 3D and (?=b,)h,w,1 for 2D
    # scale: scale content to 0~255, something between 100-255 is usually good. 
    #        content will be normalized if scale is None
    # logv: visualize value with log
    # mix: use averaged value as a volumetric visualization if True, else show middle slice

    ori_shape = list(den.shape)
    if ori_shape[-1] != 1:
        ori_shape.append(1)
        den = np.reshape(den, ori_shape)

    if is3D:
        new_range = list(range(len(ori_shape)))
        z_new_range = new_range[:]
        z_new_range[-4] = new_range[-3]
        z_new_range[-3] = new_range[-4]
        # print(z_new_range)
        YZXden = np.transpose(den, z_new_range)

        if not mix:
            _yz = YZXden[..., (ori_shape[-2] - 1) // 2, :]
            _yx = YZXden[..., (ori_shape[-4] - 1) // 2, :, :]
            _zx = YZXden[..., (ori_shape[-3] - 1) // 2, :, :, :]
        else:
            _yz = np.average(YZXden, axis=-2)
            _yx = np.average(YZXden, axis=-3)
            _zx = np.average(YZXden, axis=-4)
            # print(_yx.shape, _yz.shape, _zx.shape)

        # in case resolution is not a cube, (res,res,res)
        _yxz = np.concatenate([  #yz, yx, zx
            _yx, _yz], axis=-2)  # (?=b,),h,w+zdim,1

        if ori_shape[-3] < ori_shape[-4]:
            pad_shape = list(_yxz.shape)  #(?=b,),h,w+zdim,1
            pad_shape[-3] = ori_shape[-4] - ori_shape[-3]
            _pad = np.zeros(pad_shape, dtype=np.float32)
            _yxz = np.concatenate([_yxz, _pad], axis=-3)
        elif ori_shape[-3] > ori_shape[-4]:
            pad_shape = list(_zx.shape)  #(?=b,),h,w+zdim,1
            pad_shape[-3] = ori_shape[-3] - ori_shape[-4]

            _zx = np.concatenate(
                [_zx, np.zeros(pad_shape, dtype=np.float32)], axis=-3)

        midDen = np.concatenate([  #yz, yx, zx
            _yxz, _zx
        ], axis=-2)  # (?=b,),h,w*3,1
    else:
        midDen = den

    if logv:
        midDen = np.log10(midDen + 1)
    if scale is None:
        midDen = midDen / max(midDen.max(), 1e-6) * 255.0
    else:
        midDen = midDen * scale
    grey = np.clip(midDen, 0, 255)

    return grey.astype(np.uint8)[::-1]  # flip y


def cubecenter(cube, axis, half=0):
    # cube: (b,)h,h,h,c
    # axis: 1 (z), 2 (y), 3 (x)
    reduce_axis = [a for a in [1, 2, 3] if a != axis]
    pack = np.mean(cube, axis=tuple(reduce_axis))  # (b,)h,c
    pack = np.sqrt(np.sum(np.square(pack), axis=-1) + 1e-6)  # (b,)h

    length = cube.shape[axis - 5]  # h
    weights = np.arange(0.5 / length, 1.0, 1.0 / length)
    if half == 1:  # first half
        weights = np.where(weights < 0.5, weights, np.zeros_like(weights))
        pack = np.where(weights < 0.5, pack, np.zeros_like(pack))
    elif half == 2:  # second half
        weights = np.where(weights > 0.5, weights, np.zeros_like(weights))
        pack = np.where(weights > 0.5, pack, np.zeros_like(pack))

    weighted = pack * weights  # (b,)h
    weiAxis = np.sum(weighted, axis=-1) / np.sum(pack, axis=-1) * length  # (b,)

    return weiAxis.astype(np.int32)  # a ceiling is included


def vel2hsv(velin, is3D, logv, scale=None):  # 2D
    fx, fy = velin[..., 0], velin[..., 1]
    ori_shape = list(velin.shape[:-1]) + [3]
    if is3D:
        fz = velin[..., 2]
        ang = np.arctan2(fz, fx) + np.pi  # angXZ
        zxlen2 = fx * fx + fz * fz
        angY = np.arctan2(np.abs(fy), np.sqrt(zxlen2))
        v = np.sqrt(zxlen2 + fy * fy)
    else:
        v = np.sqrt(fx * fx + fy * fy)
        ang = np.arctan2(fy, fx) + np.pi

    if logv:
        v = np.log10(v + 1)

    hsv = np.zeros(ori_shape, np.uint8)
    hsv[..., 0] = ang * (180 / np.pi / 2)
    if is3D:
        hsv[..., 1] = 255 - angY * (240 / np.pi * 2)
    else:
        hsv[..., 1] = 255
    if scale is not None:
        hsv[..., 2] = np.minimum(v * scale, 255)
    else:
        hsv[..., 2] = v / max(v.max(), 1e-6) * 255.0
    return hsv


def velLegendHSV(hsvin, is3D, lw=-1, constV=255):
    # hsvin: (b), h, w, 3
    # always overwrite hsvin borders [lw], please pad hsvin before hand
    # or fill whole hsvin (lw < 0)
    ih, iw = hsvin.shape[-3:-1]
    if lw <= 0:  # fill whole
        a_list, b_list = [range(ih)], [range(iw)]
    else:  # fill border
        a_list = [range(ih), range(lw), range(ih), range(ih - lw, ih)]
        b_list = [range(lw), range(iw), range(iw - lw, iw), range(iw)]
    for a, b in zip(a_list, b_list):
        for _fty in a:
            for _ftx in b:
                fty = _fty - ih // 2
                ftx = _ftx - iw // 2
                ftang = np.arctan2(fty, ftx) + np.pi
                ftang = ftang * (180 / np.pi / 2)
                # print("ftang,min,max,mean", ftang.min(), ftang.max(), ftang.mean())
                # ftang,min,max,mean 0.7031249999999849 180.0 90.3515625
                hsvin[..., _fty, _ftx, 0] = np.expand_dims(ftang, axis=-1)  # 0-360 
                # hsvin[...,_fty,_ftx,0] = ftang
                hsvin[..., _fty, _ftx, 2] = constV
                if (not is3D) or (lw == 1):
                    hsvin[..., _fty, _ftx, 1] = 255
                else:
                    thetaY1 = 1.0 - ((ih // 2) - abs(fty)) / float(lw if (lw > 1) else (ih // 2))
                    thetaY2 = 1.0 - ((iw // 2) - abs(ftx)) / float(lw if (lw > 1) else (iw // 2))
                    fthetaY = max(thetaY1, thetaY2) * (0.5 * np.pi)
                    ftxY, ftyY = np.cos(fthetaY), np.sin(fthetaY)
                    fangY = np.arctan2(ftyY, ftxY)
                    fangY = fangY * (240 / np.pi * 2)  # 240 - 0
                    hsvin[..., _fty, _ftx, 1] = 255 - fangY
                    # print("fangY,min,max,mean", fangY.min(), fangY.max(), fangY.mean())
    # finished velLegendHSV.


def vel_uv2hsv(vel, scale=160, is3D=False, logv=False, mix=False):
    # vel: a np.float32 array, in shape of (?=b,) d,h,w,3 for 3D and (?=b,)h,w, 2 or 3 for 2D
    # scale: scale content to 0~255, something between 100-255 is usually good. 
    #        content will be normalized if scale is None
    # logv: visualize value with log
    # mix: use more slices to get a volumetric visualization if True, which is slow

    ori_shape = list(vel.shape[:-1]) + [3]  # (?=b,) d,h,w,3
    if is3D:
        new_range = list(range(len(ori_shape)))
        z_new_range = new_range[:]
        z_new_range[-4] = new_range[-3]
        z_new_range[-3] = new_range[-4]
        # print(z_new_range)
        YZXvel = np.transpose(vel, z_new_range)

        _xm, _ym, _zm = (ori_shape[-2] - 1) // 2, (ori_shape[-3] - 1) // 2, (ori_shape[-4] - 1) // 2

        if mix:
            _xlist = [cubecenter(vel, 3, 1), _xm, cubecenter(vel, 3, 2)]
            _ylist = [cubecenter(vel, 2, 1), _ym, cubecenter(vel, 2, 2)]
            _zlist = [cubecenter(vel, 1, 1), _zm, cubecenter(vel, 1, 2)]
        else:
            _xlist, _ylist, _zlist = [_xm], [_ym], [_zm]

        hsv = []
        for _x, _y, _z in zip(_xlist, _ylist, _zlist):
            # print(_x, _y, _z)
            _x, _y, _z = np.clip([_x, _y, _z], 0, ori_shape[-2:-5:-1])
            _yz = YZXvel[..., _x, :]
            _yz = np.stack([_yz[..., 2], _yz[..., 0], _yz[..., 1]], axis=-1)
            _yx = YZXvel[..., _z, :, :]
            _yx = np.stack([_yx[..., 0], _yx[..., 2], _yx[..., 1]], axis=-1)
            _zx = YZXvel[..., _y, :, :, :]
            _zx = np.stack([_zx[..., 0], _zx[..., 1], _zx[..., 2]], axis=-1)
            # print(_yx.shape, _yz.shape, _zx.shape)

            # in case resolution is not a cube, (res,res,res)
            _yxz = np.concatenate([  #yz, yx, zx
                _yx, _yz], axis=-2)  # (?=b,),h,w+zdim,3

            if ori_shape[-3] < ori_shape[-4]:
                pad_shape = list(_yxz.shape)  #(?=b,),h,w+zdim,3
                pad_shape[-3] = ori_shape[-4] - ori_shape[-3]
                _pad = np.zeros(pad_shape, dtype=np.float32)
                _yxz = np.concatenate([_yxz, _pad], axis=-3)
            elif ori_shape[-3] > ori_shape[-4]:
                pad_shape = list(_zx.shape)  #(?=b,),h,w+zdim,3
                pad_shape[-3] = ori_shape[-3] - ori_shape[-4]

                _zx = np.concatenate(
                    [_zx, np.zeros(pad_shape, dtype=np.float32)], axis=-3)

            midVel = np.concatenate([  #yz, yx, zx
                _yxz, _zx
            ], axis=-2)  # (?=b,),h,w*3,3
            hsv += [vel2hsv(midVel, True, logv, scale)]
        # remove depth dim, increase with zyx slices
        ori_shape[-3] = 3 * ori_shape[-2]
        ori_shape[-2] = ori_shape[-1]
        ori_shape = ori_shape[:-1]
    else:
        hsv = [vel2hsv(vel, False, logv, scale)]

    bgr = []
    for _hsv in hsv:
        if len(ori_shape) > 3:
            _hsv = _hsv.reshape([-1] + ori_shape[-2:])
        if is3D:
            velLegendHSV(_hsv, is3D, lw=max(1, min(6, int(0.025 * ori_shape[-2]))), constV=255)
        _hsv = cv2.cvtColor(_hsv, cv2.COLOR_HSV2BGR)
        if len(ori_shape) > 3:
            _hsv = _hsv.reshape(ori_shape)
        bgr += [_hsv]
    if len(bgr) == 1:
        bgr = bgr[0]
    else:
        bgr = bgr[0] * 0.2 + bgr[1] * 0.6 + bgr[2] * 0.2
    return bgr.astype(np.uint8)[::-1]  # flip Y


def jacobian3D_np(x):
    # x, (b,)d,h,w,ch
    # return jacobian and curl

    if len(x.shape) < 5:
        x = np.expand_dims(x, axis=0)
    dudx = x[:, :, :, 1:, 0] - x[:, :, :, :-1, 0]
    dvdx = x[:, :, :, 1:, 1] - x[:, :, :, :-1, 1]
    dwdx = x[:, :, :, 1:, 2] - x[:, :, :, :-1, 2]
    dudy = x[:, :, 1:, :, 0] - x[:, :, :-1, :, 0]
    dvdy = x[:, :, 1:, :, 1] - x[:, :, :-1, :, 1]
    dwdy = x[:, :, 1:, :, 2] - x[:, :, :-1, :, 2]
    dudz = x[:, 1:, :, :, 0] - x[:, :-1, :, :, 0]
    dvdz = x[:, 1:, :, :, 1] - x[:, :-1, :, :, 1]
    dwdz = x[:, 1:, :, :, 2] - x[:, :-1, :, :, 2]

    # u = dwdy[:,:-1,:,:-1] - dvdz[:,:,1:,:-1]
    # v = dudz[:,:,1:,:-1] - dwdx[:,:-1,1:,:]
    # w = dvdx[:,:-1,1:,:] - dudy[:,:-1,:,:-1]

    dudx = np.concatenate((dudx, np.expand_dims(dudx[:, :, :, -1], axis=3)), axis=3)
    dvdx = np.concatenate((dvdx, np.expand_dims(dvdx[:, :, :, -1], axis=3)), axis=3)
    dwdx = np.concatenate((dwdx, np.expand_dims(dwdx[:, :, :, -1], axis=3)), axis=3)

    dudy = np.concatenate((dudy, np.expand_dims(dudy[:, :, -1, :], axis=2)), axis=2)
    dvdy = np.concatenate((dvdy, np.expand_dims(dvdy[:, :, -1, :], axis=2)), axis=2)
    dwdy = np.concatenate((dwdy, np.expand_dims(dwdy[:, :, -1, :], axis=2)), axis=2)

    dudz = np.concatenate((dudz, np.expand_dims(dudz[:, -1, :, :], axis=1)), axis=1)
    dvdz = np.concatenate((dvdz, np.expand_dims(dvdz[:, -1, :, :], axis=1)), axis=1)
    dwdz = np.concatenate((dwdz, np.expand_dims(dwdz[:, -1, :, :], axis=1)), axis=1)

    u = dwdy - dvdz
    v = dudz - dwdx
    w = dvdx - dudy

    j = np.stack([dudx, dudy, dudz, dvdx, dvdy, dvdz, dwdx, dwdy, dwdz], axis=-1)
    c = np.stack([u, v, w], axis=-1)

    return j, c


class Voxel_Tool(object):

    def __get_tri_slice(self, _xm, _ym, _zm, _n=1):
        _yz = torch.reshape(self.pts[..., _xm:_xm + _n, :], (-1, 3))
        _zx = torch.reshape(self.pts[:, _ym:_ym + _n, ...], (-1, 3))
        _xy = torch.reshape(self.pts[_zm:_zm + _n, ...], (-1, 3))

        pts_mid = torch.cat([_yz, _zx, _xy], dim=0)
        npMaskXYZ = [np.zeros([self.D, self.H, self.W, 1], dtype=np.float32) for _ in range(3)]
        npMaskXYZ[0][..., _xm:_xm + _n, :] = 1.0
        npMaskXYZ[1][:, _ym:_ym + _n, ...] = 1.0
        npMaskXYZ[2][_zm:_zm + _n, ...] = 1.0
        return pts_mid, torch.tensor(np.clip(npMaskXYZ[0] + npMaskXYZ[1] + npMaskXYZ[2], 1e-6, 3.0))

    def __pad_slice_to_volume(self, _slice, _n, mode=0):
        # mode: 0, x_slice, 1, y_slice, 2, z_slice
        tar_shape = [self.D, self.H, self.W]
        in_shape = tar_shape[:]
        in_shape[-1 - mode] = _n
        fron_shape = tar_shape[:]
        fron_shape[-1 - mode] = (tar_shape[-1 - mode] - _n) // 2
        back_shape = tar_shape[:]
        back_shape[-1 - mode] = (tar_shape[-1 - mode] - _n - fron_shape[-1 - mode])

        cur_slice = _slice.view(in_shape + [-1])
        front_0 = torch.zeros(fron_shape + [cur_slice.shape[-1]])
        back_0 = torch.zeros(back_shape + [cur_slice.shape[-1]])

        volume = torch.cat([front_0, cur_slice, back_0], dim=-2 - mode)
        return volume

    def __init__(self, smoke_tran, smoke_tran_inv, smoke_scale, D, H, W, middleView=None):
        self.s_s2w = torch.Tensor(smoke_tran).expand([4, 4])
        self.s_w2s = torch.Tensor(smoke_tran_inv).expand([4, 4])
        self.s_scale = torch.Tensor(smoke_scale).expand([3])
        self.D = D
        self.H = H
        self.W = W
        self.pts = get_voxel_pts(H, W, D, self.s_s2w, self.s_scale)
        self.pts_mid = None
        self.npMaskXYZ = None
        self.middleView = middleView
        if middleView is not None:
            _n = 1 if self.middleView == "mid" else 3
            _xm, _ym, _zm = (W - _n) // 2, (H - _n) // 2, (D - _n) // 2
            self.pts_mid, self.npMaskXYZ = self.__get_tri_slice(_xm, _ym, _zm, _n)

    def get_raw_at_pts(self, cur_pts, chunk=1024 * 32, use_viewdirs=False,
                       network_query_fn=None, network_fn=None):
        input_shape = list(cur_pts.shape[0:-1])

        pts_flat = cur_pts.view(-1, 4)
        pts_N = pts_flat.shape[0]
        # Evaluate model
        all_raw = []
        viewdir_zeros = torch.zeros([chunk, 3], dtype=torch.float) if use_viewdirs else None
        for i in range(0, pts_N, chunk):
            pts_i = pts_flat[i:i + chunk]
            viewdir_i = viewdir_zeros[:pts_i.shape[0]] if use_viewdirs else None

            raw_i = network_query_fn(pts_i, viewdir_i, network_fn)
            all_raw.append(raw_i)

        raw = torch.cat(all_raw, 0).view(input_shape + [-1])
        return raw

    def get_density_flat(self, cur_pts, chunk=1024 * 32, use_viewdirs=False,
                         network_query_fn=None, network_fn=None, getStatic=True):
        flat_raw = self.get_raw_at_pts(cur_pts, chunk, use_viewdirs, network_query_fn, network_fn)
        den_raw = torch.nn.functional.relu(flat_raw[..., -1:])
        returnStatic = getStatic and (flat_raw.shape[-1] > 4)
        if returnStatic:
            static_raw = torch.nn.functional.relu(flat_raw[..., 3:4])
            return [den_raw, static_raw]
        return [den_raw]

    def get_velocity_flat(self, cur_pts, batchify_fn, chunk=1024 * 32,
                          vel_model=None):
        pts_N = cur_pts.shape[0]
        world_v = []
        for i in range(0, pts_N, chunk):
            input_i = cur_pts[i:i + chunk]
            vel_i = batchify_fn(vel_model, chunk)(input_i)
            world_v.append(vel_i)
        world_v = torch.cat(world_v, 0)
        return world_v

    def get_density_and_derivatives(self, cur_pts, chunk=1024 * 32, use_viewdirs=False,
                                    network_query_fn=None, network_fn=None):
        _den = self.get_density_flat(cur_pts, chunk, use_viewdirs, network_query_fn, network_fn, False)[0]
        # requires 1 backward passes 
        # The minibatch Jacobian matrix of shape (N, D_y=1, D_x=4)
        jac = _get_minibatch_jacobian(_den, cur_pts)
        _d_x, _d_y, _d_z, _d_t = [torch.squeeze(_, -1) for _ in jac.split(1, dim=-1)]  # (N,1)
        return _den, _d_x, _d_y, _d_z, _d_t

    def get_velocity_and_derivatives(self, cur_pts, chunk=1024 * 32, batchify_fn=None, vel_model=None):
        _vel = self.get_velocity_flat(cur_pts, batchify_fn, chunk, vel_model)
        # requires 3 backward passes
        # The minibatch Jacobian matrix of shape (N, D_y=3, D_x=4)
        jac = _get_minibatch_jacobian(_vel, cur_pts)
        _u_x, _u_y, _u_z, _u_t = [torch.squeeze(_, -1) for _ in jac.split(1, dim=-1)]  # (N,3)
        return _vel, _u_x, _u_y, _u_z, _u_t

    def get_voxel_density_list(self, t=None, chunk=1024 * 32, use_viewdirs=False,
                               network_query_fn=None, network_fn=None, middle_slice=False):
        D, H, W = self.D, self.H, self.W
        # middle_slice, only for fast visualization of the middle slice
        pts_flat = self.pts_mid if middle_slice else self.pts.view(-1, 3)
        pts_N = pts_flat.shape[0]
        if t is not None:
            input_t = torch.ones([pts_N, 1]) * float(t)
            pts_flat = torch.cat([pts_flat, input_t], dim=-1)

        den_list = self.get_density_flat(pts_flat, chunk, use_viewdirs, network_query_fn, network_fn)

        return_list = []
        for den_raw in den_list:
            if middle_slice:
                # only for fast visualization of the middle slice
                _n = 1 if self.middleView == "mid" else 3
                _yzV, _zxV, _xyV = torch.split(den_raw, [D * H * _n, D * W * _n, H * W * _n], dim=0)
                mixV = self.__pad_slice_to_volume(_yzV, _n, 0) + self.__pad_slice_to_volume(_zxV, _n,
                                                                                            1) + self.__pad_slice_to_volume(
                    _xyV, _n, 2)
                return_list.append(mixV / self.npMaskXYZ)
            else:
                return_list.append(den_raw.view(D, H, W, 1))
        return return_list

    def get_voxel_velocity(self, deltaT, t, batchify_fn, chunk=1024 * 32,
                           vel_model=None, middle_slice=False):
        # middle_slice, only for fast visualization of the middle slice
        D, H, W = self.D, self.H, self.W
        pts_flat = self.pts_mid if middle_slice else self.pts.view(-1, 3)
        pts_N = pts_flat.shape[0]
        if t is not None:
            input_t = torch.ones([pts_N, 1]) * float(t)
            pts_flat = torch.cat([pts_flat, input_t], dim=-1)

        world_v = self.get_velocity_flat(pts_flat, batchify_fn, chunk, vel_model)
        reso_scale = [self.W * deltaT, self.H * deltaT, self.D * deltaT]
        target_v = vel_world2smoke(world_v, self.s_w2s, self.s_scale, reso_scale)

        if middle_slice:
            _n = 1 if self.middleView == "mid" else 3
            _yzV, _zxV, _xyV = torch.split(target_v, [D * H * _n, D * W * _n, H * W * _n], dim=0)
            mixV = self.__pad_slice_to_volume(_yzV, _n, 0) + self.__pad_slice_to_volume(_zxV, _n,
                                                                                        1) + self.__pad_slice_to_volume(
                _xyV, _n, 2)
            target_v = mixV / self.npMaskXYZ
        else:
            target_v = target_v.view(D, H, W, 3)

        return target_v

    def save_voxel_den_npz(self, den_path, t, use_viewdirs=False, network_query_fn=None, network_fn=None,
                           chunk=1024 * 32, save_npz=True, save_jpg=False, jpg_mix=True, noStatic=False):
        voxel_den_list = self.get_voxel_density_list(t, chunk, use_viewdirs, network_query_fn,
                                                     network_fn, middle_slice=not (jpg_mix or save_npz))
        head_tail = os.path.split(den_path)
        namepre = ["", "static_"]
        for voxel_den, npre in zip(voxel_den_list, namepre):
            voxel_den = voxel_den.detach().cpu().numpy()
            if save_jpg:
                jpg_path = os.path.join(head_tail[0], npre + os.path.splitext(head_tail[1])[0] + ".jpg")
                imageio.imwrite(jpg_path, den_scalar2rgb(voxel_den, scale=None, is3D=True, logv=False, mix=jpg_mix))
            if save_npz:
                # to save some space
                npz_path = os.path.join(head_tail[0], npre + os.path.splitext(head_tail[1])[0] + ".npz")
                voxel_den = np.float16(voxel_den)
                np.savez_compressed(npz_path, vel=voxel_den)
            if noStatic:
                break

    def save_voxel_vel_npz(self, vel_path, deltaT, t, batchify_fn, chunk=1024 * 32, vel_model=None, save_npz=True,
                           save_jpg=False, save_vort=False):
        vel_scale = 160
        voxel_vel = self.get_voxel_velocity(deltaT, t, batchify_fn, chunk, vel_model,
                                            middle_slice=not save_npz).detach().cpu().numpy()

        if save_jpg:
            jpg_path = os.path.splitext(vel_path)[0] + ".jpg"
            imageio.imwrite(jpg_path, vel_uv2hsv(voxel_vel, scale=vel_scale, is3D=True, logv=False))
        if save_npz:
            if save_vort and save_jpg:
                _, NETw = jacobian3D_np(voxel_vel)
                head_tail = os.path.split(vel_path)
                imageio.imwrite(os.path.join(head_tail[0], "vort" + os.path.splitext(head_tail[1])[0] + ".jpg"),
                                vel_uv2hsv(NETw[0], scale=vel_scale * 5.0, is3D=True))
            # to save some space
            voxel_vel = np.float16(voxel_vel)
            np.savez_compressed(vel_path, vel=voxel_vel)


# Prepare Loss Tools (VGG, Den2Vel)
###############################################
vggTool = VGGlossTool(device)

# Move to GPU, except images
poses = torch.Tensor(poses).to(device)
time_steps = torch.Tensor(time_steps).to(device)

N_iters = args.N_iter + 1

print('Begin')
print('TRAIN views are', i_train)
print('TEST views are', i_test)
print('VAL views are', i_val)

# Prepare Voxel Sampling Tools for Image Summary (voxel_writer), Physical Priors (training_voxel), Data Priors Represented by D2V (den_p_all)
# voxel_writer: to sample low resolution data for for image summary 
resX = 64  # complexity O(N^3)
resY = int(resX * float(voxel_scale[1]) / voxel_scale[0] + 0.5)
resZ = int(resX * float(voxel_scale[2]) / voxel_scale[0] + 0.5)
voxel_writer = Voxel_Tool(voxel_tran, voxel_tran_inv, voxel_scale, resZ, resY, resX, middleView='mid3')

# training_voxel: to sample data for for velocity NSE training
# training_voxel should have a larger resolution than voxel_writer
# note that training voxel is also used for visualization in testing
min_ratio = float(64 + 4 * 2) / min(voxel_scale[0], voxel_scale[1], voxel_scale[2])
minX = int(min_ratio * voxel_scale[0] + 0.5)
trainX = max(args.vol_output_W, minX)  # a minimal resolution of 64^3
trainY = int(trainX * float(voxel_scale[1]) / voxel_scale[0] + 0.5)
trainZ = int(trainX * float(voxel_scale[2]) / voxel_scale[0] + 0.5)
training_voxel = Voxel_Tool(voxel_tran, voxel_tran_inv, voxel_scale, trainZ, trainY, trainX, middleView='mid3')
training_pts = torch.reshape(training_voxel.pts, (-1, 3))

# prepare grid positions for velocity d2v training, its shortest spatial dim has denTW+d2v_min_border*2 cells
denTW = 64  # 256 will be 30 times slower
d2v_min_border = 2
denRatio = float(denTW + 2 * d2v_min_border) / min(trainX, trainY, trainZ)
den_p_all = get_voxel_pts(int(trainY * denRatio + 1e-6), int(trainX * denRatio + 1e-6), int(trainZ * denRatio + 1e-6),
                          voxel_tran, voxel_scale)
train_reso_scale = torch.Tensor([256 * t_info[-1], 256 * t_info[-1], 256 * t_info[-1]])
# train_reso_scale: d2v model is trained on simulation data with resoluiton of 256^3

split_nse_wei = [2.0, 1e-3, 1e-3, 1e-3, 1e-3, 5e-3]
start = start + 1

basedir = args.basedir
expname = args.expname
date_str = datetime.datetime.now().strftime("%m%d-%H%M%S")
logdir = 'log_' + ('train' if not (args.vol_output_only or args.render_only) else 'test')
logdir += date_str
testimgdir = os.path.join(basedir, expname, "imgs_" + logdir)
os.makedirs(testimgdir, exist_ok=True)
# some loss terms 
ghost_loss, overlay_loss, nseloss_fine, d2v_error = None, None, None, None



Begin
TRAIN views are [  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71
  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89
  90  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107
 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
 234 235 236 237 238 239 240 

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [10]:
import requests
import collections


def curl2D(x, data_format='NHWC'):
    assert data_format == 'NHWC'
    u = x[:, 1:, :, 0] - x[:, :-1, :, 0]  # ds/dy
    v = x[:, :, :-1, 0] - x[:, :, 1:, 0]  # -ds/dx,
    u = torch.cat([u, u[:, -1:, :]], dim=1)
    v = torch.cat([v, v[:, :, -1:]], dim=2)
    c = torch.stack([u, v], dim=-1)
    return c


def curl3D(x, data_format='NHWC'):
    assert data_format == 'NHWC'
    # x: bzyxc
    # dudx = x[:,:,:,1:,0] - x[:,:,:,:-1,0]
    dvdx = x[:, :, :, 1:, 1] - x[:, :, :, :-1, 1]  #
    dwdx = x[:, :, :, 1:, 2] - x[:, :, :, :-1, 2]  #
    dudy = x[:, :, 1:, :, 0] - x[:, :, :-1, :, 0]  # 
    # dvdy = x[:,:,1:,:,1] - x[:,:,:-1,:,1]
    dwdy = x[:, :, 1:, :, 2] - x[:, :, :-1, :, 2]  #
    dudz = x[:, 1:, :, :, 0] - x[:, :-1, :, :, 0]  # 
    dvdz = x[:, 1:, :, :, 1] - x[:, :-1, :, :, 1]  # 
    # dwdz = x[:,1:,:,:,2] - x[:,:-1,:,:,2]

    # dudx = torch.cat((dudx, dudx[:,:,:,-1]), dim=3)
    dvdx = torch.cat((dvdx, dvdx[:, :, :, -1:]), dim=3)  #
    dwdx = torch.cat((dwdx, dwdx[:, :, :, -1:]), dim=3)  #

    dudy = torch.cat((dudy, dudy[:, :, -1:, :]), dim=2)  #
    # dvdy = torch.cat((dvdy, dvdy[:,:,-1:,:]), dim=2)
    dwdy = torch.cat((dwdy, dwdy[:, :, -1:, :]), dim=2)  # 

    dudz = torch.cat((dudz, dudz[:, -1:, :, :]), dim=1)  #
    dvdz = torch.cat((dvdz, dvdz[:, -1:, :, :]), dim=1)  # 
    # dwdz = torch.cat((dwdz, dwdz[:,-1:,:,:]), dim=1)

    u = dwdy - dvdz
    v = dudz - dwdx
    w = dvdx - dudy

    # j = tf.stack([
    #       dudx,dudy,dudz,
    #       dvdx,dvdy,dvdz,
    #       dwdx,dwdy,dwdz
    # ], dim=-1)
    # curl = dwdy-dvdz,dudz-dwdx,dvdx-dudy
    c = torch.stack([u, v, w], dim=-1)

    return c


def download_file(filename, url):
    """
    Download an URL to a file
    """
    with open(filename, 'wb') as fout:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        # Write response data to file
        for block in response.iter_content(4096):
            fout.write(block)


class MyFCLayer(torch.nn.Module):
    def __init__(self, fc_weight, act):
        super(MyFCLayer, self).__init__()
        self.wei = torch.Tensor(fc_weight)
        self.act = act

    def forward(self, input):
        out = torch.matmul(input, self.wei)
        if self.act is not None and self.act == 'relu':
            out = torch.nn.functional.relu(out)
        return out


class MyNormLayer(torch.nn.Module):
    def __init__(self, is2D, shift, scale):
        super(MyNormLayer, self).__init__()
        self.is2D = is2D
        tar_shape = [1, -1, 1, 1]
        if not is2D:
            tar_shape.append(1)
        self.shift = torch.Tensor(np.reshape(shift, tar_shape))
        self.scale = torch.Tensor(np.reshape(scale, tar_shape))

    def forward(self, input):
        epsilon = 1e-5
        axis_ = [2, 3]
        if not self.is2D: axis_ += [4]
        # input shape, Batch,channel,H,W,D

        sigma, mean = torch.var_mean(input, axis_, unbiased=False, keepdim=True)
        normalized = (input - mean) / torch.sqrt(sigma + epsilon)
        out = self.scale * normalized + self.shift

        return out


def _padding_size_2d(in_height, in_width, filter_size, stride):
    if in_height % stride == 0:
        pad_along_height = max(filter_size - stride, 0)
    else:
        pad_along_height = max(filter_size - (in_height % stride), 0)
    if in_width % stride == 0:
        pad_along_width = max(filter_size - stride, 0)
    else:
        pad_along_width = max(filter_size - (in_width % stride), 0)

    pad_top = pad_along_height // 2
    pad_bottom = pad_along_height - pad_top
    pad_left = pad_along_width // 2
    pad_right = pad_along_width - pad_left

    return pad_left, pad_right, pad_top, pad_bottom


def _padding_size_3d(in_depth, in_height, in_width, filter_size, stride):
    if in_depth % stride == 0:
        pad_along_depth = max(filter_size - stride, 0)
    else:
        pad_along_depth = max(filter_size - (in_depth % stride), 0)
    if in_height % stride == 0:
        pad_along_height = max(filter_size - stride, 0)
    else:
        pad_along_height = max(filter_size - (in_height % stride), 0)
    if in_width % stride == 0:
        pad_along_width = max(filter_size - stride, 0)
    else:
        pad_along_width = max(filter_size - (in_width % stride), 0)
    pad_front = pad_along_depth // 2
    pad_back = pad_along_depth - pad_front
    pad_top = pad_along_height // 2
    pad_bottom = pad_along_height - pad_top
    pad_left = pad_along_width // 2
    pad_right = pad_along_width - pad_left

    return pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back


def _my_pad_fn(input, padding, mode, is2D, name):
    ori_shape = list(input.shape)
    if mode == 'same' or is2D:
        out = torch.nn.functional.pad(input, padding, 'constant')
    else:
        input2D = torch.reshape(input, [ori_shape[0], ori_shape[1] * ori_shape[2], ori_shape[3], ori_shape[4]])
        # b, c*d, h,w
        pad_2D = torch.nn.functional.pad(input2D, (padding[0], padding[1], padding[2], padding[3]), 'reflect')
        pad_2D = pad_2D.view(ori_shape[0], ori_shape[1], ori_shape[2], pad_2D.shape[-2], pad_2D.shape[-1])

        pad_flip = torch.flip(pad_2D, (2,))
        pad_front = pad_flip[:, :, -padding[4] - 1:-1, ...]  # pad_2D[:,:, padding[4]:0:-1, ...]
        pad_back = pad_flip[:, :, 1:1 + padding[5], ...]  # pad_2D[:,:, -2:-2-padding[5]:-1, ...]
        out = torch.cat([pad_front, pad_2D, pad_back], dim=2)

    # assert out.shape[0] == ori_shape[0]
    # assert out.shape[1] == ori_shape[1]
    # assert out.shape[2] == ori_shape[2] + padding[4] + padding[5]
    # assert out.shape[3] == ori_shape[3] + padding[2] + padding[3]
    # assert out.shape[4] == ori_shape[4] + padding[0] + padding[1]
    return out


class FixLayer(object):
    # base class for Generator and Discriminator    
    def __init__(self, _var_list, _var_name):
        self.var_list = _var_list
        self.var_name = _var_name
        self.var_idx = 0
        self._print = False
        # set to True to check if d2v variables are correctly loaded

    def get_variable(self, name, shape):
        test_var = self.var_list[self.var_idx]
        if self._print:
            print("[loading d2v variables]", name, shape, "get var %d" % self.var_idx, self.var_name[self.var_idx],
                  test_var.shape)
        self.var_idx += 1
        return test_var

    def _norm(self, name, is2D, channel):
        # input shape, Batch,channel,H,W,D
        shift = self.get_variable(name + '.norm_shift', shape=[channel])
        scale = self.get_variable(name + '.norm_scale', shape=[channel])
        return MyNormLayer(is2D, shift, scale)

    def _conv(self, ch_in, num_filters, name, k_size, stride, is2D=True, hasbias=True):
        conv_fn = torch.nn.Conv2d if is2D else torch.nn.Conv3d
        filter_shape = [num_filters, ch_in, k_size, k_size]
        bias_shape = [num_filters]
        transpose_order = [-1, -2, 0, 1]
        if not is2D:
            filter_shape.append(k_size)
            transpose_order.append(2)

        cur_layer = conv_fn(ch_in, num_filters, k_size, stride)
        with torch.no_grad():
            np_wei = self.get_variable(name + '.conv_w', filter_shape)
            # kernel_size * kernel_size * kernel_size *  ch_out * ch_in
            # [ch_out, ch_in, kernel_size, kernel_size, kernel_size]
            np_wei = np_wei.transpose(transpose_order)
            cur_layer.weight = torch.nn.Parameter(torch.Tensor(np_wei))
            if hasbias:
                np_bias = self.get_variable(name + '.conv_b', bias_shape)
                np_bias = np.reshape(np_bias, bias_shape)
            else:
                np_bias = np.zeros(bias_shape)
            cur_layer.bias = torch.nn.Parameter(torch.Tensor(np_bias))
        cur_layer.requires_grad_(False)

        return cur_layer

    def _deconv(self, ch_in, num_filters, name, k_size, stride, is2D=True, hasbias=True):
        deconv_fn = torch.nn.ConvTranspose2d if is2D else torch.nn.ConvTranspose3d
        filter_shape = [ch_in, num_filters, k_size, k_size]
        bias_shape = [num_filters]
        transpose_order = [-1, -2, 0, 1]
        if not is2D:
            filter_shape.append(k_size)
            transpose_order.append(2)
        cur_layer = deconv_fn(ch_in, num_filters, k_size, stride)
        with torch.no_grad():
            np_wei = self.get_variable(name + '.deconv_w', filter_shape)
            # kernel_size * kernel_size * kernel_size *  ch_out * ch_in
            # [ch_out, ch_in, kernel_size, kernel_size, kernel_size]
            np_wei = np_wei.transpose(transpose_order)
            cur_layer.weight = torch.nn.Parameter(torch.Tensor(np_wei))
            if hasbias:
                np_bias = self.get_variable(name + '.deconv_b', bias_shape)
                np_bias = np.reshape(np_bias, bias_shape)
            else:
                np_bias = np.zeros(bias_shape)
            cur_layer.bias = torch.nn.Parameter(torch.Tensor(np_bias))
        cur_layer.requires_grad_(False)

        return cur_layer

    def _fully_connected(self, ch_in, out_n, act, name):
        fc_shape = [ch_in, out_n]
        np_wei = self.get_variable(name + '.fc_w', fc_shape)
        # cur_layer = nn.Linear(ch_in, out_n, bias=False)
        return MyFCLayer(np_wei, act)

    def _residual(self, num_filters, name, is2D=True, norm='instance'):
        layer_norm1, layer_norm2 = None, None
        layer_res1 = self._conv(num_filters, num_filters, name + '.res1', 3, 1, is2D, hasbias=False)
        if norm is not None:
            if norm == 'instance':
                layer_norm1 = self._norm(name + '.norm1', is2D, num_filters)
        layer_res2 = self._conv(num_filters, num_filters, name + '.res2', 3, 1, is2D, hasbias=False)
        if norm is not None:
            if norm == 'instance':
                layer_norm2 = self._norm(name + '.norm2', is2D, num_filters)
        return layer_res1, layer_res2, layer_norm1, layer_norm2


class MyConvBlock(torch.nn.Module):
    def __init__(self, fix_layer_tool, ch_in, num_filters, name, k_size, stride,
                 pad='reflect', is2D=True, norm='instance', activation='relu'):
        assert norm in ['instance', None]
        assert activation in ['relu', None]
        super(MyConvBlock, self).__init__()
        self.conv_layer = fix_layer_tool._conv(ch_in, num_filters, name + '.CB',
                                               k_size, stride, is2D)
        self.norm_layer = None
        if norm is not None:
            if norm == 'instance':
                self.norm_layer = fix_layer_tool._norm(name + '.CB', is2D, num_filters)
        self.act = activation
        self.is2D = is2D
        self.k_size = k_size
        self.stride = stride
        self.pad = pad
        self.name = name

    def forward(self, input):
        if self.pad in ['reflect', 'same']:
            if self.is2D:
                my_pad = _padding_size_2d(input.shape[-2], input.shape[-1], self.k_size, self.stride)
            else:
                my_pad = _padding_size_3d(input.shape[-3], input.shape[-2], input.shape[-1], self.k_size, self.stride)
            input = _my_pad_fn(input, my_pad, self.pad, self.is2D, self.name)
        out = self.conv_layer(input)
        if self.norm_layer is not None:
            out = self.norm_layer(out)
        if self.act is not None:
            if self.act == 'relu':
                out = torch.nn.functional.relu(out)
        return out


class MyDeconvBlock(torch.nn.Module):
    def __init__(self, fix_layer_tool, ch_in, num_filters, name, k_size, stride,
                 is2D=True, norm='instance', activation='relu'):
        assert norm in ['instance', None]
        assert activation in ['relu', None]
        super(MyDeconvBlock, self).__init__()
        self.deconv_layer = fix_layer_tool._deconv(ch_in, num_filters, name + '.DCB',
                                                   k_size, stride, is2D)
        self.norm_layer = None
        if norm is not None:
            if norm == 'instance':
                self.norm_layer = fix_layer_tool._norm(name + '.DCB', is2D, num_filters)
        self.act = activation
        self.is2D = is2D
        self.k_size = k_size
        self.stride = stride

    def forward(self, input):
        out = self.deconv_layer(input)
        if self.is2D:
            has_pad = _padding_size_2d(input.shape[-2] * self.stride, input.shape[-1] * self.stride, self.k_size,
                                       self.stride)
            out = out[:, :, has_pad[2]:out.shape[-2] - has_pad[3], has_pad[0]:out.shape[-1] - has_pad[1]]
        else:
            has_pad = _padding_size_3d(input.shape[-3] * self.stride, input.shape[-2] * self.stride,
                                       input.shape[-1] * self.stride, self.k_size, self.stride)
            out = out[:, :, has_pad[4]:out.shape[-3] - has_pad[5], has_pad[2]:out.shape[-2] - has_pad[3],
                  has_pad[0]:out.shape[-1] - has_pad[1]]

        if self.norm_layer is not None:
            out = self.norm_layer(out)
        if self.act is not None:
            if self.act == 'relu':
                out = torch.nn.functional.relu(out)
        return out


class MyResBlock(torch.nn.Module):
    def __init__(self, fix_layer_tool, num_filters, name,
                 pad='reflect', is2D=True, norm='instance'):
        assert norm in ['instance', None]
        super(MyResBlock, self).__init__()
        self.res_layers = fix_layer_tool._residual(num_filters, name + '.RES', is2D, norm=norm)
        self.is2D = is2D
        self.pad = pad
        self.name = name

    def forward(self, input):
        out = input
        if self.pad in ['reflect', 'same']:
            if self.is2D:
                my_pad = _padding_size_2d(input.shape[-2], input.shape[-1], 3, 1)
            else:
                my_pad = _padding_size_3d(input.shape[-3], input.shape[-2], input.shape[-1], 3, 1)
            out = _my_pad_fn(input, my_pad, self.pad, self.is2D, self.name)

        out = self.res_layers[0](out)
        if self.res_layers[2] is not None:
            out = self.res_layers[2](out)
        out = torch.nn.functional.relu(out)

        if self.pad in ['reflect', 'same']:
            if self.is2D:
                my_pad = _padding_size_2d(out.shape[-2], out.shape[-1], 3, 1)
            else:
                my_pad = _padding_size_3d(out.shape[-3], out.shape[-2], out.shape[-1], 3, 1)
            out = _my_pad_fn(out, my_pad, self.pad, self.is2D, self.name)
        out = self.res_layers[1](out)
        if self.res_layers[3] is not None:
            out = self.res_layers[3](out)

        return torch.nn.functional.relu(input + out)  # residuals


class DenToVel(torch.nn.Module):
    def init_configuration(self, obsFlag=False, withZoom=True):
        # set all hyper-parameters for den2vel model
        updateDict = {
            'OpenBounds': True,
            'adv_order': 2,
            'buoy': 2.0,
            'is2D': False,
            'useEnergy': True,
            'usePhy': True,
            'encPhy': True,
            'useVortEnd': True,
            'obsFlags': obsFlag,
            'obsMoving': obsFlag,
            'batch_size': 1,  # 1 step
            'crop_size': 64,  # 3D
            'mid_ch': 16,
            'phy_len': 2,
            'num_resblock': 6,
            'zoom_factor': 4.0 if withZoom else -1.0,
            'blend_st': -1,
            'Dst_Flag': 0,
            'selfPhy': False,
            'withRef': False,
            'mode': 'inference'
        }
        namelist = list(updateDict)
        valuelist = list(map(updateDict.get, namelist))

        Params = collections.namedtuple('Params', ",".join(namelist))
        tmpFLAGS = Params._make(valuelist)
        #print(tmpFLAGS)
        return tmpFLAGS

    def encode_param_layer(self, FLAGS, ch_in):
        # mid_in FLAGS.batch_size x m, 8x8x16
        name = 'encPhy.'
        mid_ch = FLAGS.mid_ch
        dim = 2 if FLAGS.is2D else 3
        mid_shape = [mid_ch] + [8] * dim

        M3 = MyConvBlock(self.fix_layer_tool, ch_in, mid_ch, name + 'd16c7', 7, 1,
                         pad='reflect', is2D=FLAGS.is2D, norm=None, activation=self._activation)

        if FLAGS.useEnergy:
            E2 = MyConvBlock(self.fix_layer_tool, mid_ch, mid_ch // 2, name + 'd8', 3, 2,
                             pad='same', is2D=FLAGS.is2D, norm=self._norm, activation=self._activation)

            E1 = MyConvBlock(self.fix_layer_tool, mid_ch // 2, mid_ch // 4, name + 'd4', 3, 2,
                             pad='same', is2D=FLAGS.is2D, norm=self._norm, activation=self._activation)

            tar_depth = 1
            E0 = MyConvBlock(self.fix_layer_tool, mid_ch // 4, tar_depth, name + 'd3s1-2', 3, 1,
                             pad='same', is2D=FLAGS.is2D, norm=None, activation=None)

            self.KE_list = [E2, E1, E0]
        else:
            self.KE_list = []

        M1 = MyConvBlock(self.fix_layer_tool, mid_ch, mid_ch, name + 'd16c3', 3, 1,
                         pad='same', is2D=FLAGS.is2D, norm=None, activation=None)
        M0 = self.fix_layer_tool._fully_connected(np.prod(mid_shape), FLAGS.phy_len, act=None, name=name + 'm1fc2')
        self.PHY_list = [M3, M1, M0]

    def build_param_layer(self, FLAGS):
        param_n = FLAGS.phy_len
        dim = 2 if FLAGS.is2D else 3
        mid_ch = FLAGS.mid_ch
        mid_shape = [mid_ch] + [8] * dim
        name = 'buildPhy.'

        M1 = self.fix_layer_tool._fully_connected(param_n * 2, np.prod(mid_shape), 'relu',
                                                  name + 'm1fc16')
        M2 = MyConvBlock(self.fix_layer_tool, mid_ch, mid_ch, name + 'd16', 7, 1,
                         pad='reflect', is2D=FLAGS.is2D, norm=None, activation=None)

        if FLAGS.useEnergy:
            energy_c1 = MyConvBlock(self.fix_layer_tool, 2, 2, name + 'energy-d2', 3, 1,
                                    pad='reflect', is2D=FLAGS.is2D, norm=None, activation=None)
            energy_conv = MyConvBlock(self.fix_layer_tool, 2, 1, name + 'energy-d1', 3, 1,
                                      pad='reflect', is2D=FLAGS.is2D, norm=None, activation=None)
        self.PHYin_list = [M1, M2]
        self.KEin_list = [energy_c1, energy_conv]

    def __init__(self, ckpt='./data/d2v_3Dmodel.npz', obsFlag=False, withZoom=False):
        super(DenToVel, self).__init__()
        FLAGS = self.init_configuration(obsFlag, withZoom)
        # 110 variables, load all weights as constant/fixed variable
        namelist = [
            "generator/zoom_in/LS-8/w", "generator/zoom_in/LS-8/b", "generator/zoom_in/LS-out/w",
            "generator/zoom_in/LS-out/b", "generator/c7s1-32/w", "generator/c7s1-32/b",
            "generator/c7s1-32/instance_norm/shift", "generator/c7s1-32/instance_norm/scale", "generator/d64/w",
            "generator/d64/b", "generator/d64/instance_norm/shift", "generator/d64/instance_norm/scale",
            "generator/d128/w", "generator/d128/b", "generator/d128/instance_norm/shift",
            "generator/d128/instance_norm/scale", "generator/R128_0/res1/w",
            "generator/R128_0/res1/instance_norm/shift", "generator/R128_0/res1/instance_norm/scale",
            "generator/R128_0/res2/w", "generator/R128_0/res2/instance_norm/shift",
            "generator/R128_0/res2/instance_norm/scale", "generator/R128_1/res1/w",
            "generator/R128_1/res1/instance_norm/shift", "generator/R128_1/res1/instance_norm/scale",
            "generator/R128_1/res2/w", "generator/R128_1/res2/instance_norm/shift",
            "generator/R128_1/res2/instance_norm/scale", "generator/R128_2/res1/w",
            "generator/R128_2/res1/instance_norm/shift", "generator/R128_2/res1/instance_norm/scale",
            "generator/R128_2/res2/w", "generator/R128_2/res2/instance_norm/shift",
            "generator/R128_2/res2/instance_norm/scale", "generator/R128_3/res1/w",
            "generator/R128_3/res1/instance_norm/shift", "generator/R128_3/res1/instance_norm/scale",
            "generator/R128_3/res2/w", "generator/R128_3/res2/instance_norm/shift",
            "generator/R128_3/res2/instance_norm/scale", "generator/encPhy/d16c7/w", "generator/encPhy/d16c7/b",
            "generator/encPhy/d8/w", "generator/encPhy/d8/b", "generator/encPhy/d8/instance_norm/shift",
            "generator/encPhy/d8/instance_norm/scale", "generator/encPhy/d4/w", "generator/encPhy/d4/b",
            "generator/encPhy/d4/instance_norm/shift", "generator/encPhy/d4/instance_norm/scale",
            "generator/encPhy/d3s1-2/w", "generator/encPhy/d3s1-2/b", "generator/encPhy/d16c3/w",
            "generator/encPhy/d16c3/b", "generator/encPhy/m1fc2/fc", "generator/buildPhy/m1fc16/fc",
            "generator/buildPhy/d16/w", "generator/buildPhy/d16/b", "generator/buildPhy/energy-d2/w",
            "generator/buildPhy/energy-d2/b", "generator/buildPhy/energy-d1/w", "generator/buildPhy/energy-d1/b",
            "generator/R128_4/res1/w", "generator/R128_4/res1/instance_norm/shift",
            "generator/R128_4/res1/instance_norm/scale", "generator/R128_4/res2/w",
            "generator/R128_4/res2/instance_norm/shift", "generator/R128_4/res2/instance_norm/scale",
            "generator/R128_5/res1/w", "generator/R128_5/res1/instance_norm/shift",
            "generator/R128_5/res1/instance_norm/scale", "generator/R128_5/res2/w",
            "generator/R128_5/res2/instance_norm/shift", "generator/R128_5/res2/instance_norm/scale",
            "generator/vort_u32/w", "generator/vort_u32/b", "generator/vort_u32/instance_norm/shift",
            "generator/vort_u32/instance_norm/scale", "generator/vort_u16/w", "generator/vort_u16/b",
            "generator/vort_u16/instance_norm/shift", "generator/vort_u16/instance_norm/scale", "generator/vort_v1/w",
            "generator/vort_v1/b", "generator/vein_16/w", "generator/vein_16/b",
            "generator/vein_16/instance_norm/shift", "generator/vein_16/instance_norm/scale", "generator/vein_24/w",
            "generator/vein_24/b", "generator/vein_24/instance_norm/shift", "generator/vein_24/instance_norm/scale",
            "generator/vein_32/w", "generator/vein_32/b", "generator/vein_32/instance_norm/shift",
            "generator/vein_32/instance_norm/scale", "generator/u64/w", "generator/u64/b",
            "generator/u64/instance_norm/shift", "generator/u64/instance_norm/scale", "generator/u32/w",
            "generator/u32/b", "generator/u32/instance_norm/shift", "generator/u32/instance_norm/scale",
            "generator/c7s1-3/w", "generator/c7s1-3/b", "generator/zoom_out/HS-8/w", "generator/zoom_out/HS-8/b",
            "generator/zoom_out/HS-out/w", "generator/zoom_out/HS-out/b"
        ]
        if not withZoom: namelist = namelist[4:-4]

        if not os.path.exists(ckpt):
            print(ckpt, "does not exist. Try to download d2v model for training...")
            download_file(ckpt, "https://rachelcmy.github.io/pinf_smoke/data/d2v_3Dmodel.npz")

        self.var_dict = np.load(ckpt, allow_pickle=True)["arr_0"].item()
        var_list = [self.var_dict[k] for k in namelist]
        self.fix_layer_tool = FixLayer(var_list, namelist)

        self.dim = 2 if FLAGS.is2D else 3
        tarsize = FLAGS.crop_size * int(max(FLAGS.zoom_factor, 1))
        self.input_shape = [1] + [1] + [256 if withZoom else 64] * 3
        self.output_shape = [tarsize] * (2 if FLAGS.is2D else 3)
        self._norm = 'instance'
        self._activation = 'relu'

        # start to build the model.
        C1 = MyConvBlock(self.fix_layer_tool, 1, 32, 'c7s1-32', 7, 1,
                         pad='reflect', is2D=FLAGS.is2D, norm=self._norm, activation=self._activation)
        C2 = MyConvBlock(self.fix_layer_tool, 32, 64, 'd64', 3, 2,
                         pad='same', is2D=FLAGS.is2D, norm=self._norm, activation=self._activation)
        C3 = MyConvBlock(self.fix_layer_tool, 64, 128, 'd128', 3, 2,
                         pad='same', is2D=FLAGS.is2D, norm=self._norm, activation=self._activation)

        self.Cin_list = [C1, C2, C3]
        res_ch = 128
        self.Res_list = [
            MyResBlock(self.fix_layer_tool, res_ch, 'R128_{}'.format(i),
                       pad='reflect', is2D=FLAGS.is2D, norm=self._norm)
            for i in range(FLAGS.num_resblock // 2 + 1)
        ]
        self.encode_param_layer(FLAGS, res_ch)
        self.build_param_layer(FLAGS)
        if FLAGS.useEnergy:
            res_ch = res_ch + 1
        res_ch = res_ch + FLAGS.mid_ch
        self.Res_list = self.Res_list + [
            MyResBlock(self.fix_layer_tool, res_ch, 'R128_{}'.format(i),
                       pad='reflect', is2D=FLAGS.is2D, norm=self._norm)
            for i in range(FLAGS.num_resblock // 2 + 1, FLAGS.num_resblock)
        ]
        if FLAGS.useVortEnd:
            vort_lvl1 = MyDeconvBlock(self.fix_layer_tool, res_ch, 32, 'vort_u32', 3, 2,
                                      is2D=FLAGS.is2D, norm=self._norm, activation=self._activation)
            vort_lvl2 = MyDeconvBlock(self.fix_layer_tool, 32, 16, 'vort_u16', 3, 2,
                                      is2D=FLAGS.is2D, norm=self._norm, activation=self._activation)
            enc_vortEnd = MyConvBlock(self.fix_layer_tool, 16, 1 if FLAGS.is2D else 3, 'vort_v1', 7, 1,
                                      pad='reflect', is2D=FLAGS.is2D, norm=None, activation=None)
            self.VORT_list = [vort_lvl1, vort_lvl2, enc_vortEnd]
            # ve_in = make_ve_in()
            vein_lvl2 = MyConvBlock(self.fix_layer_tool, 2 if FLAGS.is2D else 6, 16, 'vein_16', 7, 1,
                                    pad='reflect', is2D=FLAGS.is2D, norm=self._norm, activation=self._activation)
            vein_lvl1 = MyConvBlock(self.fix_layer_tool, 16, 24, 'vein_24', 3, 2,
                                    pad='same', is2D=FLAGS.is2D, norm=self._norm, activation=self._activation)
            vein_lvl = MyConvBlock(self.fix_layer_tool, 24, 32, 'vein_32', 3, 2,
                                   pad='same', is2D=FLAGS.is2D, norm=self._norm, activation=self._activation)
            self.VORTin_list = [vein_lvl2, vein_lvl1, vein_lvl]
            res_ch = res_ch + 32
            # G = tf.concat([G, vein_lvl], axis = -1)
        Out1 = MyDeconvBlock(self.fix_layer_tool, res_ch, 64, 'u64', 3, 2,
                             is2D=FLAGS.is2D, norm=self._norm, activation=self._activation)
        Out2 = MyDeconvBlock(self.fix_layer_tool, 64, 32, 'u32', 3, 2,
                             is2D=FLAGS.is2D, norm=self._norm, activation=self._activation)
        out_ch = 1 if FLAGS.is2D else 3
        Out3 = MyConvBlock(self.fix_layer_tool, 32, out_ch, 'c7s1-3', 7, 1,
                           pad='reflect', is2D=FLAGS.is2D, norm=None, activation=None)
        self.Out_list = [Out1, Out2, Out3]

        self.FLAGS = FLAGS

    def forward(self, den_input):
        dim = 2 if self.FLAGS.is2D else 3
        pool_k = [self.FLAGS.crop_size // 4 // 8] * dim
        pool_fn = torch.nn.AvgPool2d if self.FLAGS.is2D else torch.nn.AvgPool3d
        param_n = self.FLAGS.phy_len
        mid_ch = self.FLAGS.mid_ch
        mid_shape = [mid_ch] + [8] * dim
        permute_mid_shape = [8] * dim + [mid_ch]
        mid_repeat = self.FLAGS.crop_size // 4 // 8

        phy_input = torch.Tensor([[self.FLAGS.buoy, float(self.FLAGS.OpenBounds)]])

        g = den_input
        for Cin in self.Cin_list:
            g = Cin(g)

        for i in range(self.FLAGS.num_resblock):
            g = self.Res_list[i](g)
            if (i == self.FLAGS.num_resblock // 2):
                phy_g = self.PHY_list[0](g)
                if self.FLAGS.useEnergy:
                    KE_g = phy_g
                    for KE_fn in self.KE_list:
                        KE_g = KE_fn(KE_g)
                # mid_in FLAGS.batch_size x m, 8x8x16
                phy_g = pool_fn(pool_k, stride=pool_k, padding=0)(phy_g)
                phy_g = self.PHY_list[1](phy_g)
                phy_g = phy_g.permute([0, 2, 3, 4, 1])
                phy_g = torch.reshape(phy_g, (self.FLAGS.batch_size, -1))
                phy_g = self.PHY_list[2](phy_g)

                phy_in = phy_input.expand(phy_g.shape)
                phy_in = torch.cat([phy_g, phy_in], dim=1)

                if self.FLAGS.useEnergy:
                    KE_in = torch.ones_like(KE_g) * -1.0
                    KE_in = torch.cat([KE_g, KE_in], dim=1)

                phy_g = self.PHYin_list[0](phy_in)
                phy_g = phy_g.view([-1] + permute_mid_shape)  # bx16x8x8(x8)
                phy_g = phy_g.permute([0, 4, 1, 2, 3])
                phy_g = self.PHYin_list[1](phy_g)
                phy_g = torch.repeat_interleave(phy_g, mid_repeat, dim=2)
                phy_g = torch.repeat_interleave(phy_g, mid_repeat, dim=3)
                if dim == 3:
                    phy_g = torch.repeat_interleave(phy_g, mid_repeat, dim=4)
                if self.FLAGS.useEnergy:
                    KE_g = self.KEin_list[0](KE_in)
                    KE_g = self.KEin_list[1](KE_g)
                    KE_g = torch.repeat_interleave(KE_g, 4, dim=2)
                    KE_g = torch.repeat_interleave(KE_g, 4, dim=3)
                    if dim == 3:
                        KE_g = torch.repeat_interleave(KE_g, 4, dim=4)
                    phy_g = torch.cat([phy_g, KE_g], dim=1)
                g = torch.cat([g, phy_g], dim=1)
        # vort
        if self.FLAGS.useVortEnd:
            v_g = g
            for v_fn in self.VORT_list:
                v_g = v_fn(v_g)
            ve_in = torch.ones_like(v_g) * -10.0
            v_g = torch.cat([v_g, ve_in], dim=1)

            for v_fn in self.VORTin_list:
                v_g = v_fn(v_g)
            g = torch.cat([g, v_g], dim=1)
        for out_fn in self.Out_list:
            g = out_fn(g)
        # from NCHW to NHWC
        permute_order = [0, 2, 3, 1]
        if not self.FLAGS.is2D:
            permute_order = [0, 2, 3, 4, 1]
        g = g.permute(permute_order)
        # curl in NHWC mode, return NHW3 as velocity
        g = (curl2D(g) if dim == 2 else curl3D(g))
        return g


d2v_model = DenToVel() if args.d2vW > 1e-8 else None

In [11]:
from tqdm import trange

img2mse = lambda x, y: torch.mean((x - y) ** 2)
mse2psnr = lambda x: -10. * torch.log(x) / torch.log(torch.Tensor([10.]))


def fade_in_weight(step, start, duration):
    return min(max((float(step) - start) / duration, 0.0), 1.0)


def jacobian3D(x):
    # x, (b,)d,h,w,ch, pytorch tensor
    # return jacobian and curl

    dudx = x[:, :, :, 1:, 0] - x[:, :, :, :-1, 0]
    dvdx = x[:, :, :, 1:, 1] - x[:, :, :, :-1, 1]
    dwdx = x[:, :, :, 1:, 2] - x[:, :, :, :-1, 2]
    dudy = x[:, :, 1:, :, 0] - x[:, :, :-1, :, 0]
    dvdy = x[:, :, 1:, :, 1] - x[:, :, :-1, :, 1]
    dwdy = x[:, :, 1:, :, 2] - x[:, :, :-1, :, 2]
    dudz = x[:, 1:, :, :, 0] - x[:, :-1, :, :, 0]
    dvdz = x[:, 1:, :, :, 1] - x[:, :-1, :, :, 1]
    dwdz = x[:, 1:, :, :, 2] - x[:, :-1, :, :, 2]

    # u = dwdy[:,:-1,:,:-1] - dvdz[:,:,1:,:-1]
    # v = dudz[:,:,1:,:-1] - dwdx[:,:-1,1:,:]
    # w = dvdx[:,:-1,1:,:] - dudy[:,:-1,:,:-1]

    dudx = torch.cat((dudx, torch.unsqueeze(dudx[:, :, :, -1], 3)), 3)
    dvdx = torch.cat((dvdx, torch.unsqueeze(dvdx[:, :, :, -1], 3)), 3)
    dwdx = torch.cat((dwdx, torch.unsqueeze(dwdx[:, :, :, -1], 3)), 3)

    dudy = torch.cat((dudy, torch.unsqueeze(dudy[:, :, -1, :], 2)), 2)
    dvdy = torch.cat((dvdy, torch.unsqueeze(dvdy[:, :, -1, :], 2)), 2)
    dwdy = torch.cat((dwdy, torch.unsqueeze(dwdy[:, :, -1, :], 2)), 2)

    dudz = torch.cat((dudz, torch.unsqueeze(dudz[:, -1, :, :], 1)), 1)
    dvdz = torch.cat((dvdz, torch.unsqueeze(dvdz[:, -1, :, :], 1)), 1)
    dwdz = torch.cat((dwdz, torch.unsqueeze(dwdz[:, -1, :, :], 1)), 1)

    u = dwdy - dvdz
    v = dudz - dwdx
    w = dvdx - dudy

    j = torch.stack([dudx, dudy, dudz, dvdx, dvdy, dvdz, dwdx, dwdy, dwdz], -1)
    c = torch.stack([u, v, w], -1)

    return j, c


def vel_smoke2world(Vsmoke, s2w, scale_vector, st_factor):
    _st_factor = torch.Tensor(st_factor).expand((3,))
    vel_scale = Vsmoke * (scale_vector) / _st_factor  # 2.simulation to 3.target
    vel_rot = torch.sum(vel_scale[..., None, :] * (s2w[:3, :3]), -1)  # 3.target to 4.world
    return vel_rot


def off_smoke2world(Offsmoke, s2w, scale_vector):
    off_scale = Offsmoke * (scale_vector)  # 2.simulation to 3.target
    off_rot = torch.sum(off_scale[..., None, :] * (s2w[:3, :3]), -1)  # 3.target to 4.world
    return off_rot


def PDE_EQs(D_t, D_x, D_y, D_z, U, U_t=None, U_x=None, U_y=None, U_z=None):
    eqs = []
    dts = [D_t]
    dxs = [D_x]
    dys = [D_y]
    dzs = [D_z]

    if None not in [U_t, U_x, U_y, U_z]:
        dts += U_t.split(1, dim=-1)  # [d_t, u_t, v_t, w_t] # (N,1)
        dxs += U_x.split(1, dim=-1)  # [d_x, u_x, v_x, w_x]
        dys += U_y.split(1, dim=-1)  # [d_y, u_y, v_y, w_y]
        dzs += U_z.split(1, dim=-1)  # [d_z, u_z, v_z, w_z]

    u, v, w = U.split(1, dim=-1)  # (N,1)
    for dt, dx, dy, dz in zip(dts, dxs, dys, dzs):
        _e = dt + (u * dx + v * dy + w * dz)
        eqs += [_e]
    # transport and nse equations:
    # e1 = d_t + (u*d_x + v*d_y + w*d_z) - PecInv*(c_xx + c_yy + c_zz)          , should = 0
    # e2 = u_t + (u*u_x + v*u_y + w*u_z) + p_x - ReyInv*(u_xx + u_yy + u_zz)    , should = 0
    # e3 = v_t + (u*v_x + v*v_y + w*v_z) + p_y - ReyInv*(v_xx + v_yy + v_zz)    , should = 0
    # e4 = w_t + (u*w_x + v*w_y + w*w_z) + p_z - ReyInv*(w_xx + w_yy + w_zz)    , should = 0
    # e5 = u_x + v_y + w_z                                                      , should = 0
    # For simplification, we assume PecInv = 0.0, ReyInv = 0.0, pressure p = (0,0,0)                      

    if None not in [U_t, U_x, U_y, U_z]:
        # eqs += [ u_x + v_y + w_z ]
        eqs += [dxs[1] + dys[2] + dzs[3]]

    if True:  # scale regularization
        eqs += [(u * u + v * v + w * w) * 1e-1]

    return eqs


def mean_squared_error(pred, exact):
    if type(pred) is np.ndarray:
        return np.mean(np.square(pred - exact))
    return torch.mean(torch.square(pred - exact))


def get_rays(H, W, K, c2w):
    i, j = torch.meshgrid(torch.linspace(0, W - 1, W),
                          torch.linspace(0, H - 1, H))  # pytorch's meshgrid has indexing='ij'
    i = i.t()
    j = j.t()
    dirs = torch.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -torch.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3, :3],
                       -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = c2w[:3, -1].expand(rays_d.shape)
    return rays_o, rays_d


def ndc_rays(H, W, focal, near, rays_o, rays_d):
    # Shift ray origins to near plane
    t = -(near + rays_o[..., 2]) / rays_d[..., 2]
    rays_o = rays_o + t[..., None] * rays_d

    # Projection
    o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
    o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
    o2 = 1. + 2. * near / rays_o[..., 2]

    d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
    d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
    d2 = -2. * near / rays_o[..., 2]

    rays_o = torch.stack([o0, o1, o2], -1)
    rays_d = torch.stack([d0, d1, d2], -1)

    return rays_o, rays_d


def raw2outputs(raw_list, z_vals, rays_d, raw_noise_std=0, pytest=False, remove99=False):
    """Transforms model's predictions to semantically meaningful values.
    Args:
        raw_list: a list of tensors in shape [num_rays, num_samples along ray, 4]. Prediction from model.
        z_vals: [num_rays, num_samples along ray]. Integration time.
        rays_d: [num_rays, 3]. Direction of each ray.
    Returns:
        rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
        disp_map: [num_rays]. Disparity map. Inverse of depth map.
        acc_map: [num_rays]. Sum of weights along each ray.
        weights: [num_rays, num_samples]. Weights assigned to each sampled color.
        depth_map: [num_rays]. Estimated distance to object.
    """
    raw2alpha = lambda raw, dists, act_fn=torch.nn.functional.relu: 1. - torch.exp(-act_fn(raw) * dists)

    dists = z_vals[..., 1:] - z_vals[..., :-1]
    dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[..., :1].shape)], -1)  # [N_rays, N_samples]

    dists = dists * torch.norm(rays_d[..., None, :], dim=-1)

    noise = 0.
    alpha_list = []
    color_list = []
    for raw in raw_list:
        if raw is None: continue
        if raw_noise_std > 0.:
            noise = torch.randn(raw[..., 3].shape) * raw_noise_std

            # Overwrite randomly sampled data if pytest
            if pytest:
                np.random.seed(42)
                noise = np.random.rand(*list(raw[..., 3].shape)) * raw_noise_std
                noise = torch.Tensor(noise)

        alpha = raw2alpha(raw[..., 3] + noise, dists)  # [N_rays, N_samples]
        if remove99:
            alpha = torch.where(alpha > 0.99, torch.zeros_like(alpha), alpha)
        rgb = torch.sigmoid(raw[..., :3])  # [N_rays, N_samples, 3]

        alpha_list += [alpha]
        color_list += [rgb]

    densTiStack = torch.stack([1. - alpha for alpha in alpha_list], dim=-1)
    # [N_rays, N_samples, N_raws]
    densTi = torch.prod(densTiStack, dim=-1, keepdim=True)
    # [N_rays, N_samples]
    densTi_all = torch.cat([densTiStack, densTi], dim=-1)
    # [N_rays, N_samples, N_raws + 1] 
    Ti_all = torch.cumprod(densTi_all + 1e-10, dim=-2)  # accu along samples
    Ti_all = Ti_all / (densTi_all + 1e-10)
    # [N_rays, N_samples, N_raws + 1], exclusive
    weights_list = [alpha * Ti_all[..., -1] for alpha in alpha_list]  # a list of [N_rays, N_samples]
    self_weights_list = [alpha_list[alpha_i] * Ti_all[..., alpha_i] for alpha_i in
                         range(len(alpha_list))]  # a list of [N_rays, N_samples]

    def weighted_sum_of_samples(wei_list, content_list=None, content=None):
        content_map_list = []
        if content_list is not None:
            content_map_list = [
                torch.sum(weights[..., None] * ct, dim=-2)
                # [N_rays, N_content], weighted sum along samples
                for weights, ct in zip(wei_list, content_list)
            ]
        elif content is not None:
            content_map_list = [
                torch.sum(weights * content, dim=-1)
                # [N_rays], weighted sum along samples
                for weights in wei_list
            ]
        content_map = torch.stack(content_map_list, dim=-1)
        # [N_rays, (N_contentlist,) N_raws]
        content_sum = torch.sum(content_map, dim=-1)
        # [N_rays, (N_contentlist,)]
        return content_sum, content_map

    rgb_map, _ = weighted_sum_of_samples(weights_list, color_list)  # [N_rays, 3]
    # Sum of weights along each ray. This value is in [0, 1] up to numerical error.
    acc_map, _ = weighted_sum_of_samples(weights_list, None, 1)  # [N_rays]

    _, rgb_map_stack = weighted_sum_of_samples(self_weights_list, color_list)
    _, acc_map_stack = weighted_sum_of_samples(self_weights_list, None, 1)

    # Estimated depth map is expected distance.
    # Disparity map is inverse depth.
    depth_map, _ = weighted_sum_of_samples(weights_list, None, z_vals)  # [N_rays]
    disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map), depth_map / acc_map)
    # alpha * Ti
    weights = (1. - densTi)[..., 0] * Ti_all[..., -1]  # [N_rays, N_samples]

    # weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
    # rgb_map = torch.sum(weights[...,None] * rgb, -2)  # [N_rays, 3]
    # depth_map = torch.sum(weights * z_vals, -1)
    # acc_map = torch.sum(weights, -1)

    return rgb_map, disp_map, acc_map, weights, depth_map, Ti_all[..., -1], rgb_map_stack, acc_map_stack


# Hierarchical sampling (section 5.2)
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
    device = weights.get_device()
    # Get pdf
    weights = weights + 1e-5  # prevent nans
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat(
        [torch.zeros_like(cdf[..., :1], device=device), cdf], -1
    )  # (batch, len(bins))

    # Take uniform samples
    if det:
        u = torch.linspace(0.0, 1.0, steps=N_samples, device=device)
        u = u.expand(list(cdf.shape[:-1]) + [N_samples])
    else:
        u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=device)

    # Pytest, overwrite u with numpy's fixed random numbers
    if pytest:
        np.random.seed(0)
        new_shape = list(cdf.shape[:-1]) + [N_samples]
        if det:
            u = np.linspace(0.0, 1.0, N_samples)
            u = np.broadcast_to(u, new_shape)
        else:
            u = np.random.rand(*new_shape)
        u = torch.Tensor(u).to(device)

    # Invert CDF
    u = u.contiguous()
    inds = torch.searchsorted(cdf.detach(), u, right=True)

    below = torch.max(torch.zeros_like(inds - 1, device=device), inds - 1)
    above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds, device=device), inds)
    inds_g = torch.stack([below, above], -1)  # (batch, N_samples, 2)

    # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

    denom = cdf_g[..., 1] - cdf_g[..., 0]
    denom = torch.where(denom < 1e-5, torch.ones_like(denom, device=device), denom)
    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples


def render_rays(ray_batch,
                network_fn,
                network_query_fn,
                N_samples,
                retraw=False,
                lindisp=False,
                perturb=0.,
                N_importance=0,
                network_fine=None,
                raw_noise_std=0.,
                verbose=False,
                pytest=False,
                has_t=False,
                vel_model=None,
                netchunk=1024 * 64,
                warp_fading_dt=None,
                warp_mod="rand",
                remove99=False):
    """Volumetric rendering.
    Args:
      ray_batch: array of shape [batch_size, ...]. All information necessary
        for sampling along a ray, including: ray origin, ray direction, min
        dist, max dist, and unit-magnitude viewing direction.
      network_fn: function. Model for predicting RGB and density at each point
        in space.
      network_query_fn: function used for passing queries to network_fn.
      N_samples: int. Number of different times to sample along each ray.
      retraw: bool. If True, include model's raw, unprocessed predictions.
      lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
      perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
        random points in time.
      N_importance: int. Number of additional times to sample along each ray.
        These samples are only passed to network_fine.
      network_fine: "fine" network with same spec as network_fn.
      raw_noise_std: ...
      verbose: bool. If True, print more debugging info.

      warp_fading_dt, to train nearby frames with flow-based warping, fading*delt_t
    Returns:
      rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
      disp_map: [num_rays]. Disparity map. 1 / depth.
      acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
      raw: [num_rays, num_samples, 4]. Raw predictions from model.
      rgb0: See rgb_map. Output for coarse model.
      disp0: See disp_map. Output for coarse model.
      acc0: See acc_map. Output for coarse model.
      z_std: [num_rays]. Standard deviation of distances along ray for each
        sample.
    """
    N_rays = ray_batch.shape[0]
    rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6]  # [N_rays, 3] each
    rays_t, viewdirs = None, None
    if has_t:
        rays_t = ray_batch[:, -1:]  # [N_rays, 1]
        viewdirs = ray_batch[:, -4:-1] if ray_batch.shape[-1] > 9 else None
    elif ray_batch.shape[-1] > 8:
        viewdirs = ray_batch[:, -3:]

    bounds = torch.reshape(ray_batch[..., 6:8], [-1, 1, 2])
    near, far = bounds[..., 0], bounds[..., 1]  # [-1,1]

    t_vals = torch.linspace(0., 1., steps=N_samples)
    if not lindisp:
        z_vals = near * (1. - t_vals) + far * (t_vals)
    else:
        z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * (t_vals))

    z_vals = z_vals.expand([N_rays, N_samples])

    if perturb > 0.:
        # get intervals between samples
        mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
        upper = torch.cat([mids, z_vals[..., -1:]], -1)
        lower = torch.cat([z_vals[..., :1], mids], -1)
        # stratified samples in those intervals
        t_rand = torch.rand(z_vals.shape)

        # Pytest, overwrite u with numpy's fixed random numbers
        if pytest:
            np.random.seed(42)
            t_rand = np.random.rand(*list(z_vals.shape))
            t_rand = torch.Tensor(t_rand)

        z_vals = lower + (upper - lower) * t_rand

    pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]  # [N_rays, N_samples, 3]
    if rays_t is not None:
        rays_t_bc = torch.reshape(rays_t, [-1, 1, 1]).expand([N_rays, N_samples, 1])
        pts = torch.cat([pts, rays_t_bc], dim=-1)

    def warp_raw_random(orig_pts, orig_dir, fading, fn, mod="rand", has_t=has_t):
        # mod, "rand", "forw", "back", "none"
        if (not has_t) or (mod == "none") or (vel_model is None):
            orig_raw = network_query_fn(orig_pts, orig_dir, fn)  # [N_rays, N_samples, 4]
            return orig_raw

        orig_pos, orig_t = torch.split(orig_pts, [3, 1], -1)

        _vel = batchify(vel_model, netchunk)(orig_pts.view(-1, 4))
        _vel = torch.reshape(_vel, [N_rays, -1, 3])
        # _vel.shape, [N_rays, N_samples(+N_importance), 3]
        if mod == "rand":
            # random_warpT = np.random.normal(0.0, 0.6, orig_t.get_shape().as_list())
            # random_warpT = np.random.uniform(-3.0, 3.0, orig_t.shape)
            random_warpT = torch.rand(orig_t.shape) * 6.0 - 3.0  # [-3,3]
        else:
            random_warpT = 1.0 if mod == "back" else (-1.0)  # back
        # mean and standard deviation: 0.0, 0.6, so that 3sigma < 2, train +/- 2*delta_T
        random_warpT = random_warpT * fading
        random_warpT = torch.Tensor(random_warpT)

        warp_t = orig_t + random_warpT
        warp_pos = orig_pos + _vel * random_warpT
        warp_pts = torch.cat([warp_pos, warp_t], dim=-1)
        warp_pts = warp_pts.detach()  # stop gradiant

        warped_raw = network_query_fn(warp_pts, orig_dir, fn)  # [N_rays, N_samples, 4]

        return warped_raw

    def get_raw(fn, staticpts, staticdirs, has_t=has_t):
        static_raw, smoke_raw = None, None
        smoke_warp_mod = warp_mod
        if (None in [vel_model, warp_fading_dt]) or (not has_t):
            smoke_warp_mod = "none"

        smoke_raw = warp_raw_random(staticpts, staticdirs, warp_fading_dt, fn, mod=smoke_warp_mod, has_t=has_t)
        if has_t and (smoke_raw.shape[-1] > 4):  # hybrid mode
            if smoke_warp_mod == "none":
                static_raw = smoke_raw
            else:
                static_raw = warp_raw_random(staticpts, staticdirs, warp_fading_dt, fn, mod="none", has_t=True)

            static_raw = static_raw[..., :4]
            smoke_raw = smoke_raw[..., -4:]

        return smoke_raw, static_raw  # [N_rays, N_samples, 4], [N_rays, N_samples, 4]

    # raw = run_network(pts)
    C_smokeRaw, C_staticRaw = get_raw(network_fn, pts, viewdirs)
    raw = [C_smokeRaw, C_staticRaw]
    rgb_map, disp_map, acc_map, weights, depth_map, ti_map, rgb_map_stack, acc_map_stack = raw2outputs(raw, z_vals,
                                                                                                       rays_d,
                                                                                                       raw_noise_std,
                                                                                                       pytest=pytest,
                                                                                                       remove99=remove99)

    if raw[-1] is not None:
        rgbh2_map = rgb_map_stack[..., 0]  # dynamic
        acch2_map = acc_map_stack[..., 0]  # dynamic
        rgbh1_map = rgb_map_stack[..., 1]  # staitc
        acch1_map = acc_map_stack[..., 1]  # staitc

    # raw = network_query_fn(pts, viewdirs, network_fn)
    # rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)

    if N_importance > 0:

        rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map

        z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
        z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], N_importance, det=(perturb == 0.), pytest=pytest)
        z_samples = z_samples.detach()

        z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
        pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :,
                                                            None]  # [N_rays, N_samples + N_importance, 3]

        if rays_t is not None:
            rays_t_bc = torch.reshape(rays_t, [-1, 1, 1]).expand([N_rays, N_samples + N_importance, 1])
            pts = torch.cat([pts, rays_t_bc], dim=-1)

        run_fn = network_fn if network_fine is None else network_fine
        F_smokeRaw, F_staticRaw = get_raw(run_fn, pts, viewdirs)
        raw = [F_smokeRaw, F_staticRaw]

        rgb_map, disp_map, acc_map, weights, depth_map, ti_map, rgb_map_stack, acc_map_stack = raw2outputs(raw, z_vals,
                                                                                                           rays_d,
                                                                                                           raw_noise_std,
                                                                                                           pytest=pytest,
                                                                                                           remove99=remove99)

        if raw[-1] is not None:
            rgbh20_map = rgbh2_map
            acch20_map = acch2_map
            rgbh10_map = rgbh1_map
            acch10_map = acch1_map
            rgbh2_map = rgb_map_stack[..., 0]
            acch2_map = acc_map_stack[..., 0]
            rgbh1_map = rgb_map_stack[..., 1]
            acch1_map = acc_map_stack[..., 1]

        # raw = run_network(pts, fn=run_fn)
        # raw = network_query_fn(pts, viewdirs, run_fn)
        # rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)

    ret = {'rgb_map': rgb_map, 'disp_map': disp_map, 'acc_map': acc_map}
    if retraw:
        ret['raw'] = raw[0]
        if raw[1] is not None:
            ret['raw_static'] = raw[1]
    if N_importance > 0:
        ret['rgb0'] = rgb_map_0
        ret['disp0'] = disp_map_0
        ret['acc0'] = acc_map_0
        ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False)  # [N_rays]

    if raw[-1] is not None:
        ret['rgbh1'] = rgbh1_map
        ret['acch1'] = acch1_map
        ret['rgbh2'] = rgbh2_map
        ret['acch2'] = acch2_map
        if N_importance > 0:
            ret['rgbh10'] = rgbh10_map
            ret['acch10'] = acch10_map
            ret['rgbh20'] = rgbh20_map
            ret['acch20'] = acch20_map
        ret['rgbM'] = rgbh1_map * 0.5 + rgbh2_map * 0.5

    for k in ret:
        if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
            print(f"! [Numerical Error] {k} contains nan or inf.")

    return ret


def batchify_rays(rays_flat, chunk=1024 * 32, **kwargs):
    """Render rays in smaller minibatches to avoid OOM.
    """
    all_ret = {}
    for i in range(0, rays_flat.shape[0], chunk):
        ret = render_rays(rays_flat[i:i + chunk], **kwargs)
        for k in ret:
            if k not in all_ret:
                all_ret[k] = []
            all_ret[k].append(ret[k])

    all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret}
    return all_ret


def render(H, W, K, chunk=1024 * 32, rays=None, c2w=None, ndc=True,
           near=0., far=1.,
           use_viewdirs=False, c2w_staticcam=None,
           time_step=None, bkgd_color=None,
           **kwargs):
    """Render rays
    Args:
      H: int. Height of image in pixels.
      W: int. Width of image in pixels.
      focal: float. Focal length of pinhole camera.
      chunk: int. Maximum number of rays to process simultaneously. Used to
        control maximum memory usage. Does not affect final results.
      rays: array of shape [2, batch_size, 3]. Ray origin and direction for
        each example in batch.
      c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
      ndc: bool. If True, represent ray origin, direction in NDC coordinates.
      near: float or array of shape [batch_size]. Nearest distance for a ray.
      far: float or array of shape [batch_size]. Farthest distance for a ray.
      use_viewdirs: bool. If True, use viewing direction of a point in space in model.
      c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 
       camera while using other c2w argument for viewing directions.
    Returns:
      rgb_map: [batch_size, 3]. Predicted RGB values for rays.
      disp_map: [batch_size]. Disparity map. Inverse of depth.
      acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
      extras: dict with everything returned by render_rays().
    """
    if c2w is not None:
        # special case to render full image
        rays_o, rays_d = get_rays(H, W, K, c2w)
    else:
        # use provided ray batch
        rays_o, rays_d = rays

    if use_viewdirs:
        # provide ray directions as input
        viewdirs = rays_d
        if c2w_staticcam is not None:
            # special case to visualize effect of viewdirs
            rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
        viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
        viewdirs = torch.reshape(viewdirs, [-1, 3]).float()

    sh = rays_d.shape  # [..., 3]
    if ndc:
        # for forward facing scenes
        rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)

    # Create ray batch
    rays_o = torch.reshape(rays_o, [-1, 3]).float()
    rays_d = torch.reshape(rays_d, [-1, 3]).float()

    near, far = near * torch.ones_like(rays_d[..., :1]), far * torch.ones_like(rays_d[..., :1])
    rays = torch.cat([rays_o, rays_d, near, far], -1)
    if use_viewdirs:
        rays = torch.cat([rays, viewdirs], -1)

    if time_step != None:
        time_step = time_step.expand(list(rays.shape[0:-1]) + [1])
        # (ray origin, ray direction, min dist, max dist, normalized viewing direction, t)
        rays = torch.cat([rays, time_step], dim=-1)
    # Render and reshape
    all_ret = batchify_rays(rays, chunk, **kwargs)
    for k in all_ret:
        k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
        all_ret[k] = torch.reshape(all_ret[k], k_sh)

    if bkgd_color is not None:
        torch_bkgd_color = torch.Tensor(bkgd_color).to(device)
        # rgb map for model: fine, coarse, merged, dynamic_fine, dynamic_coarse
        for _i in ['_map', '0', 'h1', 'h10', 'h2',
                   'h20']:  #  add background for synthetic scenes, for image-based supervision
            rgb_i, acc_i = 'rgb' + _i, 'acc' + _i
            if (rgb_i in all_ret) and (acc_i in all_ret):
                all_ret[rgb_i] = all_ret[rgb_i] + torch_bkgd_color * (1. - all_ret[acc_i][..., None])

    k_extract = ['rgb_map', 'disp_map', 'acc_map']
    ret_list = [all_ret[k] for k in k_extract]
    ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract}
    return ret_list + [ret_dict]


for i in trange(start, N_iters):
    # time0 = time.time()
    if args.net_model != "nerf":
        model_fading_update(all_models, global_step, tempoInStep, velInStep, args.net_model == "hybrid")

    # train radiance all the time, train vel less, train with d2v even less.
    trainImg = True
    trainVGG = (args.vggW > 0.0) and (i % 4 == 0)  # less vgg training
    trainVel = (global_step >= (tempoInStep + velInStep)) and (vel_model is not None) and (i % 10 == 0)
    trainD2V = (args.d2vW > 0.0) and (global_step >= (tempoInStep + velInStep * 2)) and trainVel and (i % 20 == 0)

    # fading in for networks
    tempo_fading = fade_in_weight(global_step, tempoInStep, 10000)
    vel_fading = fade_in_weight(global_step, tempoInStep + velInStep, 10000)
    warp_fading = fade_in_weight(global_step, tempoInStep + velInStep + 10000, 20000)
    # fading in for losses
    vgg_fading = [fade_in_weight(global_step, (vgg_i - 1) * 10000, 10000) for vgg_i in
                  range(len(vggTool.layer_list), 0, -1)]
    ghost_fading = fade_in_weight(global_step, tempoInStep + 2000, 20000)
    d2v_fading = fade_in_weight(global_step, tempoInStep + velInStep * 2, 20000)
    ###########################################################

    # Random from one frame
    img_i = np.random.choice(i_train)
    target = images[img_i]
    target = torch.Tensor(target).to(device)
    pose = poses[img_i, :3, :4]
    time_locate = torch.Tensor(time_steps[img_i]).to(device)

    # Cast intrinsics to right types
    H, W, focal = hwfs[img_i]
    H, W = int(H), int(W)
    focal = float(focal)
    hwf = [H, W, focal]
    K = np.array([
        [focal, 0, 0.5 * W],
        [0, focal, 0.5 * H],
        [0, 0, 1]
    ])

    if trainVel:
        # take a mini_batch 32*32*32
        if trainD2V:  # a cropped 32^3 as a mini_batch
            offset_w = np.int32(
                np.random.uniform(d2v_min_border, int(trainX * denRatio + 1e-6) - denTW - d2v_min_border, []))
            offset_h = np.int32(
                np.random.uniform(d2v_min_border, int(trainY * denRatio + 1e-6) - denTW - d2v_min_border, []))
            offset_d = np.int32(
                np.random.uniform(d2v_min_border, int(trainZ * denRatio + 1e-6) - denTW - d2v_min_border, []))
            den_p_crop = den_p_all[offset_d:offset_d + denTW:2, offset_h:offset_h + denTW:2,
                         offset_w:offset_w + denTW:2, :]
            training_samples = torch.reshape(den_p_crop, (-1, 3))
            # training_samples = get_voxel_pts_offset(32, 32, 32, voxel_tran, voxel_scale, r_offset=4.0)
            # training_samples = torch.reshape(training_samples, (-1,3))
        else:  # a random mini_batch
            train_random = np.random.choice(trainZ * trainY * trainX, 32 * 32 * 32)
            training_samples = training_pts[train_random]

        training_samples = training_samples.view(-1, 3)
        training_t = torch.ones([training_samples.shape[0], 1]) * time_locate
        training_samples = torch.cat([training_samples, training_t], dim=-1)

        #####  core velocity optimization loop  #####
        # allows to take derivative w.r.t. training_samples
        training_samples = training_samples.clone().detach().requires_grad_(True)
        _vel, _u_x, _u_y, _u_z, _u_t = training_voxel.get_velocity_and_derivatives(training_samples, chunk=args.chunk,
                                                                                   batchify_fn=batchify,
                                                                                   vel_model=vel_model)
        _den, _d_x, _d_y, _d_z, _d_t = training_voxel.get_density_and_derivatives(training_samples, chunk=args.chunk,
                                                                                  use_viewdirs=False,
                                                                                  network_query_fn=render_kwargs_test[
                                                                                      'network_query_fn'],
                                                                                  network_fn=render_kwargs_test[
                                                                                      'network_fine' if args.N_importance > 0 else 'network_fn'])

        # get vorticity in 32^3 smoke resolution coord for training
        if trainD2V:  # all data are in a cropped 32^3 grid as a mini_batch
            with torch.no_grad():
                # in trainZ, trainY, trainX resolution space
                _den_voxel = _den.detach().view(1, 1, 32, 32, 32)  # b,c,DHW
                den_mask = torch.where(_den_voxel > 1e-6, torch.ones_like(_den_voxel),
                                       torch.zeros_like(_den_voxel)).view(32, 32, 32, 1)

                if torch.mean(den_mask) > 0.05:  # at least 5 percent are valid
                    den_voxel_raw = torch.nn.functional.interpolate(_den_voxel, size=64, mode='trilinear',
                                                                    align_corners=False) / (
                                            _den_voxel.max() + 1e-8)
                    vel_d2v_smoke = d2v_model(den_voxel_raw).permute([0, -1, 1, 2, 3])  # 1,3,64,64,64
                    vel_d2v_smoke = torch.nn.functional.interpolate(vel_d2v_smoke, size=32, mode='trilinear',
                                                                    align_corners=False).permute([0, 2, 3, 4, 1])
                    # 1,32,32,32,3
                    # scale according to vel_pred_smoke_view
                    vel_pred_smoke_view = vel_world2smoke(_vel.detach(), voxel_tran_inv, voxel_scale,
                                                          train_reso_scale).view(32, 32, 32, 3)
                    scale_factor = torch.mean(
                        torch.sqrt(torch.sum(torch.square(vel_pred_smoke_view * den_mask), dim=-1) + 1e-8))
                    scale_factor = scale_factor / torch.mean(
                        torch.sqrt(torch.sum(torch.square(vel_d2v_smoke * den_mask), dim=-1) + 1e-8))
                    vel_d2v_smoke = vel_d2v_smoke * scale_factor

                    # jac in trainZ, trainY, trainX resolution space
                    d2v_jac, d2v_vort = jacobian3D(vel_d2v_smoke)  # 1,32,32,32,9 and 1,32,32,32,3
                    d2v_udx = vel_smoke2world(d2v_jac[..., 0::3], voxel_tran, voxel_scale, train_reso_scale)
                    d2v_udy = vel_smoke2world(d2v_jac[..., 1::3], voxel_tran, voxel_scale, train_reso_scale)
                    d2v_udz = vel_smoke2world(d2v_jac[..., 2::3], voxel_tran, voxel_scale, train_reso_scale)
                    d2v_u_jac = torch.cat([d2v_udx, d2v_udy, d2v_udz], dim=-1).view(32, 32, 32, 9).detach()

                    smoke_baseX = off_smoke2world(torch.Tensor([1.0 / 256, 0.0, 0.0]), voxel_tran, voxel_scale)
                    smoke_baseY = off_smoke2world(torch.Tensor([0.0, 1.0 / 256, 0.0]), voxel_tran, voxel_scale)
                    smoke_baseZ = off_smoke2world(torch.Tensor([0.0, 0.0, 1.0 / 256]), voxel_tran, voxel_scale)
                else:
                    trainD2V = False  # train d2v next time.

        vel_optimizer.zero_grad()
        split_nse = PDE_EQs(
            _d_t.detach(), _d_x.detach(), _d_y.detach(), _d_z.detach(),
            _vel, _u_t, _u_x, _u_y, _u_z)
        nse_errors = [mean_squared_error(x, 0.0) for x in split_nse]
        nseloss_fine = 0.0
        for ei, wi in zip(nse_errors, split_nse_wei):
            nseloss_fine = ei * wi + nseloss_fine
        vel_loss = nseloss_fine * args.nseW * vel_fading

        if trainD2V:
            worldU_smokeX = smoke_baseX[0] * _u_x + smoke_baseX[1] * _u_y + smoke_baseX[2] * _u_z
            worldU_smokeY = smoke_baseY[0] * _u_x + smoke_baseY[1] * _u_y + smoke_baseY[2] * _u_z
            worldU_smokeZ = smoke_baseZ[0] * _u_x + smoke_baseZ[1] * _u_y + smoke_baseZ[2] * _u_z
            cur_jac = torch.cat([worldU_smokeX, worldU_smokeY, worldU_smokeZ], dim=-1)  # 9
            cur_jac = cur_jac.view(32, 32, 32, 9)

            d2v_jac_scale = torch.mean(torch.sqrt(torch.sum(torch.square(d2v_u_jac * den_mask), dim=-1))).detach()
            cur_jac_scale = torch.mean(torch.sqrt(torch.sum(torch.square(cur_jac * den_mask), dim=-1))).detach()

            d2v_error = mean_squared_error(cur_jac * den_mask, d2v_u_jac * den_mask) / torch.mean(den_mask)
            vel_loss += d2v_error * args.d2vW * d2v_fading

        vel_loss.backward()
        vel_optimizer.step()

    if trainImg:
        rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose))  # (H, W, 3), (H, W, 3)
        if trainVGG:  # get a cropped img (dw,dw) to train vgg
            strides = args.vgg_strides + i % 3 - 1

            # args.vgg_strides-1, args.vgg_strides, args.vgg_strides+1
            dw = int(max(20, min(40, N_rand ** 0.5)))
            vgg_min_border = 10
            strides = min(strides, min(H - vgg_min_border, W - vgg_min_border) / dw)
            strides = int(strides)

            coords = torch.stack(torch.meshgrid(torch.linspace(0, H - 1, H), torch.linspace(0, W - 1, W)),
                                 -1)  # (H, W, 2)
            if True:
                target_grey = torch.mean(torch.abs(target - args.white_bkgd), dim=-1, keepdim=True)  # H,W,1
                img_wei = coords.to(torch.float32) * target_grey
                center_coord = torch.sum(img_wei, dim=(0, 1)) / torch.sum(target_grey)
                center_coord = center_coord.cpu().numpy()
                # add random jitter
                random_R = dw * strides / 2.0
                # mean and standard deviation: center_coord, random_R/3.0, so that 3sigma < random_R
                random_x = np.random.normal(center_coord[1], random_R / 3.0) - 0.5 * dw * strides
                random_y = np.random.normal(center_coord[0], random_R / 3.0) - 0.5 * dw * strides
            else:
                random_x = np.random.uniform(low=vgg_min_border + 0.5 * dw * strides,
                                             high=W - 0.5 * dw * strides - vgg_min_border) - 0.5 * dw * strides
                random_y = np.random.uniform(low=vgg_min_border + 0.5 * dw * strides,
                                             high=W - 0.5 * dw * strides - vgg_min_border) - 0.5 * dw * strides

            offset_w = int(min(max(vgg_min_border, random_x), W - dw * strides - vgg_min_border))
            offset_h = int(min(max(vgg_min_border, random_y), H - dw * strides - vgg_min_border))

            coords_crop = coords[offset_h:offset_h + dw * strides:strides, offset_w:offset_w + dw * strides:strides,
                          :]

            select_coords = torch.reshape(coords_crop, [-1, 2]).long()
        else:
            if i < args.precrop_iters:
                dH = int(H // 2 * args.precrop_frac)
                dW = int(W // 2 * args.precrop_frac)
                coords = torch.stack(
                    torch.meshgrid(
                        torch.linspace(H // 2 - dH, H // 2 + dH - 1, 2 * dH),
                        torch.linspace(W // 2 - dW, W // 2 + dW - 1, 2 * dW)
                    ), -1)
                if i == start:
                    print(
                        f"[Config] Center cropping of size {2 * dH} x {2 * dW} is enabled until iter {args.precrop_iters}")
            else:
                coords = torch.stack(torch.meshgrid(torch.linspace(0, H - 1, H), torch.linspace(0, W - 1, W)),
                                     -1)  # (H, W, 2)

            coords = torch.reshape(coords, [-1, 2])  # (H * W, 2)
            select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False)  # (N_rand,)

            select_coords = coords[select_inds].long()  # (N_rand, 2)

        rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)
        rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)
        batch_rays = torch.stack([rays_o, rays_d], 0)
        target_s = target[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)

        if args.train_warp and vel_model is not None and (global_step >= tempoInStep + velInStep):
            render_kwargs_train['warp_fading_dt'] = warp_fading * t_info[-1]
            # fading * delt_T, need to update every iteration

        #####  core radiance optimization loop  #####
        rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
                                        verbose=i < 10, retraw=True,
                                        time_step=time_locate,
                                        bkgd_color=args.white_bkgd,
                                        **render_kwargs_train)
        optimizer.zero_grad()
        img_loss = img2mse(rgb, target_s)
        if (args.net_model == "hybrid") and ('rgbh1' in extras) and (tempo_fading < (1.0 - 1e-8)):  # rgbh1: static
            img_loss = img_loss * tempo_fading + img2mse(extras['rgbh1'], target_s) * (1.0 - tempo_fading)
            # rgb = rgb * tempo_fading + extras['rgbh1'] * (1.0-tempo_fading)

        # trans = extras['raw'][...,-1]
        loss = img_loss
        psnr = mse2psnr(img_loss)

        if 'rgb0' in extras:
            img_loss0 = img2mse(extras['rgb0'], target_s)
            if (args.net_model == "hybrid") and ('rgbh10' in extras) and (
                    tempo_fading < (1.0 - 1e-8)):  # rgbh1: static
                img_loss0 = img_loss0 * tempo_fading + img2mse(extras['rgbh10'], target_s) * (1.0 - tempo_fading)
                # extras['rgb0'] = extras['rgb0'] * tempo_fading + extras['rgbh10'] * (1.0-tempo_fading)

            loss = loss + img_loss0
            psnr0 = mse2psnr(img_loss0)

        if trainVGG:
            vgg_loss_func = vggTool.compute_cos_loss
            vgg_tar = torch.reshape(target_s, [dw, dw, 3])
            vgg_img = torch.reshape(rgb, [dw, dw, 3])
            vgg_loss = vgg_loss_func(vgg_img, vgg_tar)
            w_vgg = args.vggW / float(len(vgg_loss))
            vgg_loss_sum = 0
            for _w, _wf in zip(vgg_loss, vgg_fading):
                if _wf > 1e-8:
                    vgg_loss_sum = _w * _wf * w_vgg + vgg_loss_sum

            if 'rgb0' in extras:
                vgg_img0 = torch.reshape(extras['rgb0'], [dw, dw, 3])
                vgg_loss0 = vgg_loss_func(vgg_img0, vgg_tar)
                for _w, _wf in zip(vgg_loss0, vgg_fading):
                    if _wf > 1e-8:
                        vgg_loss_sum = _w * _wf * w_vgg + vgg_loss_sum
            loss += vgg_loss_sum

        if (args.ghostW > 0.0) and args.white_bkgd is not None:
            w_ghost = ghost_fading * args.ghostW
            if w_ghost > 1e-8:
                static_back = args.white_bkgd
                ghost_loss = ghost_loss_func(rgb, static_back, acc, den_penalty=0.0)
                if (args.net_model == "hybrid"):
                    if global_step > tempoInStep and ('rgbh1' in extras):  # static part
                        # ghost_loss += 0.1*ghost_loss_func(extras['rgbh1'], static_back, extras['acch1'], den_penalty=0.0)
                        if ('rgbh2' in extras):  # dynamic part
                            ghost_loss += 0.1 * ghost_loss_func(extras['rgbh2'], extras['rgbh1'], extras['acch2'],
                                                                den_penalty=0.0)

                if 'rgb0' in extras:
                    ghost_loss0 = ghost_loss_func(extras['rgb0'], static_back, extras['acc0'], den_penalty=0.0)
                    if (args.net_model == "hybrid"):
                        if global_step > tempoInStep and ('rgbh10' in extras):  # static part
                            # ghost_loss0 += 0.1*ghost_loss_func(extras['rgbh10'], static_back, extras['acch10'], den_penalty=0.0)
                            if ('rgbh20' in extras):  # dynamic part
                                ghost_loss0 += 0.1 * ghost_loss_func(extras['rgbh20'], extras['rgbh10'],
                                                                     extras['acch20'], den_penalty=0.0)
                    ghost_loss += ghost_loss0

                loss += ghost_loss * w_ghost

        if (args.net_model == "hybrid") and (args.overlayW > 0):
            # density should be either from smoke or from static, not mixed.
            w_overlay = args.overlayW * ghost_fading  # with fading

            smoke_den, static_den = F.relu(extras['raw'][..., -1]), F.relu(extras['raw_static'][..., -1])
            overlay_loss = torch.div(2.0 * (smoke_den * static_den),
                                     (torch.square(smoke_den) + torch.square(static_den) + 1e-8))
            overlay_loss = torch.mean(overlay_loss)
            loss += overlay_loss * w_overlay

        loss.backward()
        optimizer.step()

    ###   update learning rate   ###
    decay_rate = 0.1
    decay_steps = args.lrate_decay * 1000
    new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lrate
    if trainVel and vel_optimizer is not None:
        for param_group in vel_optimizer.param_groups:
            param_group['lr'] = new_lrate

    global_step += 1

  0%|          | 0/600000 [00:00<?, ?it/s]

[Config] Center cropping of size 480 x 270 is enabled until iter 1000


  0%|          | 302/600000 [00:50<28:01:23,  5.94it/s]


KeyboardInterrupt: 