In [1]:
import os
import sys
sys.path.append("/home/yang_liu/python_workspace/gaussian-splatting-lightning")

from matplotlib import pyplot as plt

import torch
import math
import random
import numpy as np
from gsplat import project_gaussians
from gsplat.rasterize import rasterize_gaussians

from internal.utils.gaussian_model_loader import GaussianModelLoader
from internal.dataparsers.colmap_dataparser import ColmapParams, ColmapDataParser
from internal.dataparsers.colmap_joint_dataparser import ColmapJointDataParser

torch.set_grad_enabled(False)

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [2]:
# load model and dataset
model, renderer = GaussianModelLoader.search_and_load(
    "../outputs/mc_fuse_w",
    sh_degree=3,
    device="cuda",
)
print("Gaussian count: {}".format(model.get_xyz.shape[0]))
# dataset
dataparser_outputs = ColmapJointDataParser(
    os.path.expanduser("../data/matrix_city/aerial_street_fusion/fuse/train"),
    os.path.abspath(""),
    global_rank=0,
    params=ColmapParams(
        eval_image_select_mode="ratio",
        eval_ratio=1.0,
    ),
).get_outputs()
print("Test camera count: {}".format(len(dataparser_outputs.test_set.cameras)))

Gaussian count: 1265420
appearance group by camera id
loading colmap 3D points
[colmap_joint dataparser] train set images: 453, val set images: 453, loaded mask: 0
Test camera count: 453


In [3]:
# generate dataloader
from internal.dataset import Dataset, CacheDataLoader

camera_ids = []
for idx, image in enumerate(dataparser_outputs.train_set):
    image_name, _, _, camera = image
    camera_ids.append(camera.camera_type)

camera_ids_tensor = torch.tensor(camera_ids)
class_sample_count = torch.tensor(
    [(camera_ids_tensor == t).sum() for t in torch.unique(camera_ids_tensor, sorted=True)]
)
weight = 1. / class_sample_count.float()
samples_weight = torch.tensor([weight[t] for t in camera_ids_tensor])
sampler = torch.utils.data.WeightedRandomSampler(samples_weight, len(samples_weight))

dataloader = CacheDataLoader(
    Dataset(dataparser_outputs.train_set, undistort_image=False),
    max_cache_num=-1,
    shuffle=None,
    sampler=sampler,
    seed=torch.initial_seed(),  # seed with global rank
    num_workers=8,
    distributed=False,
    world_size=1,
    global_rank=0,
)

cache all images


#46423 caching images (1st: 0):   0%|          | 0/453 [00:00<?, ?it/s]

#46423 caching images (1st: 0): 100%|██████████| 453/453 [00:14<00:00, 30.36it/s]


In [4]:
# iterate dataloader
cam_samples = list(sampler)
cam_ids_sampled = np.array([camera_ids_tensor[i].numpy() for i in cam_samples])
# for idx, (cam, gt) in enumerate(dataloader):
#     cam_types.append(cam.camera_type.numpy())

print("Camera 0 count: {}".format((cam_ids_sampled == 0).sum()))
print("Camera 1 count: {}".format((cam_ids_sampled == 1).sum()))
print(cam_ids_sampled)

Camera 0 count: 229
Camera 1 count: 224
[0 1 1 0 1 1 0 0 0 1 1 1 1 1 1 0 1 1 1 0 0 1 0 0 1 0 0 1 0 1 0 0 1 1 0 1 0
 0 1 1 1 1 0 1 0 0 1 0 1 0 0 0 0 1 0 0 1 0 0 1 0 0 1 0 1 0 0 1 1 0 0 1 1 1
 1 0 1 1 0 1 0 0 0 1 1 1 0 0 0 1 0 1 1 1 1 0 0 1 0 0 0 0 1 0 0 0 0 1 0 0 0
 1 0 1 1 0 0 0 0 1 1 0 1 1 0 0 1 1 1 0 0 1 1 0 0 0 0 0 1 1 1 0 1 0 1 0 0 1
 0 1 0 0 1 0 0 0 1 1 0 1 1 1 0 1 1 0 1 1 0 1 1 0 0 0 0 0 1 1 1 1 1 0 1 0 0
 1 0 1 1 0 0 1 0 0 1 1 0 0 1 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 1 1 1 1 0 1
 0 1 0 1 0 1 1 1 0 0 0 1 1 1 1 1 0 0 0 0 0 1 1 0 0 1 0 0 1 1 1 1 1 0 1 0 0
 1 1 0 1 1 1 1 1 1 1 0 0 1 1 0 0 1 1 1 1 1 1 1 0 1 1 0 1 1 1 0 1 1 1 0 1 0
 0 0 0 1 0 0 1 0 1 0 0 0 1 0 1 1 1 0 0 0 0 0 0 0 1 0 1 1 1 0 1 1 0 0 1 1 0
 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 1 1 0 1 0 0 0 1 1 1 0 0 0 0 1 1 1 1 0 0 0 0
 1 1 1 1 0 0 1 1 1 0 1 1 0 1 0 1 1 1 1 1 0 0 0 0 1 1 0 1 0 0 1 0 0 1 0 1 1
 1 0 0 1 1 0 0 0 0 1 0 1 1 1 1 1 0 0 0 1 1 1 0 1 0 0 0 0 1 0 0 1 1 0 1 0 1
 1 0 1 1 0 0 1 1 0]
