/
classes_heatmap_per_class.py
63 lines (50 loc) · 3.14 KB
/
classes_heatmap_per_class.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from typing import Tuple, Optional
import numpy as np
from data_gradients.common.registry.registry import register_feature_extractor
from data_gradients.utils.data_classes import DetectionSample
from data_gradients.feature_extractors.common.heatmap import BaseClassHeatmap
from data_gradients.utils.detection import scale_bboxes
@register_feature_extractor()
class DetectionClassHeatmap(BaseClassHeatmap):
"""
Provides a visual representation of object distribution across images in the dataset using heatmaps.
It helps identify common areas where objects are frequently detected, allowing insights into potential
biases in object placement or dataset collection.
"""
def __init__(self, n_rows: int = 12, n_cols: int = 2, heatmap_shape: Tuple[int, int] = (200, 200)):
"""
:param n_rows: How many rows per split.
:param n_cols: How many columns per split.
:param heatmap_shape: Heatmap, in (H, W) format. Increase for more resolution, at the expense of processing speed.
"""
super().__init__(n_rows=n_rows, n_cols=n_cols, heatmap_shape=heatmap_shape)
def update(self, sample: DetectionSample):
if not self.class_names:
self.class_names = sample.class_names
original_shape = sample.image.shape[:2]
bboxes_xyxy = scale_bboxes(old_shape=original_shape, new_shape=self.heatmap_shape, bboxes_xyxy=sample.bboxes_xyxy)
max_class_id = max(sample.class_names.keys())
split_heatmap = self.heatmaps_per_split.get(sample.split, np.zeros((max_class_id + 1, *self.heatmap_shape)))
for class_id, (x1, y1, x2, y2) in zip(sample.class_ids, bboxes_xyxy):
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
split_heatmap[class_id, y1:y2, x1:x2] += 1
self.heatmaps_per_split[sample.split] = split_heatmap
def _generate_title(self) -> str:
return "Bounding Box Density"
def _generate_description(self) -> str:
return (
"The heatmap represents areas of high object density within the images, providing insights into the spatial distribution of objects. "
"By examining the heatmap, you can quickly detect whether objects are predominantly concentrated in specific regions or if they are evenly "
"distributed throughout the scene. This information can serve as a heuristic to assess if the objects are positioned appropriately "
"within the expected areas of interest.<br/>"
"Note that images are resized to a square of the same dimension, which can affect the aspect ratio of objects. "
"This is done to focus on localization of objects in the scene (e.g. top-right, center, ...) independently of the original image sizes."
)
def _generate_notice(self) -> Optional[str]:
if len(self.class_names) > self.n_cols * self.n_rows:
return (
f"Only the {self.n_cols * self.n_rows} classes with highest density are shown.<br/>"
f"You can increase the number of classes by changing `n_cols` and `n_rows` in the configuration file."
)
else:
return None