In [1]:
%load_ext autoreload
%autoreload 2

import torch
from torch.utils.data import Dataset
import time

from nerfstudio.models.nerfacto import NerfactoModelConfig
from nerfstudio.data.scene_box import SceneBox
from random import randrange

  warn(f"Failed to load image Python extension: {e}")


In [2]:
class RandomDataset(Dataset):
    def __init__(self, num_samples, input_size, target_size, id=0):
        self.num_samples = num_samples
        self.input_size = input_size
        self.target_size = target_size
        self.id = id
        self.count = 0

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        print("Self.id: ", self.id, " - ", self.count, " - ", randrange(50))
        self.count += 1
        # print("Get item called")
        time.sleep(2)
        input_data = torch.randn(self.input_size)
        target_data = torch.randn(self.target_size)
        return str(self.id)

# Usage example
num_samples = 1000
input_size = (3, 224, 224)
target_size = (10,)

random_dataset = RandomDataset(num_samples, input_size, target_size)

datasets = [RandomDataset(num_samples, input_size, target_size, id=i) for i in range(10)]

print(random_dataset[0])

Self.id:  0  -  0  -  35
0


In [3]:
class IterableDataset(Dataset):
    def __init__(self, iter_datasets):
        self.iter_datasets = iter_datasets
        
    def __len__(self):
        return len(self.iter_datasets)
    
    def __getitem__(self, index):
        return next(self.iter_datasets[index])
    

In [4]:
class ContinuousDataLoader:
    def __init__(self, dataloader):
        self.data_loader = dataloader
        self.data_iter = iter(self.data_loader)

    def __iter__(self):
        return self

    def __next__(self):
        try:
            return next(self.data_iter)
        except StopIteration:
            self.data_iter = iter(self.data_loader)
            return next(self.data_iter)

In [13]:
iter_datasets = [iter(dataset) for dataset in datasets]
meta_dataset = IterableDataset(iter_datasets)
meta_dataloader = torch.utils.data.DataLoader(meta_dataset, batch_size=1, shuffle=False, num_workers=5, prefetch_factor=2)
from itertools import cycle

inifinite_meta_dataloader = cycle(meta_dataloader)

while True:
    time1 = time.time()
    print(next(inifinite_meta_dataloader))
    time2 = time.time()
    print("Took ", time2 - time1, " seconds")
    time.sleep(1)

Self.id: Self.id: Self.id: Self.id:     023  1  -   -   -   -   2828 2628   -  29  -  -   6
9 - 
 
30
Self.id:  4  -  26  -  7
Self.id:  7 Self.id:  -   266Self.id:    -    - 89  26
 -    - 26  43 - 
 37
Self.id:  5  -  26  -  31
Self.id:  9  -  26  -  21
['0']
Took  2.1125333309173584  seconds
['1']
Took  0.0003235340118408203  seconds
['2']
Took  0.0002429485321044922  seconds
['3']
Took  0.0002887248992919922  seconds
['4']
Took  0.00046253204345703125  seconds
['5']
Took  0.0004875659942626953  seconds
['6']
Took  0.0002741813659667969  seconds
['7']
Took  0.0004222393035888672  seconds
['8']
Took  0.00043845176696777344  seconds
['9']
Took  0.0006718635559082031  seconds
['0']
Took  0.02902054786682129  seconds
['1']
Took  0.00017404556274414062  seconds


KeyboardInterrupt: 

In [17]:
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from torch.utils.data import DataLoader
from itertools import cycle

class PrefetchLoader:
    def __init__(self, datasets, batch_size, prefetch_batches):
        self.datasets = cycle(datasets)
        self.batch_size = batch_size
        self.prefetch_batches = prefetch_batches
        self.executor = ThreadPoolExecutor(max_workers=prefetch_batches)
        self.queue = Queue(maxsize=prefetch_batches)
        self.queued = 0

    def __iter__(self):
        return self

    def prefetch(self):
        dataset = next(self.datasets)
        loader = DataLoader(dataset, batch_size=self.batch_size)
        self.queue.put(next(iter(loader)))
        self.queued -= 1

    def __next__(self):
        print("Queue size: ", self.queue.qsize(), "queued", self.queued)
        if not self.queue.full():
            print("Prefetching", self.prefetch_batches - self.queue.qsize(), "batches")
            self.queued +=  self.prefetch_batches - self.queue.qsize()
            futures = [self.executor.submit(self.prefetch) for _ in range(self.prefetch_batches - self.queue.qsize())]
        return self.queue.get()


