## Ensure that all Quadtree implementations are equivalent
We provide 3 different Quadtree implementations.

This notebook shows that they produce identical results, and compares runtime.

In [1]:
%load_ext autoreload
%autoreload 2
import setup

In [2]:
import torch
from mixed_res.quadtree_impl.quadtree_dict_lookup import DictLookupQuadtreeRunner
from mixed_res.quadtree_impl.quadtree_tensor_lookup import TensorLookupQuadtreeRunner
from mixed_res.quadtree_impl.quadtree_z_curve import ZCurveQuadtreeRunner
from mixed_res.patch_scorers.random_patch_scorer import RandomPatchScorer
from mixed_res.quadtree_impl.utils import sort_by_meta, is_power_of_2

device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 128
image_size = 256
num_patches = 100
min_patch_size = 16
max_patch_size = 64

images = torch.randn(batch_size, 3, image_size, image_size, device=device)

### Assert equivalence

In [3]:
patch_scorer = RandomPatchScorer(seed=1337)

# Init Quadtree runners from different implementations
runner_dict_lookup = DictLookupQuadtreeRunner(num_patches, min_patch_size, max_patch_size)
runner_tensor_lookup = TensorLookupQuadtreeRunner(num_patches, min_patch_size, max_patch_size)
if is_power_of_2(image_size):
    runner_z_curve = ZCurveQuadtreeRunner(num_patches, min_patch_size, max_patch_size)

# Run Quadtrees
res_dict_lookup = runner_dict_lookup.run_batch_quadtree(images, patch_scorer)
res_tensor_lookup = runner_tensor_lookup.run_batch_quadtree(images, patch_scorer)
if is_power_of_2(image_size):
    res_z_curve = runner_z_curve.run_batch_quadtree(images, patch_scorer)

# Sort results by metadata (patch location and scale) to make them comparable
res_dict_lookup = sort_by_meta(res_dict_lookup)
res_tensor_lookup = sort_by_meta(res_tensor_lookup)
if is_power_of_2(image_size):
    res_z_curve = sort_by_meta(res_z_curve)

# Assert that results are equivalent
assert torch.allclose(res_dict_lookup, res_tensor_lookup)
print("dict_lookup and tensor_lookup are equivalent")
if is_power_of_2(image_size):
    assert torch.allclose(res_dict_lookup, res_z_curve)
    print("dict_lookup and z_curve are equivalent")

dict_lookup and tensor_lookup are equivalent
dict_lookup and z_curve are equivalent


### Compare runtimes

In [4]:
%timeit runner_z_curve.run_batch_quadtree(images, patch_scorer)
%timeit runner_tensor_lookup.run_batch_quadtree(images, patch_scorer)
%timeit runner_dict_lookup.run_batch_quadtree(images, patch_scorer)

7.35 ms ± 16.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
21.2 ms ± 543 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
42.1 ms ± 1.15 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
