### PixelCounter and WeightAlgorithm classes

In [1]:
# Built-in
import os
import logging
logging.basicConfig(level=logging.DEBUG)

# Third-party packages
import numpy as np
from tqdm import tqdm

# Segwork framework
from segwork.data import DroneDataset, NumpyPixelCounter

In [2]:
DATA_DIR = os.path.join(os.pardir, 'data')
ASSETS_DIR = 'assets'

In [3]:
# Instantiate dataset
dataset = DroneDataset( 
    root = os.path.join(DATA_DIR, 'semantic_drone_dataset'),
    pil_target=False,               # Target: numpy.ndarray
)

Instantiating the numpy pixel counter

In [7]:
# The `num_classes = dataset.num_classes - 1` because background is not weighted.
pixel_counter = \
    NumpyPixelCounter( num_classes=dataset.num_classes - 1, dtype=np.longlong)
pixel_counter

Counter of pixels with 23 classes.
Pixel count:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0].
Class count:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

In [8]:
pixel_counter.pixel_count == 0

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True])

In [5]:
print(f'Number of pixel per image: {dataset.HEIGHT * dataset.WIDTH}')
label = dataset.load_weight_label(2)
pixel_counter.update(label)

Number of pixel per image: 24000000


(array([   33858, 15296778,   365294,  2688081,  2233916,        0,
               0,        0,  1046133,        0,   743212,   218419,
               0,        0,        0,   150142,        0,        0,
               0,   572787,        0,        0,   651380], dtype=int64),
 array([24000000, 24000000, 24000000, 24000000, 24000000,        0,
               0,        0, 24000000,        0, 24000000, 24000000,
               0,        0,        0, 24000000,        0,        0,
               0, 24000000,        0,        0, 24000000], dtype=int64))

In [23]:
pixel_counter.save_counters(os.path.join(ASSETS_DIR, 'pixel_count.npz'))

Weight file pixel_count.npz already exists, replacing file. Pass exist_ok=False attr to prevent it.


In [21]:
pixel_counter.reset_counters()

Counter of pixels with 23 classes.
Pixel count:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0].
Class count:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

In [11]:
pixel_counter.load_counters(os.path.join(ASSETS_DIR, 'pixel_count_all.npz'))

INFO:segwork.data.balance:Pixel counts loaded from assets\pixel_count_all.npz


True

In [12]:
for idx in tqdm(range(dataset.num_data_points)):
    label = dataset.load_weight_label(idx)
    pixel_counter.update(label)

100%|██████████| 400/400 [02:06<00:00,  3.17it/s]


### Using a weight algorithm

In [14]:
from segwork.data import NumpyMedianFrequencyWeight
weight_algorithm = NumpyMedianFrequencyWeight( pixel_counter = pixel_counter)

In [15]:
weights = weight_algorithm.compute()
weights

array([16.85484274,  0.09279877,  0.95600028,  0.12590574,  0.41609928,
        0.462153  ,  2.69340118,  0.57643398,  0.46594422,  0.24907447,
        1.        ,  2.487816  , 10.87277168,  2.03968625, 28.96594261,
        3.21416904, 13.60558168,  0.74984716,  6.58144315,  0.83559121,
        0.57470975,  6.68112193,  1.01167484])