In [None]:
# This test is to be run in Google Colab using a GPU backend (tested on A100 runtime)

In [None]:
# Test if custom Dataset returns same rgbs as base srn Dataset

In [None]:
# Clone necessary repos
!git clone --recursive https://github.com/Kacper-M-Michalik/splatter-image.git
!git clone --recursive https://github.com/graphdeco-inria/gaussian-splatting.git

In [None]:
# Get SRN cars dataset for cars test
%cd /content
!mkdir SRN
%cd /content/SRN
!mkdir srn_cars
%cd /content/SRN/srn_cars
!gdown --id 19yDsEJjx9zNpOKz9o6AaK-E8ED6taJWU -O cars.zip
!unzip cars.zip
%cd /content

In [None]:
# Imports
import sys
import os
import torch
from torch.utils.data import DataLoader
from omegaconf import OmegaConf

# Perform import of get_dataset from the splatter-image submodule
splatter_root = os.path.join("splatter-image")
if splatter_root not in sys.path:
    sys.path.append(splatter_root)
from splatter_datasets.dataset_factory import get_dataset       

In [None]:
# Construct configs for datasets
srn_cfg = OmegaConf.load(os.path.join(splatter_root, "configs", "default_config.yaml"))
srn_cfg.data.category = "cars"
srn_cfg.data.znear = 1
srn_cfg.data.zfar = 2
srn_cfg.data.fov = 1

# Need to fill in some dummy data
srn_priors_cfg = OmegaConf.load(os.path.join(splatter_root, "configs", "default_config.yaml"))
srn_priors_cfg.data.category = "cars_priors"
srn_priors_cfg.data.znear = 1
srn_priors_cfg.data.zfar = 2
srn_priors_cfg.data.fov = 1

print(srn_cfg)
print(srn_priors_cfg)

In [None]:
splits = ["test", "train", "val"]
success = True

for split in splits:
    old_dataset = get_dataset(srn_cfg, split)
    new_dataset = get_dataset(srn_priors_cfg, split)
    
    # Ensures both datasets are of same length
    if len(old_dataset) != len(new_dataset):
        print("Mismatch in dataset length for {} split".format(split))
        success = False
        break
    
    # Shuffle=False ensures the same uuid batches are loaded from both Datasets
    loader_old = DataLoader(old_dataset, batch_size=1, shuffle=False)
    loader_new = DataLoader(new_dataset, batch_size=1, shuffle=False)

    # Perform equality check on common key/value pairs
    i = 0
    for batch_old, batch_new in zip(loader_old, loader_new):
        if (not torch.equal(batch_old["gt_images"], batch_new["gt_images"]) or 
            not torch.equal(batch_old["world_view_transforms"], batch_new["world_view_transforms"]) or 
            not torch.equal(batch_old["view_to_world_transforms"], batch_new["view_to_world_transforms"]) or 
            not torch.equal(batch_old["full_proj_transforms"], batch_new["full_proj_transforms"]) or 
            not torch.equal(batch_old["camera_centers"], batch_new["camera_centers"])):
            success = False
            print("Found mismatched entry in datasets at split:{}, index:{}".format(split, i))
            break
        i += 1
        
    if not success:
        break
      
    print("Compeleted {} split successfully".format(split))

if success:
    print("Test passed successfully, datasets are identical!")
else:
    print("Test failed, datasets classes are not identical!")