In [None]:
import src.dataloader
import torch
import pathlib

device = torch.device("cuda")

dataloader = src.dataloader.load_train_data(pathlib.Path("C:/Users/imeho/Documents/DataSets/InstantPINF/ScalarReal"),
                                            "train",
                                            device=device)
image_batch = next(iter(dataloader))

print(f'image device: {image_batch["image"].device}')
memory_image_cuda = image_batch['image'].element_size() * image_batch['image'].numel()
print(f'Memory of image: {memory_image_cuda / 1024 / 1024:.2f} MB')

In [None]:
from nerfstudio.model_components.ray_samplers import UniformSampler
from nerfstudio.model_components.scene_colliders import NearFarCollider

collider = NearFarCollider(near_plane=1.1, far_plane=1.5).to(device)
sampler_uniform = UniformSampler(num_samples=192).to(device)

In [None]:
import src.encoder
import taichi as ti
import numpy as np

ti.init(arch=ti.cuda)
xyzt_encoder = src.encoder.HashEncoderHyFluid(
    min_res=np.array([16, 16, 16, 16]),
    max_res=np.array([256, 256, 256, 128]),
    num_scales=16,
    max_params=2 ** 19,
)
xyzt_encoder.to(device)

In [None]:
from nerfstudio.field_components.mlp import MLP
import src.radam

mlp_base = MLP(
    in_dim=xyzt_encoder.num_scales * xyzt_encoder.features_per_level,
    num_layers=2,
    layer_width=64,
    out_dim=1,
    out_activation=torch.nn.ReLU(),
)
mlp_base.to(device)
learned_rgb = torch.nn.Parameter(torch.tensor([0.0], device=device))

grad_vars = list(mlp_base.parameters()) + [learned_rgb]
embedding_params = list(xyzt_encoder.parameters())

optimizer = src.radam.RAdam([
    {'params': grad_vars, 'weight_decay': 1e-6},
    {'params': embedding_params, 'eps': 1e-15}
], lr=0.01, betas=(0.9, 0.99))

In [None]:
raw2alpha = lambda raw, dists, act_fn=torch.nn.functional.relu: 1. - torch.exp(-act_fn(raw) * dists)

In [None]:
import nerfstudio.data.pixel_samplers
import nerfstudio.model_components.ray_generators

ps = nerfstudio.data.pixel_samplers.PixelSamplerConfig(num_rays_per_batch=1024).setup()
rg = nerfstudio.model_components.ray_generators.RayGenerator(dataloader.dataset.cameras).to(device)

In [None]:
def image_idx_to_frame(image_indices, all_frames):
    cumulative_frames = torch.cumsum(torch.tensor([0] + all_frames[:-1], device=image_indices.device), dim=0)
    result = []
    for image_idx in image_indices:
        video_idx = torch.searchsorted(cumulative_frames, image_idx, right=True) - 1
        frame_in_video = image_idx - cumulative_frames[video_idx]
        result.append(frame_in_video.item())
    return torch.tensor(result, device=image_indices.device)

In [None]:
import nerfstudio.model_components.losses
import tqdm
import time
import matplotlib.pyplot as plt

# 初始化 MSE 损失函数
rgb_loss = nerfstudio.model_components.losses.MSELoss()

# 定义列表存储时间
encoding_times = []
loss_calc_times = []
optimizer_times = []

