Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions aperturedb/Query.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_specific(obj: BaseModel) -> dict:
RangeType.FRAME: "frame_number_range",
RangeType.FRACTION: "time_fraction_range"
}
start, stop = obj.start, obj.stop
start, stop = obj.start, obj.stop
if obj.range_type == RangeType.TIME:
start, stop = int(start), int(stop)
start = "{:0>2}:{:0>2}:{:0>2}".format(
Expand All @@ -94,7 +94,7 @@ def get_specific(obj: BaseModel) -> dict:
elif obj.range_type == RangeType.FRAME:
start = int(obj.start)
stop = int(obj.stop)
return{
return {
range_types[obj.range_type]: {
"start": start,
"stop": stop,
Expand Down Expand Up @@ -377,10 +377,10 @@ def spec(cls,
operations=operations,
with_class=with_class,
limit=limit,
sort = sort,
list = list,
sort=sort,
list=list,
blobs=blobs,
group_by_src = group_by_src,
group_by_src=group_by_src,
set=set,
vector=vector,
k_neighbors=k_neighbors
Expand Down
23 changes: 19 additions & 4 deletions aperturedb/Utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def visualize_schema(self, filename: str = None, format: str = "png") -> Source:
f'{idx_str}, {typ}</FONT></TD></TR>'
)
for connection, data in connections.items():
data_list = [data] if isinstance(data, dict) else data
data_list = self._normalize_class_data(data)
for data in data_list:
if data['src'] == entity:
matched = data["matched"]
Expand Down Expand Up @@ -226,7 +226,7 @@ def visualize_schema(self, filename: str = None, format: str = "png") -> Source:

if isinstance(connections, dict):
for connection, data in connections.items():
data_list = [data] if isinstance(data, dict) else data
data_list = self._normalize_class_data(data)
for data in data_list:
dot.edge(f'{data["src"]}:{connection}',
f'{data["dst"]}')
Expand Down Expand Up @@ -273,6 +273,21 @@ def _object_summary(self, name, object):

return total_elements

@staticmethod
def _normalize_class_data(data):
"""
Normalize class data returned from GetSchema.
ApertureDB can return connections as a dict where the keys are connection names
and values are the dicts we actually want, or as a single dict with "matched", etc,
or as a list. We normalize it to a list of dicts.
"""
if isinstance(data, dict):
if "matched" in data:
return [data]
else:
return list(data.values())
return data if isinstance(data, list) else [data]

def summary(self):
"""
Print a summary of the database.
Expand Down Expand Up @@ -315,8 +330,8 @@ def summary(self):
total_edges = 0
for c in connections_classes:
connections = r["connections"]["classes"][c]
connections_list = [connections] if isinstance(
connections, dict) else connections

connections_list = self._normalize_class_data(connections)

for connection in connections_list:
total_edges += self._object_summary(c, connection)
Expand Down
4 changes: 2 additions & 2 deletions test/test_Images.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def test_resolve_rotate():
operations = [{"type": "rotate", "angle": 90}]
resolved = resolve(points, meta, operations)
assert len(resolved) == 1
# Note: 9 instead of 10 due to float truncation in .astype(int)
assert resolved[0][0] == 90 and resolved[0][1] == 9
# Allow 9 or 10 due to float truncation/rounding differences across platforms
assert resolved[0][0] == 90 and abs(resolved[0][1] - 10) <= 1


class MockClient:
Expand Down
67 changes: 67 additions & 0 deletions test/test_Utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from unittest.mock import patch, MagicMock
import json
from aperturedb.Utils import Utils


class TestUtils():

def test_remove_all_objects(self, utils):
Expand All @@ -10,3 +15,65 @@ def test_remove_all_indexes(self, utils):

def test_get_descriptorset_list(self, utils):
assert utils.get_descriptorset_list() == []


class TestUtilsSummaryNormalization():

def test_summary_normalization(self):
# We don't use the 'utils' fixture because it requires a live DB connection
mock_connector = MagicMock()
utils = Utils(mock_connector)

mock_schema = {
"entities": {
"returned": 1,
"classes": {
"Person": {
"matched": 10,
"properties": {
"name": [10, True, "string"]
}
}
}
},
"connections": {
"returned": 3,
"classes": {
"Knows": {
"matched": 5,
"properties": {},
"src": "Person",
"dst": "Person"
},
"Likes": {
"Likes_1": {
"matched": 3,
"properties": {},
"src": "Person",
"dst": "Movie"
},
"Likes_2": {
"matched": 4,
"properties": {},
"src": "Person",
"dst": "Book"
}
},
"Owns": [
{
"matched": 2,
"properties": {},
"src": "Person",
"dst": "Car"
}
]
}
}
}
mock_status = json.dumps(
[{"GetStatus": {"version": "1.0", "status": "OK", "info": ""}}])

with patch.object(utils, 'get_schema', return_value=mock_schema), \
patch.object(utils, 'status', return_value=mock_status):
# should not raise
utils.summary()
Loading