In [4]:
import numpy as np

np.random.seed(0)

In [1]:
import kagglehub

DATASET_PATH = kagglehub.dataset_download('balraj98/berkeley-segmentation-dataset-500-bsds500')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
import dataclasses
import typing

import scipy.io
import skimage.io


def load_img(filename):
    return skimage.io.imread(filename)


def load_mat(filename):
    mat = scipy.io.loadmat(filename)
    gts = mat['groundTruth']
    num_annotators = gts.shape[1]

    return [gts[0, i]['Segmentation'][0, 0].astype(np.uint8) for i in range(num_annotators)]


@dataclasses.dataclass(frozen=True)
class TestingImage:
    filename: str

    img_test: np.ndarray[typing.Any, np.uint8]
    ground_truth: np.ndarray[typing.Any, np.uint8]


def list_files_in_directory(path):
    files_list = []

    for entry in os.listdir(path):
        full_path = os.path.join(path, entry)

        if os.path.isfile(full_path):
            files_list.append(full_path)

    return files_list


def load_bsd_test_images(all_filenames):
    for img_filename in all_filenames:
        filename = img_filename.split("/")[-1].split(".")[0]
        ground_truth_filename = DATASET_PATH + f"/ground_truth/test/{filename}.mat"

        try:
            yield TestingImage(
                filename=filename,
                img_test=load_img(img_filename),
                ground_truth=load_mat(ground_truth_filename),
            )
        except OSError:
            continue


all_image_test = list_files_in_directory(DATASET_PATH + "/images/test")
all_testing_img = [*load_bsd_test_images(all_image_test)]

len(all_testing_img)

200

In [26]:
from ktree.ntree import NTree

points = np.array([np.random.randint(0, 255, 3) for _ in range(30)]).astype(np.uint8)

x0, y0, z0 = min(points[:, 0]), min(points[:, 1]), min(points[:, 2])
x1, y1, z1 = max(points[:, 0]), max(points[:, 1]), max(points[:, 2])

tree = NTree(np.array([(x0, x1), (y0, y1), (z0, z1)]).astype(int), 0)

for p in points:
    tree.insert(p)

sort_elements = tree.sort()

for nodes in sort_elements:
    print("Cluster:", nodes)
    print("Length: ", len(nodes.data))


Cluster: Cluster(axis=[(np.int64(7), np.float64(120.0)), (np.int64(0), np.float64(123.0)), (np.int64(0), np.float64(124.0))])
Length:  3
Cluster: Cluster(axis=[(np.float64(120.0), np.int64(233)), (np.int64(0), np.float64(123.0)), (np.float64(124.0), np.int64(248))])
Length:  1
Cluster: Cluster(axis=[(np.float64(120.0), np.int64(233)), (np.float64(123.0), np.int64(246)), (np.float64(124.0), np.int64(248))])
Length:  8
Cluster: Cluster(axis=[(np.float64(120.0), np.int64(233)), (np.int64(0), np.float64(123.0)), (np.int64(0), np.float64(124.0))])
Length:  5
Cluster: Cluster(axis=[(np.int64(7), np.float64(120.0)), (np.float64(123.0), np.int64(246)), (np.float64(124.0), np.int64(248))])
Length:  5
Cluster: Cluster(axis=[(np.float64(120.0), np.int64(233)), (np.float64(123.0), np.int64(246)), (np.int64(0), np.float64(124.0))])
Length:  1
Cluster: Cluster(axis=[(np.int64(7), np.float64(120.0)), (np.float64(123.0), np.int64(246)), (np.int64(0), np.float64(124.0))])
Length:  4
Cluster: Cluster(ax