Permalink
Browse files

(draw - use matplotlib OO API rather than pyplot since fewer problems…

… with distributing)

Ignore-this: 3ff8c4a8dce6fa01de76380422d8ec1f

darcs-hash:20130718055606-3a4db-99ad2c70fb85d02d3b18cdf91fa872b41a153d6b
  • Loading branch information...
1 parent b23f0a7 commit 411e4d1dffa7cd4f4c31082eb98757f57dc3791c @MattShannon committed Jul 18, 2013
Showing with 59 additions and 56 deletions.
  1. +59 −56 armspeech/speech/draw.py
@@ -14,15 +14,13 @@
import math
import numpy as np
import armspeech.numpy_settings
-import matplotlib
-import matplotlib.transforms as transforms
-# FIXME : below is necessary to avoid errors when drawing graphs when an
-# X server is not defined, for example when running distributed jobs.
-# This is dissatisfying, however, since stateful stuff like this should not
-# be set while importing a module (or even on a per-function-call basis).
-# The best solution is probably to use the matplotlib API rather than the
-# pyplot interface, in which case you explicitly specify the backend to use.
-matplotlib.use('Agg')
+from matplotlib.figure import Figure
+from matplotlib import cm
+from matplotlib import transforms
+from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
+
+# (FIXME : not sure that the matplotlib OO API is completely thread-safe.
+# It is conceivable that we should really be using explicit locks below.)
@codeDeps()
def partitionSeq(xs, numPartitions):
@@ -37,8 +35,6 @@ def drawLabelledSeq(dataSeqs, labelSeqs, outPdf, figSizeRate = None,
xlabel = None, ylabel = None, legend = None,
lineStyles = None,
labelColors = ['red', 'purple', 'orange', 'blue']):
- import matplotlib.pyplot as plt
-
if xmin is None:
xmin = min([ dataSeq[0][0] for dataSeq in dataSeqs ] +
[ labelSeq[0][0] for labelSeq in labelSeqs ])
@@ -47,10 +43,9 @@ def drawLabelledSeq(dataSeqs, labelSeqs, outPdf, figSizeRate = None,
[ labelSeq[-1][1] for labelSeq in labelSeqs ])
if figSizeRate is None:
- fig = plt.figure()
+ fig = Figure()
else:
- fig = plt.figure(figsize = ((xmax - xmin) * figSizeRate, 6.0),
- dpi = 300.0)
+ fig = Figure(figsize = ((xmax - xmin) * figSizeRate, 6.0), dpi = 300.0)
ax = fig.add_subplot(1, 1, 1)
for x, y1, y2 in fillBetween:
@@ -82,54 +77,59 @@ def drawLabelledSeq(dataSeqs, labelSeqs, outPdf, figSizeRate = None,
ax.set_ylim(*ylims)
if xlabel is not None:
- plt.xlabel(xlabel)
+ ax.set_xlabel(xlabel)
if ylabel is not None:
- plt.ylabel(ylabel)
+ ax.set_ylabel(ylabel)
if legend is not None:
ax.legend(legend)
- plt.savefig(outPdf)
+ canvas = FigureCanvas(fig)
+ canvas.print_figure(outPdf)
@codeDeps()
def drawWarping(transformList, outPdf, xlims, ylims = None, title = None):
- import matplotlib.pyplot as plt
-
xmin, xmax = xlims
- plt.figure()
+
+ fig = Figure()
+ ax = fig.add_subplot(1, 1, 1)
+
xs = np.linspace(xmin, xmax, 101)
if len(transformList) == 0:
print ('NOTE: no transforms being drawn for drawWarping with'
' outPdf = %s' % outPdf)
for transform in transformList:
ys = [ transform(x) for x in xs ]
- plt.plot(xs, ys, '-')
+ ax.plot(xs, ys, '-')
if ylims is not None:
- plt.ylim(*ylims)
+ ax.set_ylim(*ylims)
if title is not None:
- plt.title(title)
+ ax.set_title(title)
- plt.savefig(outPdf)
+ canvas = FigureCanvas(fig)
+ canvas.print_figure(outPdf)
@codeDeps()
def drawLogPdf(outputs, bins, outPdf, fns = [], ylims = None, title = None):
- import matplotlib.pyplot as plt
-
outputs = np.array(outputs)
assert len(np.shape(outputs)) == 1
counts, bins = np.histogram(outputs, bins = bins)
avgPdfValues = counts * 1.0 / len(outputs) / np.diff(bins)
binCentres = bins[:-1] + 0.5 * np.diff(bins)
- plt.figure()
- plt.plot(binCentres, np.log(avgPdfValues))
+ fig = Figure()
+ ax = fig.add_subplot(1, 1, 1)
+
+ ax.plot(binCentres, np.log(avgPdfValues))
for f in fns:
- plt.plot(binCentres, [ f(x) for x in binCentres ])
+ ax.plot(binCentres, [ f(x) for x in binCentres ])
if title is not None:
- plt.title(title)
- plt.xlim(bins[0], bins[-1])
+ ax.set_title(title)
+ ax.set_xlim(bins[0], bins[-1])
if ylims is not None:
- plt.ylim(*ylims)
- plt.savefig(outPdf)
+ ax.set_ylim(*ylims)
+
+ canvas = FigureCanvas(fig)
+ canvas.print_figure(outPdf)
@codeDeps(d.SynthMethod)
def drawFor1DInput(debugAcc, subDist, outPdf, xlims, ylims, title = None,
@@ -139,9 +139,7 @@ def drawFor1DInput(debugAcc, subDist, outPdf, xlims, ylims, title = None,
For both debugAcc and subDist, input should be a 1D vector and output
should be a scalar.
"""
- import matplotlib.pyplot as plt
-
- def subDrawScatter(inputs, outputs):
+ def subDrawScatter(ax, inputs, outputs):
for input in inputs:
if len(input) != 1:
raise RuntimeError('input should be a vector of length 1,'
@@ -155,17 +153,18 @@ def subDrawScatter(inputs, outputs):
else:
xs = np.array(inputs)[:, 0]
ys = np.array(outputs)
- plt.plot(xs, ys, '.', markersize = 0.2)
+ ax.plot(xs, ys, '.', markersize = 0.2)
- def subDrawMeanish(subDist, xlims, nx = 50, drawPlusMinusTwoStdev = False,
+ def subDrawMeanish(ax, subDist, xlims, nx = 50,
+ drawPlusMinusTwoStdev = False,
numSamples = 200):
xmin, xmax = xlims
xs = np.linspace(xmin, xmax, nx + 1)
ysMeanish = [
subDist.synth(np.array([x]), method = d.SynthMethod.Meanish)
for x in xs
]
- plt.plot(xs, ysMeanish, '-')
+ ax.plot(xs, ysMeanish, '-')
if drawPlusMinusTwoStdev:
ysSampleLow = []
ysSampleMiddle = []
@@ -180,37 +179,41 @@ def subDrawMeanish(subDist, xlims, nx = 50, drawPlusMinusTwoStdev = False,
ysSampleLow.append(m - 2.0 * sd)
ysSampleMiddle.append(m)
ysSampleHigh.append(m + 2.0 * sd)
- plt.plot(xs, ysSampleMiddle, '-')
- plt.plot(xs, ysSampleLow, '-')
- plt.plot(xs, ysSampleHigh, '-')
+ ax.plot(xs, ysSampleMiddle, '-')
+ ax.plot(xs, ysSampleLow, '-')
+ ax.plot(xs, ysSampleHigh, '-')
- def subDrawPdfImage(subDist, xlims, ylims, nx = 100, ny = 100):
+ def subDrawPdfImage(ax, subDist, xlims, ylims, nx = 100, ny = 100):
xmin, xmax = xlims
ymin, ymax = ylims
dx = (xmax - xmin) / nx
dy = (ymax - ymin) / ny
pdfValues = [ [ math.exp(subDist.logProb(np.array([x]), y))
for x in np.linspace(xmin, xmax, nx + 1) ]
for y in np.linspace(ymin, ymax, ny + 1) ]
- plt.imshow(
+ ax.imshow(
pdfValues,
extent = [xmin - 0.5 * dx, xmax - 0.5 * dx,
ymin - 0.5 * dy, ymax - 0.5 * dy],
origin = 'lower',
- cmap = plt.cm.gray,
+ cmap = cm.gray,
interpolation = 'nearest'
)
- plt.figure()
- subDrawScatter(debugAcc.memo.inputs, debugAcc.memo.outputs)
- subDrawMeanish(subDist, xlims,
+ fig = Figure()
+ ax = fig.add_subplot(1, 1, 1)
+
+ subDrawScatter(ax, debugAcc.memo.inputs, debugAcc.memo.outputs)
+ subDrawMeanish(ax, subDist, xlims,
drawPlusMinusTwoStdev = drawPlusMinusTwoStdev)
- subDrawPdfImage(subDist, xlims, ylims)
- plt.xlim(*xlims)
- plt.ylim(*ylims)
- plt.xlabel('input')
- plt.ylabel('output')
- plt.grid(True)
+ subDrawPdfImage(ax, subDist, xlims, ylims)
+ ax.set_xlim(*xlims)
+ ax.set_ylim(*ylims)
+ ax.set_xlabel('input')
+ ax.set_ylabel('output')
+ ax.grid(True)
if title is not None:
- plt.title(title)
- plt.savefig(outPdf)
+ ax.set_title(title)
+
+ canvas = FigureCanvas(fig)
+ canvas.print_figure(outPdf)

0 comments on commit 411e4d1

Please sign in to comment.