-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
304 lines (262 loc) · 11.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import dataclasses
import enum
import importlib
import inspect
import os
import random
import sys
import time
import typing
from pathlib import Path
from queue import Queue, Empty
import logging
from typing import TypeVar
from uuid import uuid4
import pydsdl
import pycyphal
import yukon
logger = logging.getLogger(__name__)
def quit_application(state: "yukon.domain.god_state.GodState") -> None:
state.gui.gui_running = False
state.dronecan_traffic_queues.input_queue.put_nowait(None)
state.dronecan_traffic_queues.output_queue.put_nowait(None)
state.queues.god_queue.put_nowait(None)
def add_path_to_cyphal_path(path: str) -> None:
normalized_path = Path(path).resolve()
if not normalized_path:
return
cyphal_path = os.environ.get("CYPHAL_PATH", None)
if cyphal_path:
normalized_cyphal_paths = [str(Path(path).resolve()) for path in cyphal_path]
if str(normalized_path) not in normalized_cyphal_paths:
os.environ["CYPHAL_PATH"] = f"{cyphal_path}{os.pathsep}{str(normalized_path)}"
def add_path_to_sys_path(path: str) -> None:
normalized_sys_paths = [str(Path(path).resolve()) for path in sys.path]
normalized_path = Path(path).resolve()
if str(normalized_path) not in normalized_sys_paths:
sys.path.append(str(normalized_path))
logger.debug("Added %r to sys.path", normalized_path)
@dataclasses.dataclass
class Datatype:
is_fixed_id: bool
name: str
class_reference: typing.Any
is_service: bool
def __hash__(self) -> int:
return hash(self.name)
def get_all_datatypes(path: Path) -> typing.List[Datatype]:
"""The path is to a folder like .compiled which contains dsdl packages"""
all_init_files = list(path.glob("**/__init__.py"))
# List the parent directories of every __init__.py file
init_file_packages = [init_file.parent for init_file in all_init_files]
# Going to check contents of each file and see if it has any line which does not begin with a #
# If it it does then we just ignore it
init_files_with_aliases = []
for init_file in all_init_files:
has_aliases = False
with open(init_file, "r") as file:
for line in file:
# If the line starts with # or is empty
if line.startswith("#") or line.strip() == "":
continue
else:
has_aliases = True
init_files_with_aliases.append(init_file)
break
all_classes = []
for init_file in init_files_with_aliases:
all_classes.extend(scan_package_look_for_classes(path, init_file.relative_to(path).parent))
all_classes = list(set(all_classes))
return all_classes
def scan_package_look_for_classes(
package_root: typing.Union[str, Path], package_path: typing.Union[str, Path]
) -> typing.List[Datatype]:
if isinstance(package_path, str):
package_path = Path(package_path)
if isinstance(package_root, str):
package_root = Path(package_root)
classes = []
add_path_to_sys_path(str(package_root))
proper_module_path = str(package_path).replace(os.path.sep, ".")
try:
package = importlib.import_module(proper_module_path)
except:
return []
# pycyphal.util.import_submodules(package)
# sys.path.remove(str(package_folder.absolute()))
queue: Queue = Queue()
queue.put((package, None)) # No previous class
counter = 0
# for debugging only
loop_unique_id = uuid4()
loop_start = time.monotonic()
try:
while True:
if time.monotonic() - loop_start > 5:
logger.error("Loop %s took more than 5 seconds", loop_unique_id)
break
counter += 1
module_or_class, previous_module_or_class = queue.get_nowait()
is_potentially_a_service = False
# Get all modules and classes that are seen in the imported module
elements = inspect.getmembers(module_or_class, lambda x: inspect.ismodule(x) or inspect.isclass(x))
for element in elements:
if element[1].__name__ == "object" or element[1].__name__ == "type":
continue
queue.put((element[1], module_or_class)) # Previous class was module_or_class
if inspect.isclass(module_or_class):
_class = module_or_class
if not hasattr(module_or_class, "_deserialize_") and not hasattr(module_or_class, "_serialize_"):
continue
try:
model = pycyphal.dsdl.get_model(_class)
# for debugging
# if "uavcan.node.port.List" in model.full_name:
# print("here")
except Exception:
logger.exception("Failed to get model for %s", _class)
continue
if _class.__name__ == "Request" or _class.__name__ == "Response":
is_potentially_a_service = True
classes.append(
Datatype(
hasattr(_class, "_FIXED_PORT_ID_"), model.full_name, _class, is_service=is_potentially_a_service
)
)
except Empty:
pass
# logger.debug(f"Loop {loop_unique_id} took {time.time() - loop_start} seconds")
return classes
def get_datatype_return_dto(all_classes: typing.List[Datatype]) -> typing.Any:
return_object: typing.Any = {
"fixed_id_messages": {},
"variable_id_messages": [],
}
for datatype in all_classes:
try:
if datatype.is_fixed_id:
return_object["fixed_id_messages"][str(datatype.class_reference._FIXED_PORT_ID_)] = {
"short_name": datatype.class_reference.__name__,
"name": datatype.name,
"is_service": datatype.is_service,
}
else:
return_object["variable_id_messages"].append(
{
"short_name": datatype.class_reference.__name__,
"name": datatype.name,
"is_service": datatype.is_service,
}
)
except Exception as e:
logger.error(str(e))
logger.exception("Failed to get datatype for %s", datatype)
# Sort return_object.variable_id_messages by name
return_object["variable_id_messages"] = sorted(return_object["variable_id_messages"], key=lambda x: x["name"]) # type: ignore
return return_object
# The user will provide only primitive values, all composite types are automatically generated around them
class PrimitiveFieldType(enum.Enum):
Real = 0
UnsignedInteger = 1
Integer = 2
Boolean = 3
String = 4
Unknown = 5
def determine_primitive_field_type(datatype: pydsdl.SerializableType) -> PrimitiveFieldType:
"""
Determine the primitive field type of a field
:param field: The field to determine the primitive field type of
:return: The primitive field type
"""
if isinstance(datatype, pydsdl.PrimitiveType):
if isinstance(datatype, pydsdl.SignedIntegerType):
return PrimitiveFieldType.Integer
elif isinstance(datatype, pydsdl.UnsignedIntegerType):
return PrimitiveFieldType.UnsignedInteger
elif isinstance(datatype, pydsdl.FloatType):
return PrimitiveFieldType.Real
elif isinstance(datatype, pydsdl.BooleanType):
return PrimitiveFieldType.Boolean
elif isinstance(datatype, pydsdl.StringType):
return PrimitiveFieldType.String
return PrimitiveFieldType.Unknown
@dataclasses.dataclass
class SimplifiedFieldDTO:
field_name: str
field_type: PrimitiveFieldType
is_array: bool
short_name: str
model_name: str
def get_all_fields_recursive(
field: pydsdl.Field,
properties: typing.List[SimplifiedFieldDTO],
previous_components: typing.List[str],
model_name: str,
depth: int = 0,
) -> None:
"""
Recursively get all fields of a composite type. Fills in the properties list.
:param field: The field to get the fields of
:param properties: The list of properties to append to
:param previous_components: The list of previous components to append to, components make up the full path
:param depth: The depth of the recursion
:return: None
"""
try:
previous_components.append(field.name)
for field in field.data_type.fields:
if not isinstance(field, pydsdl.PaddingField):
# print(f"{' ' * depth}{field.name}")
# This is where the attribute error comes from when it's not a compound type, that's ok
previous_path = ".".join(previous_components)
path = previous_path + "." + field.name
if isinstance(field.data_type, pydsdl.PrimitiveType):
properties.append(
SimplifiedFieldDTO(
path, determine_primitive_field_type(field.data_type), False, field.name, model_name
)
)
elif isinstance(field.data_type, pydsdl.ArrayType):
field_element_type = determine_primitive_field_type(field.data_type.element_type)
properties.append(SimplifiedFieldDTO(path, field_element_type, True, field.name, model_name))
else:
get_all_fields_recursive(field, properties, previous_components, model_name, depth + 1)
except AttributeError as e:
# No longer a CompositeType, a leaf node of some other type
pass
def get_all_field_dtos(obj: typing.Any) -> typing.List[SimplifiedFieldDTO]:
"""
Recursively get all properties of a composite type
:param obj: The object to get the properties of
"""
model = pycyphal.dsdl.get_model(obj)
properties = []
for field in model.fields_except_padding:
if isinstance(field.data_type, pydsdl.PrimitiveType):
properties.append(
SimplifiedFieldDTO(
field.name,
determine_primitive_field_type(field.data_type),
False,
field.name,
str(model) + "." + field.name,
)
)
elif isinstance(field.data_type, pydsdl.ArrayType):
field_element_type = determine_primitive_field_type(field.data_type.element_type)
if field_element_type == PrimitiveFieldType.Unknown:
get_all_fields_recursive(field.data_type.element_type, properties, [field.name], str(model))
else:
properties.append(
SimplifiedFieldDTO(field.name, field_element_type, True, field.name, str(model) + "." + field.name)
)
else:
get_all_fields_recursive(field, properties, [], str(model))
return properties
# These are for calculating the tolerance for the MonotonicClusteringSynchronizer
T = TypeVar("T")
def tolerance_from_key_delta(old: T, new: T) -> T:
return (new - old) * 0.5 # type: ignore
def clamp(lo_hi: tuple[T, T], val: T) -> T:
lo, hi = lo_hi
return min(max(lo, val), hi) # type: ignore