Skip to content

Commit

Permalink
Merge pull request #254 from tyarkoni/graph-improvements
Browse files Browse the repository at this point in the history
Graph improvements
  • Loading branch information
qmac committed Feb 8, 2018
2 parents 535935d + b4eacdb commit dce9fc4
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 80 deletions.
27 changes: 23 additions & 4 deletions pliers/extractors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ def history(self, history):


def merge_results(results, format='wide', timing=True, metadata=True,
extractor_names=True, object_id=True, aggfunc=None):
extractor_names=True, object_id=True, aggfunc=None,
invalid_results='ignore'):
''' Merges a list of ExtractorResults instances and returns a pandas DF.
Args:
Expand Down Expand Up @@ -239,6 +240,13 @@ def merge_results(results, format='wide', timing=True, metadata=True,
same index. Can be a callable or any string value recognized by
pandas. By default (None), 'mean' will be used for numeric columns
and 'first' will be used for object/categorical columns.
invalid_results (str): Specifies desired action for treating elements
of the passed in results argument that are not ExtractorResult
objects. Valid values include:
- 'ignore' will ignore them and merge the valid
ExtractorResults.
- 'fail' will raise an exception on any invalid input
Returns: a pandas DataFrame. For format details, see 'format' argument.
'''
Expand All @@ -253,9 +261,20 @@ def merge_results(results, format='wide', timing=True, metadata=True,
elif extractor_names is False:
extractor_names = 'drop'

dfs = [r.to_df(timing=_timing, metadata=metadata, format='long',
extractor_name=True, object_id=_object_id)
for r in results]
dfs = []
for r in results:
if isinstance(r, ExtractorResult):
dfs.append(r.to_df(timing=_timing, metadata=metadata,
format='long', extractor_name=True,
object_id=_object_id))
elif invalid_results == 'fail':
raise ValueError("At least one of the provided results was not an"
"ExtractorResult. Set the invalid_results"
"parameter to 'ignore' if you wish to ignore"
"this.")

if len(dfs) == 0:
return pd.DataFrame()

data = pd.concat(dfs, axis=0).reset_index(drop=True)

Expand Down
207 changes: 140 additions & 67 deletions pliers/graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
''' The `graph` module contains tools for constructing and executing graphs
of pliers Transformers. '''

from pliers.extractors.base import Extractor, merge_results
from pliers.extractors.base import merge_results
from pliers.stimuli import __all__ as stim_list
from pliers.transformers import get_transformer
from pliers.utils import (listify, flatten, isgenerator, attempt_to_import,
verify_dependencies)
Expand All @@ -12,6 +13,7 @@
import json

pgv = attempt_to_import('pygraphviz', 'pgv')
stim_list.insert(0, 'ExtractorResult')


class Node(object):
Expand All @@ -31,6 +33,7 @@ def __init__(self, transformer, name=None, **parameters):
if isinstance(transformer, string_types):
transformer = get_transformer(transformer, **parameters)
self.transformer = transformer
self.parameters = parameters
if name is not None:
self.transformer.name = name
self.id = id(transformer)
Expand All @@ -42,6 +45,19 @@ def add_child(self, node):
def is_leaf(self):
return len(self.children)

def to_json(self):
spec = {'transformer': self.transformer.__class__.__name__}
if self.name:
spec['name'] = self.name
if self.children:
children = []
for c in self.children:
children.append(c.to_json())
spec['children'] = children
if self.parameters:
spec['parameters'] = self.parameters
return spec


class Graph(object):
''' Graph-like structure that represents an entire pliers workflow.
Expand All @@ -66,7 +82,29 @@ def __init__(self, nodes=None, spec=None):
with open(spec) as spec_file:
self.add_nodes(json.load(spec_file)['roots'])

def add_nodes(self, nodes, parent=None):
@staticmethod
def _parse_node_args(node):
if isinstance(node, dict):
return node

kwargs = {}

if isinstance(node, (list, tuple)):
kwargs['transformer'] = node[0]
if len(node) > 1:
kwargs['children'] = node[1]
if len(node) > 2:
kwargs['name'] = node[2]
elif isinstance(node, Node):
kwargs['transformer'] = node.transformer
kwargs['children'] = node.children
kwargs['name'] = node.name
else:
kwargs['transformer'] = node

return kwargs

def add_nodes(self, nodes, parent=None, mode='horizontal'):
''' Adds one or more nodes to the current graph.
Args:
Expand All @@ -86,10 +124,32 @@ def add_nodes(self, nodes, parent=None):
parent (Node): Optional parent node (i.e., the node containing the
pliers Transformer from which the to-be-created nodes receive
their inputs).
mode (str): Indicates the direction with which to add the new nodes
* horizontal: the nodes should each be added as a child of the
'parent' argument (or a Graph root by default).
* vertical: the nodes should each be added in sequence with
the first node being the child of the 'parnet' argument
(a Graph root by default) and each subsequent node being
the child of the previous node in the list.
'''
for n in nodes:
node_args = self._parse_node_args(n)
self.add_node(parent=parent, **node_args)
if mode == 'horizontal':
self.add_node(parent=parent, **node_args)
elif mode == 'vertical':
parent = self.add_node(parent=parent, return_node=True,
**node_args)
else:
raise ValueError("Invalid mode for adding nodes to a graph:"
"%s" % mode)

def add_chain(self, nodes, parent=None):
''' An alias for add_nodes with the mode preset to 'vertical'. '''
self.add_nodes(nodes, parent, 'vertical')

def add_children(self, nodes, parent=None):
''' An alias for add_nodes with the mode preset to 'horizontal'. '''
self.add_nodes(nodes, parent, 'horizontal')

def add_node(self, transformer, name=None, children=None, parent=None,
parameters={}, return_node=False):
Expand Down Expand Up @@ -131,51 +191,6 @@ def add_node(self, transformer, name=None, children=None, parent=None,
if return_node:
return node

def draw(self, filename):
''' Render a plot of the graph via pygraphviz.
Args:
filename (str): Path to save the generated image to.
'''
verify_dependencies(['pgv'])
if not hasattr(self, '_results'):
raise RuntimeError("Graph cannot be drawn before it is executed. "
"Try calling run() first.")

g = pgv.AGraph(directed=True)
node_list = {}

for elem in self._results:
if not hasattr(elem, 'history'):
continue
log = elem.history

has_parent = True

while has_parent:

# Add nodes
source_from = log.parent[6] if log.parent else ''
s_node = hash((source_from, log[2]))
if s_node not in node_list:
g.add_node(s_node, label=log[2], shape='ellipse')

t_node = hash((log[6], log[7]))
if t_node not in node_list:
g.add_node(t_node, label=log[6], shape='box')

r_node = hash((log[6], log[5]))
if r_node not in node_list:
g.add_node(r_node, label=log[5], shape='ellipse')

# Add edges
g.add_edge(s_node, t_node)
g.add_edge(t_node, r_node)
has_parent = log.parent
log = log.parent

g.draw(filename, prog='dot')

def run(self, stim, merge=True, **merge_kwargs):
''' Executes the graph by calling all Transformers in sequence.
Expand Down Expand Up @@ -209,7 +224,7 @@ def run_node(self, node, stim):
node = self.nodes[node]

result = node.transformer.transform(stim)
if isinstance(node.transformer, Extractor):
if len(node.children) == 0:
return listify(result)

stim = result
Expand All @@ -219,25 +234,83 @@ def run_node(self, node, stim):
stim = list(stim)
return list(chain(*[self.run_node(c, stim) for c in node.children]))

@staticmethod
def _parse_node_args(node):
def draw(self, filename, color=True):
''' Render a plot of the graph via pygraphviz.
if isinstance(node, dict):
return node
Args:
filename (str): Path to save the generated image to.
color (bool): If True, will color graph nodes based on their type,
otherwise will draw a black-and-white graph.
'''
verify_dependencies(['pgv'])
if not hasattr(self, '_results'):
raise RuntimeError("Graph cannot be drawn before it is executed. "
"Try calling run() first.")

kwargs = {}
g = pgv.AGraph(directed=True)
g.node_attr['colorscheme'] = 'set312'

if isinstance(node, (list, tuple)):
kwargs['transformer'] = node[0]
if len(node) > 1:
kwargs['children'] = node[1]
if len(node) > 2:
kwargs['name'] = node[2]
elif isinstance(node, Node):
kwargs['transformer'] = node.transformer
kwargs['children'] = node.children
kwargs['name'] = node.name
else:
kwargs['transformer'] = node
for elem in self._results:
if not hasattr(elem, 'history'):
continue
log = elem.history

return kwargs
while log:
# Configure nodes
source_from = log.parent[6] if log.parent else ''
s_node = hash((source_from, log[2]))
s_color = stim_list.index(log[2])
s_color = s_color % 12 + 1

t_node = hash((log[6], log[7]))
t_style = 'filled,' if color else ''
t_style += 'dotted' if log.implicit else ''
if log[6].endswith('Extractor'):
t_color = '#0082c8'
elif log[6].endswith('Filter'):
t_color = '#e6194b'
else:
t_color = '#3cb44b'

r_node = hash((log[6], log[5]))
r_color = stim_list.index(log[5])
r_color = r_color % 12 + 1

# Add nodes
if color:
g.add_node(s_node, label=log[2], shape='ellipse',
style='filled', fillcolor=s_color)
g.add_node(t_node, label=log[6], shape='box',
style=t_style, fillcolor=t_color)
g.add_node(r_node, label=log[5], shape='ellipse',
style='filled', fillcolor=r_color)
else:
g.add_node(s_node, label=log[2], shape='ellipse')
g.add_node(t_node, label=log[6], shape='box',
style=t_style)
g.add_node(r_node, label=log[5], shape='ellipse')

# Add edges
g.add_edge(s_node, t_node, style=t_style)
g.add_edge(t_node, r_node, style=t_style)
log = log.parent

g.draw(filename, prog='dot')

def to_json(self):
''' Returns the JSON representation of this graph. '''
roots = []
for r in self.roots:
roots.append(r.to_json())
return {'roots': roots}

def save(self, filename):
''' Writes the JSON representation of this graph to the provided
filename, such that the graph can be easily reconstructed using
Graph(spec=filename).
Args:
filename (str): Path at which to write out the json file.
'''
with open(filename, 'w') as outfile:
json.dump(self.to_json(), outfile)
15 changes: 9 additions & 6 deletions pliers/stimuli/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def load_url(source):
return stims[0]


def _log_transformation(source, result, trans=None):
def _log_transformation(source, result, trans=None, implicit=False):

if result is None or not config.get_option('log_transformations') or \
(trans is not None and not trans._loggable):
Expand All @@ -191,12 +191,15 @@ def _log_transformation(source, result, trans=None):
string = str(parent) if parent else values[2]
string += '->%s/%s' % (values[6], values[5])
values.extend([string, parent])
values.append(implicit)
result.history = TransformationLog(*values)
return result


_trans_log = namedtuple('TransformationLog', "source_name source_file " +
"source_class result_name result_file result_class " +
" transformer_class transformer_params string parent")
" transformer_class transformer_params string " +
"parent implicit")


class TransformationLog(_trans_log):
Expand All @@ -210,9 +213,9 @@ def __str__(self):

def to_df(self):
def _append_row(rows, history):
rows.append(history[:-2])
if history[-1]:
_append_row(rows, history[-1])
rows.append(history[:-3])
if history.parent:
_append_row(rows, history.parent)
return rows
rows = _append_row([], self)[::-1]
return pd.DataFrame(rows, columns=self._fields[:-2])
return pd.DataFrame(rows, columns=self._fields[:-3])

0 comments on commit dce9fc4

Please sign in to comment.