In [None]:
import train
import torch
import pathlib

device = torch.device("cuda")

manager = train.load_data(torch.device("cuda"),
                          data_path=pathlib.Path("C:/Users/imeho/Documents/DataSets/InstantPINF/ScalarReal"))

In [None]:
from nerfstudio.model_components.ray_samplers import UniformSampler
from nerfstudio.model_components.scene_colliders import NearFarCollider

collider = NearFarCollider(near_plane=1.1, far_plane=1.5)
sampler_uniform = UniformSampler(num_samples=192)

In [None]:
import src.encoder as encoder
import taichi as ti
import numpy as np

ti.init(arch=ti.cuda)
xyzt_encoder = encoder.HashEncoderHyFluid(
    min_res=np.array([16, 16, 16, 16]),
    max_res=np.array([256, 256, 256, 128]),
    num_scales=16,
    max_params=2 ** 19,
)
xyzt_encoder.to(device)

In [None]:
from nerfstudio.field_components.mlp import MLP
import src.radam as radam

mlp_base = MLP(
    in_dim=xyzt_encoder.num_scales * xyzt_encoder.features_per_level,
    num_layers=2,
    layer_width=64,
    out_dim=1,
    out_activation=torch.nn.ReLU(),
)
mlp_base.to(device)
learned_rgb = torch.nn.Parameter(torch.tensor([0.0], device=device))

grad_vars = list(mlp_base.parameters()) + [learned_rgb]
embedding_params = list(xyzt_encoder.parameters())

optimizer = radam.RAdam([
    {'params': grad_vars, 'weight_decay': 1e-6},
    {'params': embedding_params, 'eps': 1e-15}
], lr=0.01, betas=(0.9, 0.99))

In [None]:
raw2alpha = lambda raw, dists, act_fn=torch.nn.functional.relu: 1. - torch.exp(-act_fn(raw) * dists)

In [None]:
from nerfstudio.model_components.losses import MSELoss
import tqdm
import matplotlib.pyplot as plt

import src.datamanager as datamanager

loss_history = []

for i in tqdm.tqdm(range(100)):
    ray_bundle, batch = manager.next_train()
    collider(ray_bundle)
    ray_samples_uniform = sampler_uniform(ray_bundle)
    positions = ray_samples_uniform.frustums.get_positions()
    frames = datamanager.image_idx_to_frame(
        image_indices=batch['indices'][:, 0],
        all_frames=manager.train_dataset.metadata['all_frames'])
    frames_expanded = frames.to(device).view(positions.shape[0], 1, 1).expand(-1, positions.shape[1], -1)
    xyzt = torch.cat((positions, frames_expanded), dim=-1)
    xyzt_flat = xyzt.reshape(-1, 4)
    xyzt_encoded = xyzt_encoder(xyzt_flat)

    raw_flat = mlp_base(xyzt_encoded)
    raw = raw_flat.reshape(xyzt.shape[0], xyzt.shape[1], raw_flat.shape[-1])

    dists = ray_samples_uniform.deltas
    rgb = torch.ones(3, device=device) * (0.6 + torch.tanh(learned_rgb) * 0.4)
    alpha = raw2alpha(raw[..., -1], dists[..., -1])
    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1), device=device), 1. - alpha + 1e-10], -1),
                                    -1)[:, :-1]
    rgb_map = torch.sum(weights[..., None] * rgb, -2)

    rgb_loss = MSELoss()
    image = batch['image'].to(device)
    loss = rgb_loss(rgb_map, image)
    loss_history.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    new_lrate = 5e-4 * (0.1 ** (i / 250))
    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lrate

plt.figure(figsize=(10, 6))
plt.plot(loss_history, label="Loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.title("Loss over Training Iterations")
plt.legend()
plt.grid()
plt.show()

In [1]:
import src.dataloader as dl
import torch
import pathlib

dataloader = dl.load_train_data(pathlib.Path("C:/Users/imeho/Documents/DataSets/InstantPINF/ScalarReal"), "train",
                                device=torch.device("cuda"))
iter = iter(dataloader)
image_batch = next(iter)

print(f'image device: {image_batch["image"].device}')
memory_image_cuda = image_batch['image'].element_size() * image_batch['image'].numel()
print(f'Memory of image: {memory_image_cuda / 1024 / 1024:.2f} MB')

Output()

image device: cuda:0
Memory of image: 11390.62 MB


In [10]:
import nerfstudio.data.pixel_samplers

ps = nerfstudio.data.pixel_samplers.PixelSamplerConfig(num_rays_per_batch=1024).setup()
batch = ps.sample(image_batch)

print(batch.keys())
print(batch['image'])
print(batch['indices'])

dict_keys(['image', 'indices'])
tensor([[0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000],
        [0.1020, 0.1020, 0.1020],
        [0.0000, 0.0000, 0.0000]], device='cuda:0')
tensor([[ 146, 1711,  750],
        [ 344,  173,  378],
        [ 316, 1509,  230],
        ...,
        [ 458,  552,  430],
        [ 174, 1376,  639],
        [ 385, 1550,  891]], device='cuda:0')


In [14]:
import nerfstudio.model_components.ray_generators

rg = nerfstudio.model_components.ray_generators.RayGenerator(dataloader.dataset.cameras.to(torch.device("cuda"))).to(
    torch.device("cuda"))
ray_bundle = rg(batch['indices'])

cuda:0


In [22]:
import nerfstudio.data.utils.dataloaders as dls

fied = dls.FixedIndicesEvalDataloader(input_dataset=dataloader.dataset, device=torch.device("cuda"))
camera, batch = next(fied)

print(camera)
print(batch.keys())
print(batch['image_idx'])
print(batch['image'].shape)

Cameras(camera_to_worlds=tensor([[[ 0.4863, -0.2431, -0.8393, -0.7697],
         [-0.0189,  0.9574, -0.2882,  0.0132],
         [ 0.8736,  0.1560,  0.4610,  0.3250]]], device='cuda:0'), fx=tensor([[2613.7634]], device='cuda:0', dtype=torch.float64), fy=tensor([[2613.7634]], device='cuda:0', dtype=torch.float64), cx=tensor([[540.]], device='cuda:0', dtype=torch.float64), cy=tensor([[960.]], device='cuda:0', dtype=torch.float64), width=tensor([[1080]], device='cuda:0'), height=tensor([[1920]], device='cuda:0'), distortion_params=None, camera_type=tensor([[1]], device='cuda:0'), times=None, metadata=None)
dict_keys(['image_idx', 'image'])
0
torch.Size([1920, 1080, 3])


In [28]:
print(fied.cameras.device)

cuda:0


In [26]:
camera_ray_bundle = camera.generate_rays(camera_indices=0, keep_shape=True)
image_height, image_width = camera_ray_bundle.origins.shape[:2]
print(image_height, image_width)
num_rays = len(camera_ray_bundle)
print(f'Number of rays: {num_rays}')
print(camera_ray_bundle.origins.device)

for i in range(0, num_rays, 1024):
    start_idx = i
    end_idx = i + 1024
    ray_bundle = camera_ray_bundle.get_row_major_sliced_ray_bundle(start_idx, end_idx)

1920 1080
Number of rays: 2073600
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0


In [2]:
dl.create_test(dataloader.dataset, torch.device("cuda"))

Number of rays: 2073600
