In [1]:
import train
import torch
import pathlib

device = torch.device("cuda")

manager = train.load_data(torch.device("cuda"), data_path=pathlib.Path("C:/Users/imeho/Documents/DataSets/InstantPINF/ScalarReal"))

Output()

Output()

In [2]:
from nerfstudio.model_components.ray_samplers import UniformSampler
from nerfstudio.model_components.scene_colliders import NearFarCollider

collider = NearFarCollider(near_plane=1.1, far_plane=1.5)
sampler_uniform = UniformSampler(num_samples=192)

In [3]:
import src.encoder as encoder
import taichi as ti
import numpy as np

ti.init(arch=ti.cuda)
xyzt_encoder = encoder.HashEncoderHyFluid(
    min_res=np.array([16, 16, 16, 16]),
    max_res=np.array([256, 256, 256, 128]),
    num_scales=16,
    max_params=2 ** 19,
)
xyzt_encoder.to(device)

[Taichi] version 1.7.2, llvm 15.0.1, commit 0131dce9, win, python 3.11.0
[Taichi] Starting on arch=cuda


HashEncoderHyFluid()

In [5]:
from nerfstudio.field_components.mlp import MLP
import src.radam as radam

mlp_base = MLP(
    in_dim=xyzt_encoder.num_scales * xyzt_encoder.features_per_level,
    num_layers=2,
    layer_width=64,
    out_dim=1,
    out_activation=torch.nn.ReLU(),
)
mlp_base.to(device)
learned_rgb = torch.nn.Parameter(torch.tensor([0.0], device=device))

grad_vars = list(mlp_base.parameters()) + [learned_rgb]
embedding_params = list(xyzt_encoder.parameters())

optimizer = radam.RAdam([
    {'params': grad_vars, 'weight_decay': 1e-6},
    {'params': embedding_params, 'eps': 1e-15}
], lr=0.01, betas=(0.9, 0.99))

In [6]:
raw2alpha = lambda raw, dists, act_fn=torch.nn.functional.relu: 1. - torch.exp(-act_fn(raw) * dists)

In [None]:
from nerfstudio.model_components.losses import MSELoss
import tqdm
import matplotlib.pyplot as plt

import src.datamanager as datamanager

loss_history = []

for i in tqdm.tqdm(range(100)):
    ray_bundle, batch = manager.next_train()
    collider(ray_bundle)
    ray_samples_uniform = sampler_uniform(ray_bundle)
    positions = ray_samples_uniform.frustums.get_positions()
    frames = datamanager.image_idx_to_frame(
        image_indices=batch['indices'][:, 0],
        all_frames=manager.train_dataset.metadata['all_frames'])
    frames_expanded = frames.to(device).view(positions.shape[0], 1, 1).expand(-1, positions.shape[1], -1)
    xyzt = torch.cat((positions, frames_expanded), dim=-1)
    xyzt_flat = xyzt.reshape(-1, 4)
    xyzt_encoded = xyzt_encoder(xyzt_flat)

    raw_flat = mlp_base(xyzt_encoded)
    raw = raw_flat.reshape(xyzt.shape[0], xyzt.shape[1], raw_flat.shape[-1])

    dists = ray_samples_uniform.deltas
    rgb = torch.ones(3, device=device) * (0.6 + torch.tanh(learned_rgb) * 0.4)
    alpha = raw2alpha(raw[..., -1], dists[..., -1])
    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1), device=device), 1. - alpha + 1e-10], -1),
                                    -1)[:, :-1]
    rgb_map = torch.sum(weights[..., None] * rgb, -2)

    rgb_loss = MSELoss()
    image = batch['image'].to(device)
    loss = rgb_loss(rgb_map, image)
    loss_history.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    new_lrate = 5e-4 * (0.1 ** (i / 250))
    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lrate

plt.figure(figsize=(10, 6))
plt.plot(loss_history, label="Loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.title("Loss over Training Iterations")
plt.legend()
plt.grid()
plt.show()

 49%|████▉     | 49/100 [00:27<00:23,  2.14it/s]