In [18]:
my_loader = PrefetchLoader(datasets, batch_size=1, prefetch_batches=10)

for data in my_loader:
    print(data)
    time.sleep(1)

Queue size:  0 queued 0
Prefetching 10 batches
Self.id:  0  -  56  -  35
Self.id:  1  -  56  -  32
Self.id:  2  -  56  -  7
Self.id:  3  -  54  -  7
Self.id:  4  -  54  -  30
Self.id:  7  -  53  -  30
Self.id:  5  -  54  -  3
Self.id:  6  -  54  -  38
Self.id:  8  -  53  -  44
Self.id:  9  -  53  -  9
['0']
Queue size:  9 queued 0
Prefetching 1 batches
['2']
Self.id:  0  -  57  -  40
Queue size:  8 queued 1
Prefetching 2 batches
['1']
Self.id:  1  -  57  -  13
Self.id:  2  -  57  -  31
Queue size:  7 queued 3
Prefetching 3 batches
['4']
Self.id:  4  -  55  -  35
Self.id:  5  -  55  -  42
Self.id:  3  -  55  -  43
Queue size:  7 queued 5
Prefetching 3 batches
['3']
Self.id:  6  -  55  -  25
Self.id:  8  -  54  -  41
Self.id:  7  -  54  -  40
Queue size:  8 queued 6
Prefetching 2 batches
['7']
Self.id:  0  -  58  -  25
Self.id:  9  -  54  -  14
Queue size:  10 queued 5
['5']
Queue size:  10 queued 4
['9']
Queue size:  10 queued 3
['8']
Queue size:  10 queued 2
['6']
Queue size:  10 queue

KeyboardInterrupt: 

In [26]:
import torch
from torch.utils.data import DataLoader

class CustomDataLoader(DataLoader):
    def __init__(self, dataset, batch_size=1, shuffle=False, num_workers=0, prefetch_factor=2):
        super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
        self.prefetch_factor = prefetch_factor
        self.prefetched_samples = []
        self._original_iterator = iter(super().__iter__())

    def __iter__(self):
        self.prefetched_samples = []
        self._iterator = iter(self._original_iterator)
        self._prefetch_samples()
        return self

    def __next__(self):
        if len(self.prefetched_samples) == 0:
            self._prefetch_samples()

        if len(self.prefetched_samples) == 0:
            raise StopIteration

        return self.prefetched_samples.pop(0)

    def _prefetch_samples(self):
        while len(self.prefetched_samples) < self.prefetch_factor:
            try:
                batch = next(self._iterator)
                self.prefetched_samples.append(batch)
            except StopIteration:
                break



In [27]:
# Iterate over the random dataset using the custom data loader
batch_size = 1
num_workers = 4
prefetch_factor = 2

custom_loader = CustomDataLoader(random_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                                prefetch_factor=prefetch_factor)
custom_iter = iter(custom_loader)


Self.id:  
0Self.id:  0
Self.id:  0
Self.id:  0
Self.id:  0
Self.id:  0


Self.id:  0
Self.id:  0
Self.id:  0
Self.id:  0
Self.id:  0
Self.id:  0
Self.id:  0
Self.id:  0
Self.id:  0
Self.id:  0
Self.id:  0
Self.id:  0
Self.id:  0
Self.id:  0
Self.id:  0
Self.id:  0
Self.id:  0
Self.id:  0


In [28]:
while True:
    time1 = time.time()
    try:
        batch = next(custom_iter)
    except StopIteration:
        break
    time2 = time.time()
    print("Getting batch took {} seconds".format(time2-time1))
    print(len(batch))
    print(batch[0].shape)
    print(batch[1].shape)
    # Use the batch for training or processing
    pass

