/
utils.py
203 lines (158 loc) · 7 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
from typing_extensions import Literal
from .serialization import Serializable
from .types import BoxOrKeypointType, ScalarType, ScaleType, SizeType
if TYPE_CHECKING:
import torch
PAIR = 2
def get_shape(img: Union["np.ndarray", "torch.Tensor"]) -> SizeType:
if isinstance(img, np.ndarray):
return img.shape[:2]
try:
import torch
if torch.is_tensor(img):
return img.shape[-2:]
except ImportError:
pass
raise RuntimeError(
f"Albumentations supports only numpy.ndarray and torch.Tensor data type for image. Got: {type(img)}",
)
def format_args(args_dict: Dict[str, Any]) -> str:
formatted_args = []
for k, v in args_dict.items():
v_formatted = f"'{v}'" if isinstance(v, str) else str(v)
formatted_args.append(f"{k}={v_formatted}")
return ", ".join(formatted_args)
class Params(Serializable, ABC):
def __init__(self, format: str, label_fields: Optional[Sequence[str]] = None):
self.format = format
self.label_fields = label_fields
def to_dict_private(self) -> Dict[str, Any]:
return {"format": self.format, "label_fields": self.label_fields}
class DataProcessor(ABC):
def __init__(self, params: Params, additional_targets: Optional[Dict[str, str]] = None):
self.params = params
self.data_fields = [self.default_data_name]
if additional_targets is not None:
self.add_targets(additional_targets)
@property
@abstractmethod
def default_data_name(self) -> str:
raise NotImplementedError
def add_targets(self, additional_targets: Dict[str, str]) -> None:
"""Add targets to transform them the same way as one of existing targets"""
for k, v in additional_targets.items():
if v == self.default_data_name and k not in self.data_fields:
self.data_fields.append(k)
def ensure_data_valid(self, data: Dict[str, Any]) -> None:
pass
def ensure_transforms_valid(self, transforms: Sequence[object]) -> None:
pass
def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
rows, cols = get_shape(data["image"])
for data_name in self.data_fields:
if data_name in data:
data[data_name] = self.filter(data[data_name], rows, cols)
data[data_name] = self.check_and_convert(data[data_name], rows, cols, direction="from")
return self.remove_label_fields_from_data(data)
def preprocess(self, data: Dict[str, Any]) -> None:
data = self.add_label_fields_to_data(data)
rows, cols = data["image"].shape[:2]
for data_name in self.data_fields:
if data_name in data:
data[data_name] = self.check_and_convert(data[data_name], rows, cols, direction="to")
def check_and_convert(
self,
data: List[BoxOrKeypointType],
rows: int,
cols: int,
direction: Literal["to", "from"] = "to",
) -> List[BoxOrKeypointType]:
if self.params.format == "albumentations":
self.check(data, rows, cols)
return data
if direction == "to":
return self.convert_to_albumentations(data, rows, cols)
if direction == "from":
return self.convert_from_albumentations(data, rows, cols)
raise ValueError(f"Invalid direction. Must be `to` or `from`. Got `{direction}`")
@abstractmethod
def filter(self, data: Sequence[BoxOrKeypointType], rows: int, cols: int) -> Sequence[BoxOrKeypointType]:
pass
@abstractmethod
def check(self, data: List[BoxOrKeypointType], rows: int, cols: int) -> None:
pass
@abstractmethod
def convert_to_albumentations(self, data: List[BoxOrKeypointType], rows: int, cols: int) -> List[BoxOrKeypointType]:
pass
@abstractmethod
def convert_from_albumentations(
self,
data: List[BoxOrKeypointType],
rows: int,
cols: int,
) -> List[BoxOrKeypointType]:
pass
def add_label_fields_to_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
if self.params.label_fields is None:
return data
for data_name in self.data_fields:
if data_name in data:
for field in self.params.label_fields:
if not len(data[data_name]) == len(data[field]):
raise ValueError(
f"The lengths of bboxes and labels do not match. Got {len(data[data_name])} "
f"and {len(data[field])} respectively.",
)
data_with_added_field = []
for d, field_value in zip(data[data_name], data[field]):
data_with_added_field.append([*list(d), field_value])
data[data_name] = data_with_added_field
return data
def remove_label_fields_from_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
if not self.params.label_fields:
return data
label_fields_len = len(self.params.label_fields)
for data_name in self.data_fields:
if data_name in data:
for idx, field in enumerate(self.params.label_fields):
data[field] = [bbox[-label_fields_len + idx] for bbox in data[data_name]]
data[data_name] = [d[:-label_fields_len] for d in data[data_name]]
return data
def to_tuple(
param: ScaleType,
low: Optional[ScaleType] = None,
bias: Optional[ScalarType] = None,
) -> Union[Tuple[int, int], Tuple[float, float]]:
"""Convert input argument to a min-max tuple.
Args:
param: Input value which could be a scalar or a sequence of exactly 2 scalars.
low: Second element of the tuple, provided as an optional argument for when `param` is a scalar.
bias: An offset added to both elements of the tuple.
Returns:
A tuple of two scalars, optionally adjusted by `bias`.
Raises ValueError for invalid combinations or types of arguments.
"""
# Validate mutually exclusive arguments
if low is not None and bias is not None:
msg = "Arguments 'low' and 'bias' cannot be used together."
raise ValueError(msg)
if isinstance(param, Sequence) and len(param) == PAIR:
min_val, max_val = min(param), max(param)
# Handle scalar input
elif isinstance(param, (int, float)):
if isinstance(low, (int, float)):
# Use low and param to create a tuple
min_val, max_val = (low, param) if low < param else (param, low)
else:
# Create a symmetric tuple around 0
min_val, max_val = -param, param
else:
msg = "Argument 'param' must be either a scalar or a sequence of 2 elements."
raise ValueError(msg)
# Apply bias if provided
if bias is not None:
return (bias + min_val, bias + max_val)
return min_val, max_val