Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph plotting #130

Merged
merged 3 commits into from Feb 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 4 additions & 3 deletions .travis.yml
Expand Up @@ -23,8 +23,8 @@ matrix:
before_install:
- sudo add-apt-repository ppa:joyard-nicolas/ffmpeg -y
- sudo apt-get -qq update
- sudo apt-get install ffmpeg
- sudo apt-get install tesseract-ocr
- sudo apt-get install ffmpeg tesseract-ocr graphviz


- export MINICONDA=$HOME/miniconda
- export PATH="$MINICONDA/bin:$PATH"
Expand All @@ -42,7 +42,8 @@ before_install:
install:
- pip install -r requirements.txt
- pip install --upgrade --ignore-installed setuptools
- pip install --upgrade coveralls pytest-cov pysrt xlrd clarifai seaborn pytesseract matplotlib SpeechRecognition IndicoIo tensorflow
- pip install --upgrade coveralls pytest-cov pysrt xlrd clarifai seaborn pytesseract
- pip install --upgrade matplotlib SpeechRecognition IndicoIo tensorflow pygraphviz
- pip install 'numpy<1.12.0'

before_script:
Expand Down
1 change: 1 addition & 0 deletions optional-dependencies.txt
Expand Up @@ -2,6 +2,7 @@ clarifai
cv2
google-api-python-client
matplotlib
pygraphviz
pysrt
seaborn
SpeechRecognition>=3.6.0
Expand Down
2 changes: 1 addition & 1 deletion pliers/extractors/audio.py
Expand Up @@ -51,7 +51,7 @@ def _stft(self, stim):
w = np.hanning(framesamp)
X = np.array([fft(w*x[i:(i+framesamp)])
for i in range(0, len(x)-framesamp, hopsamp)])
nyquist_lim = X.shape[1]//2
nyquist_lim = int(X.shape[1]//2)
X = np.log(X[:, :nyquist_lim])
X = np.absolute(X)
if self.spectrogram:
Expand Down
53 changes: 53 additions & 0 deletions pliers/graph.py
Expand Up @@ -6,6 +6,11 @@
from six import string_types
from collections import OrderedDict

try:
import pygraphviz as pgv
except:
pgv = None


class Node(object):

Expand Down Expand Up @@ -66,9 +71,57 @@ def add_node(self, transformer, name=None, children=None, parent=None,
if return_node:
return node

def draw(self, filename):
''' Render a plot of the graph via pygraphviz.
Args:
filename (str): Path to save the generated image to.
'''
if pgv is None:
raise ImportError("pygraphviz is required in order to plot graphs,"
" but could not be successfully imported. Please"
" make sure it is installed.")
if not hasattr(self, '_results'):
raise RuntimeError("Graph cannot be drawn before it is executed. "
"Try calling run() first.")

g = pgv.AGraph(directed=True)
node_list = {}

for elem in self._results:
if not hasattr(elem, 'history'):
continue
log = elem.history

has_parent = True

while has_parent:

# Add nodes
source_from = log.parent[6] if log.parent else ''
s_node = hash((source_from, log[2]))
if s_node not in node_list:
g.add_node(s_node, label=log[2], shape='ellipse')

t_node = hash((log[6], log[7]))
if t_node not in node_list:
g.add_node(t_node, label=log[6], shape='box')

r_node = hash((log[6], log[5]))
if r_node not in node_list:
g.add_node(r_node, label=log[5], shape='ellipse')

# Add edges
g.add_edge(s_node, t_node)
g.add_edge(t_node, r_node)
has_parent = log.parent
log = log.parent

g.draw(filename, prog='dot')

def run(self, stim, merge=True):
results = list(chain(*[self.run_node(n, stim) for n in self.roots]))
results = list(flatten(results))
self._results = results # For use in plotting
return merge_results(results) if merge else results

def run_node(self, node, stim):
Expand Down
1 change: 1 addition & 0 deletions pliers/stimuli/base.py
Expand Up @@ -141,6 +141,7 @@ def _log_transformation(source, result, trans=None):
result.history = TransformationLog(*values)
return result


class TransformationLog(namedtuple('TransformationLog', "source_name source_file " +
"source_class result_name result_file result_class " +
" transformer_class transformer_params string parent")):
Expand Down
10 changes: 9 additions & 1 deletion pliers/tests/test_graph.py
Expand Up @@ -6,9 +6,11 @@
LengthExtractor, merge_results)
from pliers.stimuli import (ImageStim, VideoStim)
from .utils import get_test_data_path, DummyExtractor
from os.path import join
from os.path import join, exists
import numpy as np
from numpy.testing import assert_almost_equal
import tempfile
import os


def test_node_init():
Expand Down Expand Up @@ -84,6 +86,7 @@ def test_small_pipeline():

@pytest.mark.skipif("'WIT_AI_API_KEY' not in os.environ")
def test_big_pipeline():
pytest.importorskip('pygraphviz')
filename = join(get_test_data_path(), 'video', 'obama_speech.mp4')
video = VideoStim(filename)
visual_nodes = [(FrameSamplingConverter(every=15), [
Expand All @@ -97,6 +100,11 @@ def test_big_pipeline():
graph.add_nodes(visual_nodes)
graph.add_nodes(audio_nodes)
result = graph.run(video)
# Test that pygraphviz outputs a file
drawfile = next(tempfile._get_candidate_names())
graph.draw(drawfile)
assert exists(drawfile)
os.remove(drawfile)
assert ('LengthExtractor', 'text_length') in result.columns
assert ('VibranceExtractor', 'vibrance') in result.columns
# assert not result[('onset', '')].isnull().any()
Expand Down