Getting batch took 0.0005524158477783203 seconds
2
torch.Size([1, 3, 224, 224])
torch.Size([1, 10])
Getting batch took 0.0003113746643066406 seconds
2
torch.Size([1, 3, 224, 224])
torch.Size([1, 10])
Getting batch took 0.00818181037902832 seconds
2
torch.Size([1, 3, 224, 224])
torch.Size([1, 10])
Getting batch took 0.00014162063598632812 seconds
2
torch.Size([1, 3, 224, 224])
torch.Size([1, 10])
Getting batch took 0.8961060047149658 seconds
2
torch.Size([1, 3, 224, 224])
torch.Size([1, 10])
Getting batch took 0.0002777576446533203 seconds
2
torch.Size([1, 3, 224, 224])
torch.Size([1, 10])
Getting batch took 0.40130186080932617 seconds
2
torch.Size([1, 3, 224, 224])
torch.Size([1, 10])
Getting batch took 0.0008094310760498047 seconds
2
torch.Size([1, 3, 224, 224])
torch.Size([1, 10])
Getting batch took 1.6627304553985596 seconds
2
torch.Size([1, 3, 224, 224])
torch.Size([1, 10])
Getting batch took 0.0006418228149414062 seconds
2
torch.Size([1, 3, 224, 224])
torch.Size([1, 10])
Getting b

KeyboardInterrupt: 

