In [1]:
import numpy as np

np.random.seed(0)

In [2]:
import kagglehub

DATASET_PATH = kagglehub.dataset_download('balraj98/berkeley-segmentation-dataset-500-bsds500')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import os
import dataclasses
import typing

import scipy.io
import skimage.io


def load_image(filename: str) -> np.ndarray:
    return skimage.io.imread(filename)


def load_ground_truth(filename) -> np.ndarray:
    mat = scipy.io.loadmat(filename)
    gts = mat['groundTruth']
    num_annotators = gts.shape[1]

    return np.array([gts[0, i]['Segmentation'][0, 0].astype(np.uint8) for i in range(num_annotators)])


@dataclasses.dataclass(frozen=True)
class DataSet:
    id: int

    image: np.ndarray
    ground_truth: np.ndarray

    def __len__(self) -> int:
        return self.ground_truth.shape[0]

    def __getitem__(self, index: int) -> np.ndarray:
        return self.ground_truth[index]

    def __iter__(self) -> typing.Iterator[np.ndarray]:
        return iter(self.ground_truth)


@dataclasses.dataclass(frozen=True)
class GroundTruthDataSet:
    id: int
    id_ground_truth: int

    image: np.ndarray
    source: np.ndarray

    def __len__(self) -> int:
        return len(np.unique(self.source))


def load_files_from_directory(path: str) -> list[str]:
    files_list = []

    for entry in os.listdir(path):
        full_path = os.path.join(path, entry)

        if os.path.isfile(full_path):
            files_list.append(full_path)

    return files_list


def load_dataset_bsds500(all_filenames) -> list[DataSet]:
    all_data_set = []

    for img_filename in all_filenames:
        filename = img_filename.split("/")[-1].split(".")[0]
        ground_truth_filename = DATASET_PATH + f"/ground_truth/test/{filename}.mat"

        if not filename.isdigit():
            continue

        _id = int(filename)

        try:

            data_set = DataSet(
                id=_id,
                image=load_image(img_filename),
                ground_truth=load_ground_truth(ground_truth_filename),
            )

            all_data_set.append(data_set)
        except OSError:
            continue

    return all_data_set


def load_ground_truth_dataset_bsds500(all_filenames) -> list[GroundTruthDataSet]:
    all_gt_data_set = []

    for img_filename in all_filenames:
        filename = img_filename.split("/")[-1].split(".")[0]
        ground_truth_filename = DATASET_PATH + f"/ground_truth/test/{filename}.mat"

        if not filename.isdigit():
            continue

        _id = int(filename)

        try:
            image = load_image(img_filename)

            for n, ground_truth in enumerate(load_ground_truth(ground_truth_filename)):
                gd_dataset = GroundTruthDataSet(
                    id=_id,
                    id_ground_truth=n,
                    image=image,
                    source=ground_truth,
                )

                all_gt_data_set.append(gd_dataset)
        except OSError:
            continue

    return all_gt_data_set


all_image_test = load_files_from_directory(DATASET_PATH + "/images/test")
all_gd_dataset_bsd500 = load_ground_truth_dataset_bsds500(all_image_test)

len(all_gd_dataset_bsd500)

1063

In [4]:
all_gd_dataset_bsd500[0].image.reshape(-1, 3)

array([[55, 73, 73],
       [56, 74, 74],
       [57, 75, 75],
       ...,
       [53, 53, 45],
       [47, 49, 38],
       [41, 44, 35]], shape=(154401, 3), dtype=uint8)

In [5]:
from ktree.ntree import NTreeDynamic

data = all_gd_dataset_bsd500[0].image.reshape(-1, 3)

x0, y0, z0 = min(data[:, 0]), min(data[:, 1]), min(data[:, 2])
x1, y1, z1 = max(data[:, 0]), max(data[:, 1]), max(data[:, 2])
shape = np.array([(x0, x1), (y0, y1), (z0, z1)]).astype(float).tolist()

tree = NTreeDynamic(0)

for d in data:
    tree.insert(d)

sort_elements = tree.sort()

for nodes in sort_elements:
    print("Cluster:", nodes)
    print("Length: ", len(nodes.data))


Cluster: Cluster(axis=[(16.0, 135.5), (28.0, 141.0), (24.0, 139.5)])
Length:  29012
Cluster: Cluster(axis=[(16.0, 135.5), (28.0, 141.0), (139.5, 255.0)])
Length:  4705
Cluster: Cluster(axis=[(16.0, 135.5), (141.0, 254.0), (139.5, 255.0)])
Length:  2176
Cluster: Cluster(axis=[(135.5, 255.0), (141.0, 254.0), (139.5, 255.0)])
Length:  118174
Cluster: Cluster(axis=[(135.5, 255.0), (28.0, 141.0), (139.5, 255.0)])
Length:  258
Cluster: Cluster(axis=[(135.5, 255.0), (28.0, 141.0), (24.0, 139.5)])
Length:  69
Cluster: Cluster(axis=[(135.5, 255.0), (141.0, 254.0), (24.0, 139.5)])
Length:  7
