Skip to content

Commit

Permalink
feat(swcloader): add type hints, allow exceptions to error program
Browse files Browse the repository at this point in the history
  • Loading branch information
sanjayankur31 committed Jul 3, 2024
1 parent db6d0a2 commit c7f2c11
Showing 1 changed file with 71 additions and 68 deletions.
139 changes: 71 additions & 68 deletions pyneuroml/swc/LoadSWC.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import re
import typing

logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,7 +56,16 @@ class SWCNode:
GLIA_PROCESSES: "Glia Processes",
}

def __init__(self, node_id, type_id, x, y, z, radius, parent_id):
def __init__(
self,
node_id: typing.Union[str, int],
type_id: typing.Union[str, int],
x: typing.Union[str, float],
y: typing.Union[str, float],
z: typing.Union[str, float],
radius: typing.Union[str, float],
parent_id: typing.Union[str, int],
):
try:
self.id = int(node_id)
self.type = int(type_id)
Expand All @@ -64,7 +74,7 @@ def __init__(self, node_id, type_id, x, y, z, radius, parent_id):
self.z = float(z)
self.radius = float(radius)
self.parent_id = int(parent_id)
self.children = []
self.children: typing.List[SWCNode] = []
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid data types in SWC line: {e}")

Expand Down Expand Up @@ -99,13 +109,12 @@ class SWCGraph:
"SCALE",
]

def __init__(self):
self.nodes = []
self.root = None
self.metadata = {}
self.logger = logging.getLogger(__name__)
def __init__(self) -> None:
self.nodes: typing.List[SWCNode] = []
self.root: typing.Optional[SWCNode] = None
self.metadata: typing.Dict[str, str] = {}

def add_node(self, node):
def add_node(self, node: SWCNode):
"""
Add a node to the SWC graph.
Expand All @@ -114,7 +123,7 @@ def add_node(self, node):
:raises ValueError: If a node with the same ID already exists in the graph or if multiple root nodes are detected
"""
if any(existing_node.id == node.id for existing_node in self.nodes):
self.logger.error(f"Duplicate node ID: {node.id}")
logger.error(f"Duplicate node ID: {node.id}")
raise ValueError(f"Duplicate node ID: {node.id}")

if node.parent_id == -1:
Expand All @@ -123,21 +132,21 @@ def add_node(self, node):
"Attempted to add multiple root nodes. Only one root node is allowed."
)
self.root = node
self.logger.debug(f"Root node set: {node}")
logger.debug(f"Root node set: {node}")
else:
parent = next((n for n in self.nodes if n.id == node.parent_id), None)
if parent:
parent.children.append(node)
self.logger.debug(f"Node {node.id} added as child to node {parent.id}")
logger.debug(f"Node {node.id} added as child to node {parent.id}")
else:
raise ValueError(
f"Parent node {node.parent_id} not found for node {node.id}"
)

self.nodes.append(node)
self.logger.debug(f"New node added: {node}")
logger.debug(f"New node added: {node}")

def get_node(self, node_id):
def get_node(self, node_id: int) -> SWCNode:
"""
Get a node from the graph by its ID.
Expand All @@ -152,7 +161,7 @@ def get_node(self, node_id):
raise ValueError(f"Node {node_id} not found in the SWC tree")
return node

def add_metadata(self, key, value):
def add_metadata(self, key: str, value: str):
"""
Add metadata to the SWC graph.
Expand All @@ -164,17 +173,17 @@ def add_metadata(self, key, value):

if key in self.HEADER_FIELDS:
self.metadata[key] = value
self.logger.debug(f"Added metadata: {key}: {value}")
logger.debug(f"Added metadata: {key}: {value}")
else:
self.logger.warning(f"Ignoring unrecognized header field: {key}: {value}")
logger.warning(f"Ignoring unrecognized header field: {key}: {value}")

def get_parent(self, node_id):
def get_parent(self, node_id: int) -> typing.Optional[SWCNode]:
"""
Get the parent node of a given node in the SWC tree.
:param node_id: The ID of the node for which to retrieve the parent
:type node_id: int
:return: The parent Node object if the node has a parent, otherwise None
:return: The parent node if the node has a parent, otherwise None
:rtype: SWCNode or None
:raises ValueError: If the specified node_id is not found in the SWC tree
Expand All @@ -186,7 +195,7 @@ def get_parent(self, node_id):
return None
return self.get_node(node.parent_id)

def get_children(self, node_id):
def get_children(self, node_id: int) -> typing.List[SWCNode]:
"""
Get a list of child nodes for a given node.
Expand All @@ -201,7 +210,9 @@ def get_children(self, node_id):
children = [node for node in self.nodes if node.parent_id == node_id]
return children

