In [None]:
# This test is to be run in Google Colab using a GPU backend (tested on A100 runtime)

In [None]:
# Test if custom Dataset returns same rgbs as base srn Dataset

In [1]:
# Clone necessary repos
!git clone --recursive https://github.com/Kacper-M-Michalik/splatter-image.git
!git clone --recursive https://github.com/graphdeco-inria/gaussian-splatting.git

Cloning into 'splatter-image'...
remote: Enumerating objects: 391, done.[K
remote: Counting objects: 100% (206/206), done.[K
remote: Compressing objects: 100% (100/100), done.[K
remote: Total 391 (delta 149), reused 124 (delta 106), pack-reused 185 (from 2)[K
Receiving objects: 100% (391/391), 3.02 MiB | 3.38 MiB/s, done.
Resolving deltas: 100% (218/218), done.
Cloning into 'gaussian-splatting'...
remote: Enumerating objects: 1053, done.[K
remote: Total 1053 (delta 0), reused 0 (delta 0), pack-reused 1053 (from 1)[K
Receiving objects: 100% (1053/1053), 78.70 MiB | 16.16 MiB/s, done.
Resolving deltas: 100% (604/604), done.
Submodule 'SIBR_viewers' (https://gitlab.inria.fr/sibr/sibr_core.git) registered for path 'SIBR_viewers'
Submodule 'submodules/diff-gaussian-rasterization' (https://github.com/graphdeco-inria/diff-gaussian-rasterization.git) registered for path 'submodules/diff-gaussian-rasterization'
Submodule 'submodules/fused-ssim' (https://github.com/rahul-goel/fused-ssim.gi

In [2]:
# Get SRN cars dataset for cars test
%cd /content
!mkdir SRN
%cd /content/SRN
!mkdir srn_cars
%cd /content/SRN/srn_cars
!gdown --id 19yDsEJjx9zNpOKz9o6AaK-E8ED6taJWU -O cars.zip
!unzip cars.zip
%cd /content

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: cars_val/87ee241d3d0d1dda4ff3c6764341833/pose/000032.txt  
  inflating: cars_val/87ee241d3d0d1dda4ff3c6764341833/pose/000015.txt  
  inflating: cars_val/87ee241d3d0d1dda4ff3c6764341833/pose/000092.txt  
  inflating: cars_val/87ee241d3d0d1dda4ff3c6764341833/pose/000121.txt  
  inflating: cars_val/87ee241d3d0d1dda4ff3c6764341833/pose/000192.txt  
  inflating: cars_val/87ee241d3d0d1dda4ff3c6764341833/pose/000200.txt  
  inflating: cars_val/87ee241d3d0d1dda4ff3c6764341833/pose/000074.txt  
  inflating: cars_val/87ee241d3d0d1dda4ff3c6764341833/pose/000142.txt  
  inflating: cars_val/87ee241d3d0d1dda4ff3c6764341833/pose/000166.txt  
  inflating: cars_val/87ee241d3d0d1dda4ff3c6764341833/pose/000018.txt  
  inflating: cars_val/87ee241d3d0d1dda4ff3c6764341833/pose/000077.txt  
  inflating: cars_val/87ee241d3d0d1dda4ff3c6764341833/pose/000033.txt  
  inflating: cars_val/87ee241d3d0d1dda4ff3c6764341833/pose/000224.txt  

In [3]:
# Imports
import sys
import os
import torch
from torch.utils.data import DataLoader
from omegaconf import OmegaConf

# Perform import of get_dataset from the splatter-image submodule
splatter_root = os.path.join("splatter-image")
if splatter_root not in sys.path:
    sys.path.append(splatter_root)
from splatter_datasets.dataset_factory import get_dataset

In [9]:
# Construct configs for datasets
srn_cfg = OmegaConf.load(os.path.join(splatter_root, "configs", "default_config.yaml"))
srn_cfg.data.category = "cars"
srn_cfg.data.znear = 1
srn_cfg.data.zfar = 2
srn_cfg.data.fov = 1

# Need to fill in some dummy data
srn_priors_cfg = OmegaConf.load(os.path.join(splatter_root, "configs", "default_config.yaml"))
srn_priors_cfg.data.category = "cars_priors"
srn_priors_cfg.data.znear = 1
srn_priors_cfg.data.zfar = 2
srn_priors_cfg.data.fov = 1

print(srn_cfg)
print(srn_priors_cfg)

{'defaults': [{'wandb': 'defaults'}, {'hydra': 'defaults'}, {'cam_embd': 'defaults'}, '_self_'], 'general': {'device': 0, 'random_seed': 0, 'num_devices': 1, 'mixed_precision': False}, 'data': {'training_resolution': 128, 'subset': -1, 'input_images': 1, 'origin_distances': False, 'use_pred_depth': False, 'use_pred_normal': False, 'category': 'cars', 'znear': 1, 'zfar': 2, 'fov': 1}, 'opt': {'iterations': 800001, 'base_lr': 5e-05, 'batch_size': 8, 'betas': [0.9, 0.999], 'loss': 'l2', 'imgs_per_obj': 4, 'ema': {'use': True, 'update_every': 10, 'update_after_step': 100, 'beta': 0.9999}, 'lambda_lpips': 0.0, 'pretrained_ckpt': None, 'pretrained_hf': False, 'lora_finetune': False}, 'model': {'max_sh_degree': 1, 'inverted_x': False, 'inverted_y': True, 'name': 'SingleUNet', 'opacity_scale': 1.0, 'opacity_bias': -2.0, 'scale_bias': 0.02, 'scale_scale': 0.003, 'xyz_scale': 0.1, 'xyz_bias': 0.0, 'depth_scale': 1.0, 'depth_bias': 0.0, 'network_without_offset': False, 'network_with_offset': True

In [17]:
splits = ["test", "train", "val"]
success = True

for split in splits:
    old_dataset = get_dataset(srn_cfg, split)
    new_dataset = get_dataset(srn_priors_cfg, split)

    # Ensures both datasets are of same length
    if len(old_dataset) != len(new_dataset):
        print("Mismatch in dataset length for {} split".format(split))
        success = False
        break

    # Shuffle=False ensures the same uuid batches are loaded from both Datasets
    loader_old = DataLoader(old_dataset, batch_size=1, shuffle=False)
    loader_new = DataLoader(new_dataset, batch_size=1, shuffle=False)

    # Perform equality check on common key/value pairs
    i = 0
    for batch_old, batch_new in zip(loader_old, loader_new):
        if (not torch.equal(batch_old["gt_images"], batch_new["gt_images"]) or
            not torch.equal(batch_old["world_view_transforms"], batch_new["world_view_transforms"]) or
            not torch.equal(batch_old["view_to_world_transforms"], batch_new["view_to_world_transforms"]) or
            not torch.equal(batch_old["full_proj_transforms"], batch_new["full_proj_transforms"]) or
            not torch.equal(batch_old["camera_centers"], batch_new["camera_centers"])):
            success = False
            print("Found mismatched entry in datasets at split:{}, index:{}".format(split, i))
            break
        i += 1

    if not success:
        break

if success:
    print("Test passed successfully, datasets are identical!")
else:
    print("Test failed, datasets classes are not identical!")

704
Started downloading datasets
Downloaded datasets
Converted poses
Converted rgbs
Converted depths
Converted normals
Dataset intrin length: 704
352
Started downloading datasets
Downloaded datasets
Converted poses
Converted rgbs
Converted depths
Converted normals
Dataset intrin length: 352
Test passed successfully, datasets are identical!
