Skip to content

Commit

Permalink
pass positional arguments from call to create_nx_graph
Browse files Browse the repository at this point in the history
  • Loading branch information
pstjohn committed Jan 4, 2022
1 parent 0fd16c4 commit aaf81a3
Showing 1 changed file with 50 additions and 41 deletions.
91 changes: 50 additions & 41 deletions nfp/preprocessing/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import networkx as nx
import numpy as np
import tensorflow as tf

from nfp.preprocessing.tokenizer import Tokenizer

logger = logging.getLogger(__name__)
Expand All @@ -24,11 +23,11 @@ class Preprocessor(ABC):
the returned arrays
"""

def __init__(self, output_dtype: str = 'int32'):
def __init__(self, output_dtype: str = "int32"):
self.output_dtype = output_dtype

@abstractmethod
def create_nx_graph(self, structure: Any, **kwargs) -> nx.DiGraph:
def create_nx_graph(self, structure: Any, *args, **kwargs) -> nx.DiGraph:
"""Given an input structure object, convert it to a networkx digraph
with node, edge, and graph features assigned.
Expand All @@ -48,8 +47,9 @@ def create_nx_graph(self, structure: Any, **kwargs) -> nx.DiGraph:
pass

@abstractmethod
def get_edge_features(self, edge_data: list,
max_num_edges: int) -> Dict[str, np.ndarray]:
def get_edge_features(
self, edge_data: list, max_num_edges: int
) -> Dict[str, np.ndarray]:
"""Given a list of edge features from the nx.Graph, processes and
concatenates them to an array.
Expand All @@ -70,8 +70,9 @@ def get_edge_features(self, edge_data: list,
pass

@abstractmethod
def get_node_features(self, node_data: list,
max_num_nodes: int) -> Dict[str, np.ndarray]:
def get_node_features(
self, node_data: list, max_num_nodes: int
) -> Dict[str, np.ndarray]:
"""Given a list of node features from the nx.Graph, processes and
concatenates them to an array.
Expand Down Expand Up @@ -108,7 +109,9 @@ def get_graph_features(self, graph_data: dict) -> Dict[str, np.ndarray]:
pass

@staticmethod
def get_connectivity(graph: nx.DiGraph, max_num_edges: int) -> Dict[str, np.ndarray]:
def get_connectivity(
graph: nx.DiGraph, max_num_edges: int
) -> Dict[str, np.ndarray]:
"""Get the graph connectivity from the networkx graph
Parameters
Expand All @@ -125,17 +128,20 @@ def get_connectivity(graph: nx.DiGraph, max_num_edges: int) -> Dict[str, np.ndar
array of (node_index, node_index) pairs indicating the start and end
nodes for each edge.
"""
connectivity = np.zeros((max_num_edges, 2), dtype='int64')
connectivity = np.zeros((max_num_edges, 2), dtype="int64")
if len(graph.edges) > 0: # Handle odd case with no edges
connectivity[:len(graph.edges)] = np.asarray(graph.edges)
return {'connectivity': connectivity}

def __call__(self,
structure: Any,
train: bool = False,
max_num_nodes: Optional[int] = None,
max_num_edges: Optional[int] = None,
**kwargs) -> Dict[str, np.ndarray]:
connectivity[: len(graph.edges)] = np.asarray(graph.edges)
return {"connectivity": connectivity}

def __call__(
self,
structure: Any,
*args,
train: bool = False,
max_num_nodes: Optional[int] = None,
max_num_edges: Optional[int] = None,
**kwargs,
) -> Dict[str, np.ndarray]:
"""Convert an input graph structure into a featurized set of node, edge,
and graph-level features.
Expand All @@ -159,13 +165,17 @@ def __call__(self,
Dict[str, np.ndarray]
A dictionary of key, array pairs as a single sample.
"""
nx_graph = self.create_nx_graph(structure, **kwargs)
nx_graph = self.create_nx_graph(structure, *args, **kwargs)

max_num_edges = len(nx_graph.edges) if max_num_edges is None else max_num_edges
assert len(nx_graph.edges) <= max_num_edges, "max_num_edges too small for given input"
assert (
len(nx_graph.edges) <= max_num_edges
), "max_num_edges too small for given input"

max_num_nodes = len(nx_graph.nodes) if max_num_nodes is None else max_num_nodes
assert len(nx_graph.nodes) <= max_num_nodes, "max_num_nodes too small for given input"
assert (
len(nx_graph.nodes) <= max_num_nodes
), "max_num_nodes too small for given input"

# Make sure that Tokenizer classes are correctly initialized
for _, tokenizer in getmembers(self, lambda x: type(x) == Tokenizer):
Expand All @@ -176,35 +186,31 @@ def __call__(self,
graph_features = self.get_graph_features(nx_graph.graph)
connectivity = self.get_connectivity(nx_graph, max_num_edges)

return {
**node_features,
**edge_features,
**graph_features,
**connectivity
}

def construct_feature_matrices(self,
*args,
train=False,
**kwargs) -> Dict[str, np.ndarray]:
return {**node_features, **edge_features, **graph_features, **connectivity}

def construct_feature_matrices(
self, *args, train=False, **kwargs
) -> Dict[str, np.ndarray]:
"""
.. deprecated:: 0.3.0
`construct_feature_matrices` will be removed in 0.4.0, use
`__call__` instead
"""
warnings.warn(
"construct_feature_matrices is deprecated, use `call` instead as "
"of nfp 0.4.0", DeprecationWarning)
"of nfp 0.4.0",
DeprecationWarning,
)
return self(*args, train=train, **kwargs)

def to_json(self, filename: str) -> None:
"""Serialize the classes's data to a json file"""
with open(filename, 'w') as f:
with open(filename, "w") as f:
json.dump(self, f, default=lambda x: x.__dict__)

def from_json(self, filename: str) -> None:
"""Set's the class's data with attributes taken from the save file"""
with open(filename, 'r') as f:
with open(filename, "r") as f:
json_data = json.load(f)
load_from_json(self, json_data)

Expand Down Expand Up @@ -232,12 +238,14 @@ def create_nx_graph(self, structure: Any, **kwargs) -> nx.MultiDiGraph:
pass

@staticmethod
def get_connectivity(graph: nx.DiGraph, max_num_edges: int) -> Dict[str, np.ndarray]:
def get_connectivity(
graph: nx.DiGraph, max_num_edges: int
) -> Dict[str, np.ndarray]:
# Don't include keys in the connectivity matrix
connectivity = np.zeros((max_num_edges, 2), dtype='int64')
connectivity = np.zeros((max_num_edges, 2), dtype="int64")
if len(graph.edges) > 0: # Handle odd case with no edges
connectivity[:len(graph.edges)] = np.asarray(graph.edges)[:, :2]
return {'connectivity': connectivity}
connectivity[: len(graph.edges)] = np.asarray(graph.edges)[:, :2]
return {"connectivity": connectivity}


def load_from_json(obj: Any, data: Dict):
Expand All @@ -260,10 +268,11 @@ def load_from_json(obj: Any, data: Dict):
try:
if isinstance(val, type(data[key])):
obj.__dict__[key] = data[key]
elif hasattr(val, '__dict__'):
elif hasattr(val, "__dict__"):
load_from_json(val, data[key])

except KeyError:
logger.warning(
f"{key} not found in JSON file, it may have been created with"
" an older nfp version")
" an older nfp version"
)

0 comments on commit aaf81a3

Please sign in to comment.