def get_nodes_with_multiple_children(self, type_id=None):
def get_nodes_with_multiple_children(
self, type_id: typing.Optional[int] = None
) -> typing.List[SWCNode]:
"""
Get nodes with multiple children, optionally filtered by type.
Expand All @@ -225,7 +236,7 @@ def get_nodes_with_multiple_children(self, type_id=None):

return nodes

def get_nodes_by_type(self, type_id):
def get_nodes_by_type(self, type_id: int) -> typing.List[SWCNode]:
"""
Get a list of nodes of a specific type.
Expand All @@ -236,28 +247,30 @@ def get_nodes_by_type(self, type_id):
"""
return [node for node in self.nodes if node.type == type_id]

def get_branch_points(self, *types):
def get_branch_points(
self, types: typing.Optional[typing.List[int]]
) -> typing.Union[typing.List[SWCNode], typing.Dict[int, typing.List[SWCNode]]]:
"""
Get all branch points (nodes with multiple children) of the given types.
:param types: One or more node type IDs to filter the branch points by
:type types: int
:return: A dictionary of lists of SWCNode objects that represent branch points of the specified types
:rtype: dict
:return: if node types are given, a dictionary with keys as the node
type and lists of nodes as values; otherwise a list of all nodes
:rtype: list or dict
"""
branch_points = {}

if not types:
# If no types are specified, return all branch points under a None key
branch_points[None] = self.get_nodes_with_multiple_children()
# If no types are specified, return all branch points
return self.get_nodes_with_multiple_children()
else:
branch_points = {}
for type_id in types:
branch_points[type_id] = self.get_nodes_with_multiple_children(type_id)
return branch_points

return branch_points


def parse_header(line):
def parse_header(line: str) -> typing.Optional[typing.Tuple[str, str]]:
"""
Parse a header line from an SWC file.
Expand All @@ -272,55 +285,45 @@ def parse_header(line):
match = re.match(rf"{field}\s+(.+)", line, re.IGNORECASE)
if match:
return field, match.group(1).strip()
return None, None
else:
logger.warn(f"Line beginning with '#' does not match header format: {line}")
return None


def load_swc(filename):
def load_swc(filename: str) -> SWCGraph:
"""
Load an SWC file and create an SWCGraph object.
:param filename: The path to the SWC file to be loaded
:type filename: str
:return: An SWCGraph object representing the loaded SWC file
:rtype: SWCGraph
:raises FileNotFoundError: If the specified file does not exist
:raises IOError: If there's an error reading the file
:raises ValueError: If a non header line with more than the required number
of fields is found
"""

tree = SWCGraph()
try:
with open(filename, "r") as file:
for line_number, line in enumerate(file, 1):
line = line.strip()
if not line:
continue
if line.startswith("#"):
key, value = parse_header(line[1:].strip())
if key:
tree.add_metadata(key, value)
continue

parts = line.split()
if len(parts) != 7:
logger.warning(
f"Line {line_number}: Invalid number of fields. Expected 7, got {len(parts)}. Skipping line: {line}"
)
continue

try:
node_id, type_id, x, y, z, radius, parent_id = parts
node = SWCNode(node_id, type_id, x, y, z, radius, parent_id)
tree.add_node(node)
except ValueError as e:
logger.warning(
f"Line {line_number}: {str(e)}. Skipping line: {line}"
)

except FileNotFoundError:
logger.error(f"File not found: {filename}")
raise
except IOError as e:
logger.error(f"Error reading file {filename}: {str(e)}")
raise
with open(filename, "r") as file:
for line_number, line in enumerate(file, 1):
line = line.strip()
if not line:
continue
if line.startswith("#"):
header = parse_header(line[1:].strip())
if header:
tree.add_metadata(header[0], header[1])
continue

parts = line.split()
if len(parts) != 7:
raise ValueError(
f"Line {line_number}: Invalid number of fields. Expected 7, got {len(parts)}. Skipping line: {line}"
)

# the add_node bit throws errors if things don't work out as
# expected
node_id, type_id, x, y, z, radius, parent_id = parts
node = SWCNode(node_id, type_id, x, y, z, radius, parent_id)
tree.add_node(node)

return tree

0 comments on commit c7f2c11

Please sign in to comment.