# 主循环
for i in tqdm.tqdm(range(100)):
    # 记录编码时间
    start = time.time()
    batch = ps.sample(image_batch)
    ray_bundle = rg(batch['indices'])
    collider(ray_bundle)
    ray_samples_uniform = sampler_uniform(ray_bundle)
    positions = ray_samples_uniform.frustums.get_positions()

    frames = image_idx_to_frame(image_indices=batch['indices'][:, 0],
                                all_frames=dataloader.dataset.metadata['all_frames'])
    frames_expanded = frames.view(positions.shape[0], 1, 1).expand(-1, positions.shape[1], -1)
    xyzt = torch.cat((positions, frames_expanded), dim=-1)
    xyzt_flat = xyzt.reshape(-1, 4)
    xyzt_encoded = xyzt_encoder(xyzt_flat)
    end = time.time()
    encoding_times.append(end - start)  # 存储编码耗时

    # 记录损失计算时间
    start = time.time()
    raw_flat = mlp_base(xyzt_encoded)
    raw = raw_flat.reshape(xyzt.shape[0], xyzt.shape[1], raw_flat.shape[-1])
    dists = ray_samples_uniform.deltas
    rgb = torch.ones(3, device=device) * (0.6 + torch.tanh(learned_rgb) * 0.4)
    alpha = raw2alpha(raw[..., -1], dists[..., -1])
    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1), device=device), 1. - alpha + 1e-10], -1),
                                    -1)[:, :-1]
    rgb_map = torch.sum(weights[..., None] * rgb, -2)
    loss = rgb_loss(rgb_map, batch['image'])
    end = time.time()
    loss_calc_times.append(end - start)  # 存储损失计算耗时

    # 记录优化时间
    start = time.time()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    new_lrate = 5e-4 * (0.1 ** (i / 250))
    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lrate

    end = time.time()
    optimizer_times.append(end - start)  # 存储优化耗时

# 绘制折线图
plt.figure(figsize=(10, 6))
plt.plot(encoding_times, label='Encoding Time')
plt.plot(loss_calc_times, label='Loss Calculation Time')
plt.plot(optimizer_times, label='Optimizer Time')
plt.xlabel('Iteration')
plt.ylabel('Time (seconds)')
plt.title('Time Analysis per Operation')
plt.legend()
plt.grid(True)
plt.show()

  0%|          | 0/100 [00:00<?, ?it/s]

Encoding took 0.23 seconds


  1%|          | 1/100 [00:00<00:56,  1.77it/s]

Loss calculation took 0.34 seconds
optimizer 0 took 0.00 seconds


  2%|▏         | 2/100 [00:14<13:56,  8.54s/it]

Encoding took 14.07 seconds
Loss calculation took 0.03 seconds
optimizer 1 took 0.02 seconds


  3%|▎         | 3/100 [00:28<17:49, 11.02s/it]

Encoding took 13.93 seconds
Loss calculation took 0.03 seconds
optimizer 2 took 0.02 seconds


  4%|▍         | 4/100 [00:42<19:14, 12.02s/it]

Encoding took 13.52 seconds
Loss calculation took 0.03 seconds
optimizer 3 took 0.00 seconds


  5%|▌         | 5/100 [00:55<19:58, 12.62s/it]

Encoding took 13.63 seconds
Loss calculation took 0.03 seconds
optimizer 4 took 0.02 seconds


  6%|▌         | 6/100 [01:09<20:24, 13.02s/it]

Encoding took 13.78 seconds
Loss calculation took 0.03 seconds
optimizer 5 took 0.00 seconds


  7%|▋         | 7/100 [01:23<20:34, 13.27s/it]

Encoding took 13.75 seconds
Loss calculation took 0.03 seconds
optimizer 6 took 0.00 seconds


  8%|▊         | 8/100 [01:36<20:27, 13.35s/it]

Encoding took 13.46 seconds
Loss calculation took 0.03 seconds
optimizer 7 took 0.02 seconds


  9%|▉         | 9/100 [01:51<20:36, 13.58s/it]

Encoding took 14.05 seconds
Loss calculation took 0.03 seconds
optimizer 8 took 0.02 seconds


 10%|█         | 10/100 [02:04<20:25, 13.62s/it]

Encoding took 13.65 seconds
Loss calculation took 0.03 seconds
optimizer 9 took 0.02 seconds


 11%|█         | 11/100 [02:18<20:13, 13.63s/it]

Encoding took 13.64 seconds
Loss calculation took 0.03 seconds
optimizer 10 took 0.00 seconds


 12%|█▏        | 12/100 [02:32<19:57, 13.61s/it]

