Skip to content

Commit

Permalink
feat: tensorflow support, also toch.functional has now reduced suppor…
Browse files Browse the repository at this point in the history
…t due to some hard-to-trace edges
  • Loading branch information
MLRichter committed Jan 14, 2022
1 parent 6895bfb commit f340bf6
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 50 deletions.
39 changes: 39 additions & 0 deletions rfa_toolbox/encodings/tensorflow_keras/ingest_architecture.py
Expand Up @@ -26,6 +26,17 @@ def find_processable_node(
working_layers: List[Dict[str, Any]],
processed_nodes: Dict[str, EnrichedNetworkNode],
) -> Dict[str, Any]:
"""
Finds the first node in the list of working_layers, which is not yet processed.
Args:
working_layers: all unprocessed layers
processed_nodes: all processed nodes, the dicts maps node-ids
to their EnrichedNetworkNode-instances
Returns:
The first node in the list of working_layers, which is not yet processed.
"""
for layer in working_layers:
if "inbound_nodes" in layer:
inbound_nodes = layer["inbound_nodes"]
Expand All @@ -41,6 +52,11 @@ def find_processable_node(


def obtain_layer_definition(node_dict: Dict[str, Any]) -> LayerDefinition:
"""Obtain the layer-definition from a node-dict.
The transformations of nodes into their respective layers
is done by handler-objects, which are registered in this module in
the variable PARSERS.
"""
for parser in PARSERS:
if parser.can_handle(node=node_dict):
return parser(node_dict)
Expand All @@ -50,6 +66,14 @@ def obtain_layer_definition(node_dict: Dict[str, Any]) -> LayerDefinition:
def create_node_from_dict(
node_dict: Dict[str, Any], processed_nodes: Dict[str, EnrichedNetworkNode]
) -> EnrichedNetworkNode:
"""Create the node-representation of a layer.
Args:
node_dict: the node in dictionary representation, as extracted
from the keras-model
processed_nodes: a dictionary, which maps already processed nodes
to their EnrichedNetworkNode-instances, used for
obtaining predecessors
"""
predecessors = (
[]
if not node_dict["inbound_nodes"]
Expand All @@ -65,6 +89,7 @@ def create_node_from_dict(


def create_graph(layers: List[Dict[str, Any]]) -> EnrichedNetworkNode:
"""Create a graph of the model from a list of layers"""
processed_nodes: Dict[str, EnrichedNetworkNode] = {}
working_layers: List[Dict[str, Any]] = layers[:]
while working_layers:
Expand All @@ -82,6 +107,15 @@ def create_graph(layers: List[Dict[str, Any]]) -> EnrichedNetworkNode:


def model_dict_to_enriched_graph(model_dict: Dict[str, Any]) -> EnrichedNetworkNode:
"""Turn a dictionary extracted from the json-representation of a Keras
model into the rfa-toolbox specific graph representation.
Args:
model_dict: the json-representation of the model
Returns:
a node of the graph
"""
if "config" not in model_dict:
raise AttributeError("Model-json export has no config")
layer_config = model_dict["config"]
Expand All @@ -93,10 +127,15 @@ def model_dict_to_enriched_graph(model_dict: Dict[str, Any]) -> EnrichedNetworkN


def keras_model_to_dict(model: Model) -> Dict[str, Any]:
"""Creates a model into a dictionary based on it's json-representation"""
return loads(model.to_json())


def create_graph_from_model(model: Model) -> EnrichedNetworkNode:
"""Create a graph model from tensorflow
Args:
model: the model, thus must be a Keras-model.
"""
model_dict = keras_model_to_dict(model)
return model_dict_to_enriched_graph(model_dict)

Expand Down
55 changes: 5 additions & 50 deletions rfa_toolbox/encodings/tensorflow_keras/layer_handlers.py
Expand Up @@ -40,16 +40,7 @@ def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
@attrs(frozen=True, slots=True, auto_attribs=True)
class KernelBasedHandler(LayerInfoHandler):
def can_handle(self, node: Dict[str, Any]) -> bool:
"""Checks if this handler can process the
node in the compute graph of the model.
Args:
node: the node in question
Returns:
True if the node can be processed into a
valid LayerDefinition by this handler.
"""
"""Handles only layers featuring a kernel_size and filters"""
return "kernel_size" in node["config"] and "filters" in node["config"]

def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
Expand All @@ -71,16 +62,7 @@ def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
@attrs(frozen=True, slots=True, auto_attribs=True)
class PoolingBasedHandler(LayerInfoHandler):
def can_handle(self, node: Dict[str, Any]) -> bool:
"""Checks if this handler can process the
node in the compute graph of the model.
Args:
node: the node in question
Returns:
True if the node can be processed into a
valid LayerDefinition by this handler.
"""
"""Handles only layers featuring a pool_size"""
return "pool_size" in node["config"]

def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
Expand All @@ -101,16 +83,7 @@ def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
@attrs(frozen=True, slots=True, auto_attribs=True)
class DenseHandler(LayerInfoHandler):
def can_handle(self, node: Dict[str, Any]) -> bool:
"""Checks if this handler can process the
node in the compute graph of the model.
Args:
node: the node in question
Returns:
True if the node can be processed into a
valid LayerDefinition by this handler.
"""
"""Handles only layers feature units as attribute"""
return "units" in node["config"]

def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
Expand All @@ -123,16 +96,7 @@ def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
@attrs(frozen=True, slots=True, auto_attribs=True)
class InputHandler(LayerInfoHandler):
def can_handle(self, node: Dict[str, Any]) -> bool:
"""Checks if this handler can process the
node in the compute graph of the model.
Args:
node: the node in question
Returns:
True if the node can be processed into a
valid LayerDefinition by this handler.
"""
"""This is strictly meant for handling input nodes"""
return node["class_name"] == "InputLayer"

def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
Expand All @@ -144,16 +108,7 @@ def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
@attrs(frozen=True, slots=True, auto_attribs=True)
class AnyHandler(LayerInfoHandler):
def can_handle(self, node: Dict[str, Any]) -> bool:
"""Checks if this handler can process the
node in the compute graph of the model.
Args:
node: the node in question
Returns:
True if the node can be processed into a
valid LayerDefinition by this handler.
"""
"""This is a catch-all handler"""
return True

def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
Expand Down

0 comments on commit f340bf6

Please sign in to comment.