Skip to content

Commit

Permalink
cleanup tracing, new spans, new attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
aMahanna committed Dec 16, 2023
1 parent eee8128 commit 83672c5
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 65 deletions.
143 changes: 79 additions & 64 deletions adbpyg_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
from .abc import Abstract_ADBPyG_Adapter
from .controller import ADBPyG_Controller
from .exceptions import ADBMetagraphError, InvalidADBEdgesError, PyGMetagraphError
from .tracing import TRACING_ENABLED, TracingManager, with_tracing
from .tracing import (
TRACING_ENABLED,
TracingManager,
start_as_current_span,
with_tracing,
)
from .typings import (
ADBMap,
ADBMetagraph,
Expand Down Expand Up @@ -717,6 +722,7 @@ def __process_adb_v_col(
:param node_data: The PyG NodeStorage object.
:type node_data: torch_geometric.data.storage.NodeStorage
"""
TracingManager.set_attributes(v_col=v_col)

# 1. Fetch ArangoDB vertices
v_col_cursor, v_col_size = self.__fetch_adb_docs(
Expand All @@ -736,8 +742,6 @@ def __process_adb_v_col(
node_data=node_data,
)

TracingManager.set_attributes(v_col=v_col, v_col_size=v_col_size)

@with_tracing
def __process_adb_e_col(
self,
Expand Down Expand Up @@ -774,6 +778,8 @@ def __process_adb_e_col(
:param is_homogeneous: Whether the ArangoDB graph is homogeneous or not.
:type is_homogeneous: bool
"""
TracingManager.set_attributes(e_col=e_col)

# 1. Fetch ArangoDB edges
e_col_cursor, e_col_size = self.__fetch_adb_docs(
e_col, meta, **adb_export_kwargs
Expand All @@ -795,8 +801,6 @@ def __process_adb_e_col(
is_homogeneous=is_homogeneous,
)

TracingManager.set_attributes(e_col=e_col, e_col_size=e_col_size)

@with_tracing
def __fetch_adb_docs(
self,
Expand Down Expand Up @@ -848,14 +852,15 @@ def get_aql_return_value(
"""

col_size: int = self.__db.collection(col).count()
TracingManager.set_attributes(col=col, col_size=col_size, meta=str(meta))

with get_export_spinner_progress(f"ADB Export: '{col}' ({col_size})") as p:
p.add_task(col)

cursor: Cursor = self.__db.aql.execute(
f"FOR doc IN @@col RETURN {get_aql_return_value(meta)}",
bind_vars={"@col": col},
**{**adb_export_kwargs, **{"stream": True}},
**{**adb_export_kwargs, "stream": True},
)

return cursor, col_size
Expand Down Expand Up @@ -899,17 +904,15 @@ def __process_adb_cursor(
progress = get_bar_progress(f"(ADB → PyG): '{col}'", progress_color)
progress_task_id = progress.add_task(col, total=col_size)

i = 0
with Live(Group(progress)):
i = 0
while not cursor.empty():
cursor_batch = len(cursor.batch())
df = DataFrame([cursor.pop() for _ in range(cursor_batch)])
df = DataFrame(cursor.batch())
cursor.batch().clear()

i = process_adb_df(i, df, col, adb_map, meta, preserve_key, **kwargs)
progress.advance(progress_task_id, advance=len(df))

df.drop(df.index, inplace=True)

if cursor.has_more():
cursor.fetch()

Expand Down Expand Up @@ -944,6 +947,8 @@ def __process_adb_vertex_df(
:return: The last PyG Node id value.
:rtype: int
"""
TracingManager.set_attributes(i=i, vertex_df_size=len(df))

# 1. Map each ArangoDB _key to a PyG node id
for adb_key in df["_key"]:
adb_map[v_col][adb_key] = i
Expand Down Expand Up @@ -999,6 +1004,8 @@ def __process_adb_edge_df(
but is needed for type hinting.
:rtype: int
"""
TracingManager.set_attributes(edge_df_size=len(df))

# 1. Split the ArangoDB _from & _to IDs into two columns
df[["from_col", "from_key"]] = self.__split_adb_ids(df["_from"])
df[["to_col", "to_key"]] = self.__split_adb_ids(df["_to"])
Expand All @@ -1008,50 +1015,56 @@ def __process_adb_edge_df(
df[["from_col", "to_col"]].value_counts().items()
):
edge_type = (from_col, e_col, to_col)
edge_data: EdgeStorage = data if is_homogeneous else data[edge_type]

# 3. Check for partial Edge Collection import
if from_col not in v_cols or to_col not in v_cols:
logger.debug(f"Skipping {edge_type}")
continue

logger.debug(f"Preparing {count} {edge_type} edges")

# 4. Get the edge data corresponding to the current edge type
et_df: DataFrame = df[
(df["from_col"] == from_col) & (df["to_col"] == to_col)
]

# 5. Map each ArangoDB from/to _key to the corresponding PyG node id
from_nodes = et_df["from_key"].map(adb_map[from_col]).tolist()
to_nodes = et_df["to_key"].map(adb_map[to_col]).tolist()

# 6. Set/Update the PyG Edge Index
edge_index = tensor([from_nodes, to_nodes])
edge_data.edge_index = torch.cat(
(edge_data.get("edge_index", tensor([])), edge_index), dim=1
)

# 7. Deal with invalid edges
if torch.any(torch.isnan(edge_data.edge_index)):
if strict:
m = f"Invalid edges found in Edge Collection {e_col}, {from_col} -> {to_col}." # noqa: E501
raise InvalidADBEdgesError(m)
else:
# Remove the invalid edges
edge_data.edge_index = edge_data.edge_index[
:, ~torch.any(edge_data.edge_index.isnan(), dim=0)
]

# 8. Set the PyG Edge Data
self.__set_pyg_data(meta, edge_data, et_df)

# 9. Maintain the ArangoDB _key values
if preserve_key is not None:
if preserve_key not in edge_data:
edge_data[preserve_key] = []
with start_as_current_span("__process_adb_edge_type_df"):
TracingManager.set_attributes(
edge_type=edge_type,
edge_type_df_size=count,
)

edge_data[preserve_key].extend(list(et_df["_key"]))
# 3. Check for partial Edge Collection import
if from_col not in v_cols or to_col not in v_cols:
logger.debug(f"Skipping {edge_type}")
TracingManager.set_attributes(skipped=True)
continue

logger.debug(f"Preparing {count} {edge_type} edges")

# 4. Get the edge data corresponding to the current edge type
et_df: DataFrame = df[
(df["from_col"] == from_col) & (df["to_col"] == to_col)
]

# 5. Map each ArangoDB from/to _key to the corresponding PyG node id
from_nodes = et_df["from_key"].map(adb_map[from_col]).tolist()
to_nodes = et_df["to_key"].map(adb_map[to_col]).tolist()

# 6. Set/Update the PyG Edge Index
edge_data: EdgeStorage = data if is_homogeneous else data[edge_type]
existing_ei = edge_data.get("edge_index", tensor([]))
new_ei = tensor([from_nodes, to_nodes])
edge_data.edge_index = torch.cat((existing_ei, new_ei), dim=1)

# 7. Deal with invalid edges
if torch.any(torch.isnan(edge_data.edge_index)):
if strict:
m = f"Invalid edges found in Edge Collection {e_col}, {from_col} -> {to_col}." # noqa: E501
raise InvalidADBEdgesError(m)
else:
# Remove the invalid edges
edge_data.edge_index = edge_data.edge_index[
:, ~torch.any(edge_data.edge_index.isnan(), dim=0)
]

# 8. Set the PyG Edge Data
self.__set_pyg_data(meta, edge_data, et_df)

# 9. Maintain the ArangoDB _key values
if preserve_key is not None:
if preserve_key not in edge_data:
edge_data[preserve_key] = []

edge_data[preserve_key].extend(list(et_df["_key"]))

return 1 # Useless return value, but needed for type hinting

Expand Down Expand Up @@ -1091,6 +1104,7 @@ def __set_pyg_data(
"""
valid_meta: Dict[str, ADBMetagraphValues]
valid_meta = meta if type(meta) is dict else {m: m for m in meta}
TracingManager.set_attributes(meta=str(valid_meta))

for k, v in valid_meta.items():
t = self.__build_tensor_from_dataframe(df, k, v)
Expand Down Expand Up @@ -1125,8 +1139,8 @@ def __build_tensor_from_dataframe(
:rtype: torch.Tensor
:raise adbpyg_adapter.exceptions.ADBMetagraphError: If invalid **meta_val**.
"""
m = f"__build_tensor_from_dataframe(df, '{meta_key}', {type(meta_val)})"
logger.debug(m)
TracingManager.set_attributes(meta_key=meta_key, meta_val=str(meta_val))
logger.debug(f"__build_tensor_from_dataframe(df, {meta_key}, {str(meta_val)})")

if type(meta_val) is str:
return tensor(adb_df[meta_val].to_list())
Expand Down Expand Up @@ -1652,6 +1666,7 @@ def __set_adb_data(

valid_meta: Dict[Any, PyGMetagraphValues]
valid_meta = meta if type(meta) is dict else {m: m for m in meta}
TracingManager.set_attributes(meta=str(valid_meta))

pyg_keys = (
set(valid_meta.keys())
Expand All @@ -1664,15 +1679,16 @@ def __set_adb_data(
data = pyg_data[meta_key]
meta_val = valid_meta.get(meta_key, str(meta_key))

if (
type(meta_val) is str
and type(data) is list
and len(data) == pyg_data_size
):
if len(data) != pyg_data_size:
m = f"Skipping {meta_key} due to invalid length ({len(data)} != {pyg_data_size})" # noqa: E501
logger.debug(m)
continue

if type(meta_val) is str and type(data) is list:
meta_val = "_key" if meta_val in ["_v_key", "_e_key"] else meta_val
df = df.join(DataFrame(data[start_index:end_index], columns=[meta_val]))

if type(data) is Tensor and len(data) == pyg_data_size:
if type(data) is Tensor:
df = df.join(
self.__build_dataframe_from_tensor(
data[start_index:end_index],
Expand Down Expand Up @@ -1713,9 +1729,8 @@ def __build_dataframe_from_tensor(
:rtype: pandas.DataFrame
:raise adbpyg_adapter.exceptions.PyGMetagraphError: If invalid **meta_val**.
"""
logger.debug(
f"__build_dataframe_from_tensor(df, '{meta_key}', {type(meta_val)})"
)
TracingManager.set_attributes(meta_key=meta_key, meta_val=str(meta_val))
logger.debug(f"__build_dataframe_from_tensor(df, {meta_key}, {type(meta_val)})")

if type(meta_val) is str:
df = DataFrame(index=range(start_index, end_index), columns=[meta_val])
Expand Down
12 changes: 11 additions & 1 deletion adbpyg_adapter/tracing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable, List, Optional, TypeVar, cast
from typing import Any, Callable, Iterator, List, Optional, TypeVar, cast

try:
from opentelemetry import trace
Expand Down Expand Up @@ -54,6 +55,15 @@ def decorator(*args: Any, **kwargs: Any) -> Any:
return cast(T, decorator)


@contextmanager
def start_as_current_span(*args: Any, **kwargs: Any) -> Iterator[None]:
if tracer := TracingManager.get_tracer():
with tracer.start_as_current_span(*args, **kwargs):
yield
else:
yield


def create_tracer(
name: str,
enable_console_tracing: bool = False,
Expand Down

0 comments on commit 83672c5

Please sign in to comment.