Encoding took 13.51 seconds
Loss calculation took 0.03 seconds
optimizer 11 took 0.00 seconds


 13%|█▎        | 13/100 [02:45<19:39, 13.56s/it]

Encoding took 13.41 seconds
Loss calculation took 0.03 seconds
optimizer 12 took 0.00 seconds


 14%|█▍        | 14/100 [02:58<19:24, 13.54s/it]

Encoding took 13.46 seconds
Loss calculation took 0.03 seconds
optimizer 13 took 0.02 seconds


 15%|█▌        | 15/100 [03:12<19:12, 13.56s/it]

Encoding took 13.57 seconds
Loss calculation took 0.03 seconds
optimizer 14 took 0.00 seconds


 16%|█▌        | 16/100 [03:26<19:01, 13.58s/it]

Encoding took 13.61 seconds
Loss calculation took 0.03 seconds
optimizer 15 took 0.00 seconds


 17%|█▋        | 17/100 [03:39<18:46, 13.57s/it]

Encoding took 13.52 seconds
Loss calculation took 0.03 seconds
optimizer 16 took 0.00 seconds


 18%|█▊        | 18/100 [03:53<18:41, 13.68s/it]

Encoding took 13.79 seconds
Loss calculation took 0.03 seconds
optimizer 17 took 0.09 seconds


 19%|█▉        | 19/100 [04:07<18:33, 13.75s/it]

Encoding took 13.87 seconds
Loss calculation took 0.04 seconds
optimizer 18 took 0.01 seconds


 20%|██        | 20/100 [04:21<18:25, 13.82s/it]

Encoding took 13.96 seconds
Loss calculation took 0.03 seconds
optimizer 19 took 0.00 seconds


 21%|██        | 21/100 [04:35<18:11, 13.81s/it]

Encoding took 13.76 seconds
Loss calculation took 0.03 seconds
optimizer 20 took 0.00 seconds


 22%|██▏       | 22/100 [04:49<17:56, 13.81s/it]

Encoding took 13.76 seconds
Loss calculation took 0.03 seconds
optimizer 21 took 0.00 seconds


 23%|██▎       | 23/100 [05:02<17:40, 13.77s/it]

Encoding took 13.67 seconds
Loss calculation took 0.03 seconds
optimizer 22 took 0.00 seconds


 24%|██▍       | 24/100 [05:16<17:28, 13.80s/it]

Encoding took 13.78 seconds
Loss calculation took 0.03 seconds
optimizer 23 took 0.06 seconds


 25%|██▌       | 25/100 [05:30<17:05, 13.68s/it]

Encoding took 13.36 seconds
Loss calculation took 0.03 seconds
optimizer 24 took 0.00 seconds


 26%|██▌       | 26/100 [05:44<16:57, 13.75s/it]

Encoding took 13.89 seconds
Loss calculation took 0.03 seconds
optimizer 25 took 0.00 seconds


 27%|██▋       | 27/100 [05:57<16:39, 13.70s/it]

Encoding took 13.52 seconds
Loss calculation took 0.03 seconds
optimizer 26 took 0.02 seconds


 28%|██▊       | 28/100 [06:11<16:25, 13.69s/it]

Encoding took 13.64 seconds
Loss calculation took 0.03 seconds
optimizer 27 took 0.02 seconds


 29%|██▉       | 29/100 [06:25<16:20, 13.82s/it]

Encoding took 14.02 seconds
Loss calculation took 0.03 seconds
optimizer 28 took 0.05 seconds


 30%|███       | 30/100 [06:39<16:08, 13.83s/it]

Encoding took 13.82 seconds
Loss calculation took 0.05 seconds
optimizer 29 took 0.00 seconds


 31%|███       | 31/100 [06:53<15:55, 13.85s/it]

Encoding took 13.88 seconds
Loss calculation took 0.03 seconds
optimizer 30 took 0.00 seconds
