### How to use the PixelCounter and the WeightAlgorithm classes

In [6]:
# Built-in
import os
import logging
logging.basicConfig(level=logging.DEBUG)

# Third-party packages
import numpy as np
from tqdm import tqdm

# Segwork framework
from segwork.data import DroneDataset, NumpyPixelCounter

In [7]:
DATA_DIR = os.path.join('data')
ASSETS_DIR = 'assets'

In [8]:
# Instantiate dataset
dataset = DroneDataset( 
    root = os.path.join(DATA_DIR, 'semantic_drone_dataset'),
    pil_target=False,               # Target: numpy.ndarray
)

In [9]:
# The `num_classes = dataset.num_classes - 1` because background is not weighted.
pixel_counter = \
    NumpyPixelCounter( num_classes=dataset.num_classes - 1, dtype=np.longlong)
pixel_counter

Counter of pixels with 23 classes.
Pixel count:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0].
Class count:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

In [5]:
print(f'Number of pixel per image: {dataset.HEIGHT * dataset.WIDTH}')
label = dataset.load_weight_label(2)
pixel_counter.update(label)

Number of pixel per image: 24000000


(array([   33858, 15296778,   365294,  2688081,  2233916,        0,
               0,        0,  1046133,        0,   743212,   218419,
               0,        0,        0,   150142,        0,        0,
               0,   572787,        0,        0,   651380], dtype=int64),
 array([24000000, 24000000, 24000000, 24000000, 24000000,        0,
               0,        0, 24000000,        0, 24000000, 24000000,
               0,        0,        0, 24000000,        0,        0,
               0, 24000000,        0,        0, 24000000], dtype=int64))

In [23]:
pixel_counter.save_counters(os.path.join(ASSETS_DIR, 'pixel_count.npz'))

Weight file pixel_count.npz already exists, replacing file. Pass exist_ok=False attr to prevent it.


In [21]:
pixel_counter.reset_counters()

Counter of pixels with 23 classes.
Pixel count:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0].
Class count:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

In [27]:
pixel_counter.load_counters(os.path.join(ASSETS_DIR, 'pixel_count_all.npz'))

True

In [10]:
for idx in tqdm(range(dataset.num_data_points)):
    label = dataset.load_weight_label(idx)
    pixel_counter.update(label)

  0%|          | 1/400 [00:00<02:44,  2.42it/s]


FileNotFoundError: [Errno 2] No such file or directory: 'data\\semantic_drone_dataset\\training_set\\gt\\semantic\\label_numpy\\001.npy'

In [7]:
from segwork.data import balance as bl
import importlib
importlib.reload(bl)

weight_algorithm = bl.NumpyMedianFrequencyWeight( pixel_counter = pixel_counter)

In [47]:
weights = weight_algorithm.compute()
weights

array([16.85484274,  0.09279877,  0.95600028,  0.12590574,  0.41609928,
        0.462153  ,  2.69340118,  0.57643398,  0.46594422,  0.24907447,
        1.        ,  2.487816  , 10.87277168,  2.03968625, 28.96594261,
        3.21416904, 13.60558168,  0.74984716,  6.58144315,  0.83559121,
        0.57470975,  6.68112193,  1.01167484])

In [14]:
pixel_counter = NumpyPixelCounter( num_classes=dataset.num_classes - 1, dtype=np.longlong)
weight_algorithm = bl.NumpyMedianFrequencyWeight( pixel_counter = pixel_counter)
weights = dataset.compute_class_weights( weight_algorithm=weight_algorithm)

100%|██████████| 400/400 [02:12<00:00,  3.01it/s]


In [11]:
pixel_counter = NumpyPixelCounter( num_classes=dataset.num_classes - 1, dtype=np.longlong)
weight_algorithm = bl.NumpyMedianFrequencyWeight( pixel_counter = pixel_counter)
weights = dataset.compute_class_weights( weight_algorithm=weight_algorithm, path='pixel_count_all.npz')
weights

INFO:segwork.data.balance:Pixel counts loaded from pixel_count_all.npz


array([16.85484274,  0.09279877,  0.95600028,  0.12590574,  0.41609928,
        0.462153  ,  2.69340118,  0.57643398,  0.46594422,  0.24907447,
        1.        ,  2.487816  , 10.87277168,  2.03968625, 28.96594261,
        3.21416904, 13.60558168,  0.74984716,  6.58144315,  0.83559121,
        0.57470975,  6.68112193,  1.01167484])

### Pytorch calculator

In [1]:
# Built-in
import os
import logging
logging.basicConfig(level=logging.DEBUG)

# Third-party packages
import torch
import torchvision
import numpy as np
from tqdm import tqdm

# Segwork framework
from segwork.data import DroneDataset, NumpyPixelCounter,  PytorchPixelCounter

In [2]:
DATA_DIR = os.path.join(os.pardir, 'data')
ASSETS_DIR = 'assets'
# Instantiate dataset
dataset = DroneDataset( 
    root = os.path.join(DATA_DIR, 'semantic_drone_dataset'),
    pil_target=False,               # Target: numpy.ndarray
)
# The `num_classes = dataset.num_classes - 1` because background is not weighted.
pixel_counter = PytorchPixelCounter( num_classes=dataset.num_classes - 1)
pixel_counter

Counter of pixels with 23 classes.
Pixel count:
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0').
Class count:
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')

In [3]:
print(f'Number of pixel per image: {dataset.HEIGHT * dataset.WIDTH}')
label = dataset.load_weight_label(2)
pytorch_label = torch.from_numpy(label).to('cuda')
pixel_counter.update(pytorch_label)

Number of pixel per image: 24000000


(tensor([   33858, 15296778,   365294,  2688081,  2233916,        0,        0,
                0,  1046133,        0,   743212,   218419,        0,        0,
                0,   150142,        0,        0,        0,   572787,        0,
                0,   651380], device='cuda:0'),
 tensor([ True,  True,  True,  True,  True, False, False, False,  True, False,
          True,  True, False, False, False,  True, False, False, False,  True,
         False, False,  True], device='cuda:0'))

In [4]:
pixel_counter.save_counters(os.path.join(ASSETS_DIR, 'pixel_counts.pt'))

In [6]:
pixel_counter.reset_counters()
pixel_counter.load_counters(os.path.join(ASSETS_DIR, 'pixel_counts.pt'))

INFO:segwork.data.balance:Pixel counts loaded from assets\pixel_counts.pt


True

In [5]:
from segwork.data import NumpyMedianFrequencyWeight

In [None]:
weight_algorithm = NumpyMedianFrequencyWeight(pixel_counter=pixel_counter)