In [33]:
class CustomMultiDataLoader(DataLoader):
    def __init__(self, datasets, batch_size=1, shuffle=False, num_workers=0, prefetch_factor=2):
        super().__init__(datasets[0], batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
        self.prefetch_factor = prefetch_factor
        self.prefetched_samples = []
        self._datasets = datasets
        self._original_iterators = [iter(dataset) for dataset in self._datasets]
        self._current_dataset_index = 0
        self._original_iterator = self._original_iterators[self._current_dataset_index]

    def __iter__(self):
        self.prefetched_samples = []
        self._current_dataset_index = 0
        self._original_iterator = self._original_iterators[self._current_dataset_index]
        self._prefetch_samples()
        return self

    def __next__(self):
        if len(self.prefetched_samples) == 0:
            self._prefetch_samples()

        if len(self.prefetched_samples) == 0:
            raise StopIteration

        return self.prefetched_samples.pop(0)

    def _prefetch_samples(self):
        while len(self.prefetched_samples) < self.prefetch_factor:
            try:
                self._current_dataset_index = (self._current_dataset_index + 1) % len(self._datasets)
                self._original_iterator = self._original_iterators[self._current_dataset_index]
                batch = next(self._original_iterator)
                self.prefetched_samples.append(batch)
            except StopIteration:
                print("StopIteration")
                self._current_dataset_index = (self._current_dataset_index + 1) % len(self._datasets)
                self._original_iterator = self._original_iterators[self._current_dataset_index]


In [34]:
batch_size = 1
num_workers = 2
prefetch_factor = 2
custom_loader = CustomMultiDataLoader(datasets, batch_size=batch_size, prefetch_factor=prefetch_factor, shuffle=True, num_workers=num_workers)
multi_iter = iter(custom_loader)
while True:
    time1 = time.time()
    try:
        batch = next(multi_iter)
    except StopIteration:
        break
    time2 = time.time()
    print("Getting batch took {} seconds".format(time2-time1))
    print(len(batch))
    print(batch[0].shape)
    print(batch[1].shape)
    # Use the batch for training or processing
    pass

Self.id:  1
Self.id:  2
Getting batch took 4.982948303222656e-05 seconds
2
torch.Size([3, 224, 224])
torch.Size([10])
Getting batch took 2.6464462280273438e-05 seconds
2
torch.Size([3, 224, 224])
torch.Size([10])
Self.id:  3
Self.id:  4
Getting batch took 4.0117106437683105 seconds
2
torch.Size([3, 224, 224])
torch.Size([10])
Getting batch took 0.0005445480346679688 seconds
2
torch.Size([3, 224, 224])
torch.Size([10])
Self.id:  5
Self.id:  6
Getting batch took 4.008026361465454 seconds
2
torch.Size([3, 224, 224])
torch.Size([10])
Getting batch took 4.315376281738281e-05 seconds
2
torch.Size([3, 224, 224])
torch.Size([10])
Self.id:  7
Self.id:  8


KeyboardInterrupt: 

In [2]:
load_path_0 = "/YOUR/PATH/HERE/klever_depth_normal_models_nesf/0/depth-nerfacto/07_04_23/nerfstudio_models/step-000007999.ckpt"
load_path_1 = "/YOUR/PATH/HERE/klever_depth_normal_models_nesf/1/depth-nerfacto/07_04_23/nerfstudio_models/step-000007999.ckpt"

In [7]:
def load_model(load_path):
    time1 = time.time()
    loaded_state = torch.load(load_path, map_location="cpu")["pipeline"]
    time2 = time.time()

    model=NerfactoModelConfig(eval_num_rays_per_chunk=1 << 15,
                                        predict_normals=True
                                        )
    time_2_5 = time.time()
    scene_box = SceneBox(aabb = torch.zeros((2,3)))
    model = model.setup(scene_box=scene_box,
                num_train_data=271,
                metadata={})

    time3 = time.time()
    state = {key.replace("module.", ""): value for key, value in loaded_state.items()}
    state = {key.replace("_model.", ""): value for key, value in state.items()}
    time4 = time.time()



    missing_keys, unexpected_keys = model.load_state_dict(state, strict=False)
    # gt_sd = model.state_dict()
    print("Missing keys: {}".format(missing_keys))
    print("Unexpected keys: {}".format(unexpected_keys))
    time5 = time.time()

    print("Loading took {} seconds".format(time2-time1))
    print("Creating config took {} seconds".format(time_2_5-time2))
    print("Model setup took {} seconds".format(time3-time_2_5))
    print("State dict conversion took {} seconds".format(time4-time3))
    print("Loading state dict took {} seconds".format(time5-time4))
    return model

def get_model():
    time1 = time.time()
    model=NerfactoModelConfig(eval_num_rays_per_chunk=1 << 15,
                                        predict_normals=True
                                        )
    time_2 = time.time()
    scene_box = SceneBox(aabb = torch.zeros((2,3)))
    model = model.setup(scene_box=scene_box,
                num_train_data=271,
                metadata={})
    time3 = time.time()
    print("Creating config took {} seconds".format(time_2-time1))
    print("Model setup took {} seconds".format(time3-time_2))
    return model

def get_state_dicts(count):
    state_dicts = []
    for i in range(count):
        time1 = time.time()
        load_path = "/YOUR/PATH/HERE/klever_depth_normal_models_nesf/{}/depth-nerfacto/07_04_23/nerfstudio_models/step-000007999.ckpt".format(i)
        loaded_state = torch.load(load_path, map_location="cpu")["pipeline"]
        state = {key.replace("module.", ""): value for key, value in loaded_state.items()}
        state = {key.replace("_model.", ""): value for key, value in state.items()}
        state_dicts.append(state)
        time2 = time.time()
        print("Loading took {} seconds".format(time2-time1))
    return state_dicts

def load_state_dict_to_model(model, state_dict):
    time3 = time.time()
    state_dict = {key.replace("module.", ""): value for key, value in state_dict.items()}
    state_dict = {key.replace("_model.", ""): value for key, value in state_dict.items()}
    time4 = time.time()



    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
    # gt_sd = model.state_dict()
    print("Missing keys: {}".format(missing_keys))
    print("Unexpected keys: {}".format(unexpected_keys))
    time5 = time.time()

    print("State dict conversion took {} seconds".format(time4-time3))
    print("Loading state dict took {} seconds".format(time5-time4))
    return model
    


In [6]:
state_dicts = get_state_dicts(10)

Loading took 0.5905206203460693 seconds
Loading took 0.1249542236328125 seconds
Loading took 0.5644783973693848 seconds
Loading took 0.5164320468902588 seconds
Loading took 0.5863499641418457 seconds
Loading took 0.4617340564727783 seconds
Loading took 0.49069833755493164 seconds
Loading took 0.48426032066345215 seconds
Loading took 0.5144634246826172 seconds
Loading took 0.4097611904144287 seconds


In [13]:
model = get_model()

# compute average model loading time
time1 = time.time()
K=50
for i in range(K):
    model = get_model()
time2 = time.time()
print("Average model loading time: {}".format((time2-time1)/K))

Creating config took 5.5789947509765625e-05 seconds
Model setup took 0.9349899291992188 seconds
Creating config took 4.673004150390625e-05 seconds
Model setup took 0.7964234352111816 seconds
Creating config took 2.5033950805664062e-05 seconds
Model setup took 0.7067139148712158 seconds
Creating config took 1.8358230590820312e-05 seconds
Model setup took 0.8017966747283936 seconds
Creating config took 2.288818359375e-05 seconds
Model setup took 0.7316091060638428 seconds
Creating config took 1.8358230590820312e-05 seconds
Model setup took 0.8432517051696777 seconds
Creating config took 2.3603439331054688e-05 seconds
Model setup took 0.7980501651763916 seconds
Creating config took 1.8358230590820312e-05 seconds
Model setup took 0.7507121562957764 seconds
Creating config took 1.7404556274414062e-05 seconds
Model setup took 0.87148118019104 seconds
Creating config took 1.8358230590820312e-05 seconds
Model setup took 0.8608250617980957 seconds
Creating config took 2.8848648071289062e-05 sec

In [15]:
time1 = time.time()
for state_dict in state_dicts:
    model = load_state_dict_to_model(model, state_dict)
    print(model.device)
    
time2 = time.time()
print("Average model loading time: {}".format((time2-time1)/len(state_dicts)))

Missing keys: []
Unexpected keys: []
State dict conversion took 1.9073486328125e-05 seconds
Loading state dict took 0.022670745849609375 seconds
cpu
Missing keys: []
Unexpected keys: []
State dict conversion took 2.7179718017578125e-05 seconds
Loading state dict took 0.02508378028869629 seconds
cpu
Missing keys: []
Unexpected keys: []
State dict conversion took 2.4318695068359375e-05 seconds
Loading state dict took 0.030529260635375977 seconds
cpu
Missing keys: []
Unexpected keys: []
State dict conversion took 2.3603439331054688e-05 seconds
Loading state dict took 0.03043675422668457 seconds
cpu
Missing keys: []
Unexpected keys: []
State dict conversion took 2.3126602172851562e-05 seconds
Loading state dict took 0.03284716606140137 seconds
cpu
Missing keys: []
Unexpected keys: []
State dict conversion took 3.5762786865234375e-05 seconds
Loading state dict took 0.03047919273376465 seconds
cpu
Missing keys: []
Unexpected keys: []
State dict conversion took 2.47955322265625e-05 seconds
Lo

In [6]:
m0 = load_model(load_path_0)
print(m0.scene_box)

Missing keys: []
Unexpected keys: []
Loading took 0.661447286605835 seconds
Creating config took 7.367134094238281e-05 seconds
Model setup took 10.829570770263672 seconds
State dict conversion took 8.726119995117188e-05 seconds
Loading state dict took 0.264911413192749 seconds
SceneBox(aabb=tensor([[-1., -1., -1.],
        [ 1.,  1.,  1.]]))


In [4]:
m1 = load_model(load_path_1)
print(m1.scene_box)

Missing keys: []
Unexpected keys: []
Loading took 0.878633975982666 seconds
Creating config took 4.673004150390625e-05 seconds
Model setup took 2.973586082458496 seconds
State dict conversion took 3.1948089599609375e-05 seconds
Loading state dict took 0.02256464958190918 seconds
SceneBox(aabb=tensor([[-1., -1., -1.],
        [ 1.,  1.,  1.]]))


In [7]:
import pickle

time1 = time.time()
pickle.dump(m0, open("m0.pkl", "wb"))
time2 = time.time()

# load the model from disk
model = pickle.load(open("m0.pkl", 'rb'))
time3 = time.time()
print("Pickling took {} seconds".format(time2-time1))
print("Unpickling took {} seconds".format(time3-time2))

AttributeError: Can't pickle local object 'UniformLinDispPiecewiseSampler.__init__.<locals>.<lambda>'