Skip to content

Commit

Permalink
(draw - use matplotlib OO API rather than pyplot since fewer problems…
Browse files Browse the repository at this point in the history
… with distributing)

Ignore-this: 3ff8c4a8dce6fa01de76380422d8ec1f

darcs-hash:20130718055606-3a4db-99ad2c70fb85d02d3b18cdf91fa872b41a153d6b
  • Loading branch information
MattShannon committed Jul 18, 2013
1 parent b23f0a7 commit 411e4d1
Showing 1 changed file with 59 additions and 56 deletions.
115 changes: 59 additions & 56 deletions armspeech/speech/draw.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@
import math import math
import numpy as np import numpy as np
import armspeech.numpy_settings import armspeech.numpy_settings
import matplotlib from matplotlib.figure import Figure
import matplotlib.transforms as transforms from matplotlib import cm
# FIXME : below is necessary to avoid errors when drawing graphs when an from matplotlib import transforms
# X server is not defined, for example when running distributed jobs. from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
# This is dissatisfying, however, since stateful stuff like this should not
# be set while importing a module (or even on a per-function-call basis). # (FIXME : not sure that the matplotlib OO API is completely thread-safe.
# The best solution is probably to use the matplotlib API rather than the # It is conceivable that we should really be using explicit locks below.)
# pyplot interface, in which case you explicitly specify the backend to use.
matplotlib.use('Agg')


@codeDeps() @codeDeps()
def partitionSeq(xs, numPartitions): def partitionSeq(xs, numPartitions):
Expand All @@ -37,8 +35,6 @@ def drawLabelledSeq(dataSeqs, labelSeqs, outPdf, figSizeRate = None,
xlabel = None, ylabel = None, legend = None, xlabel = None, ylabel = None, legend = None,
lineStyles = None, lineStyles = None,
labelColors = ['red', 'purple', 'orange', 'blue']): labelColors = ['red', 'purple', 'orange', 'blue']):
import matplotlib.pyplot as plt

if xmin is None: if xmin is None:
xmin = min([ dataSeq[0][0] for dataSeq in dataSeqs ] + xmin = min([ dataSeq[0][0] for dataSeq in dataSeqs ] +
[ labelSeq[0][0] for labelSeq in labelSeqs ]) [ labelSeq[0][0] for labelSeq in labelSeqs ])
Expand All @@ -47,10 +43,9 @@ def drawLabelledSeq(dataSeqs, labelSeqs, outPdf, figSizeRate = None,
[ labelSeq[-1][1] for labelSeq in labelSeqs ]) [ labelSeq[-1][1] for labelSeq in labelSeqs ])


if figSizeRate is None: if figSizeRate is None:
fig = plt.figure() fig = Figure()
else: else:
fig = plt.figure(figsize = ((xmax - xmin) * figSizeRate, 6.0), fig = Figure(figsize = ((xmax - xmin) * figSizeRate, 6.0), dpi = 300.0)
dpi = 300.0)
ax = fig.add_subplot(1, 1, 1) ax = fig.add_subplot(1, 1, 1)


for x, y1, y2 in fillBetween: for x, y1, y2 in fillBetween:
Expand Down Expand Up @@ -82,54 +77,59 @@ def drawLabelledSeq(dataSeqs, labelSeqs, outPdf, figSizeRate = None,
ax.set_ylim(*ylims) ax.set_ylim(*ylims)


if xlabel is not None: if xlabel is not None:
plt.xlabel(xlabel) ax.set_xlabel(xlabel)
if ylabel is not None: if ylabel is not None:
plt.ylabel(ylabel) ax.set_ylabel(ylabel)
if legend is not None: if legend is not None:
ax.legend(legend) ax.legend(legend)


plt.savefig(outPdf) canvas = FigureCanvas(fig)
canvas.print_figure(outPdf)


@codeDeps() @codeDeps()
def drawWarping(transformList, outPdf, xlims, ylims = None, title = None): def drawWarping(transformList, outPdf, xlims, ylims = None, title = None):
import matplotlib.pyplot as plt

xmin, xmax = xlims xmin, xmax = xlims
plt.figure()
fig = Figure()
ax = fig.add_subplot(1, 1, 1)

xs = np.linspace(xmin, xmax, 101) xs = np.linspace(xmin, xmax, 101)
if len(transformList) == 0: if len(transformList) == 0:
print ('NOTE: no transforms being drawn for drawWarping with' print ('NOTE: no transforms being drawn for drawWarping with'
' outPdf = %s' % outPdf) ' outPdf = %s' % outPdf)
for transform in transformList: for transform in transformList:
ys = [ transform(x) for x in xs ] ys = [ transform(x) for x in xs ]
plt.plot(xs, ys, '-') ax.plot(xs, ys, '-')
if ylims is not None: if ylims is not None:
plt.ylim(*ylims) ax.set_ylim(*ylims)
if title is not None: if title is not None:
plt.title(title) ax.set_title(title)


plt.savefig(outPdf) canvas = FigureCanvas(fig)
canvas.print_figure(outPdf)


@codeDeps() @codeDeps()
def drawLogPdf(outputs, bins, outPdf, fns = [], ylims = None, title = None): def drawLogPdf(outputs, bins, outPdf, fns = [], ylims = None, title = None):
import matplotlib.pyplot as plt

outputs = np.array(outputs) outputs = np.array(outputs)
assert len(np.shape(outputs)) == 1 assert len(np.shape(outputs)) == 1
counts, bins = np.histogram(outputs, bins = bins) counts, bins = np.histogram(outputs, bins = bins)
avgPdfValues = counts * 1.0 / len(outputs) / np.diff(bins) avgPdfValues = counts * 1.0 / len(outputs) / np.diff(bins)
binCentres = bins[:-1] + 0.5 * np.diff(bins) binCentres = bins[:-1] + 0.5 * np.diff(bins)


plt.figure() fig = Figure()
plt.plot(binCentres, np.log(avgPdfValues)) ax = fig.add_subplot(1, 1, 1)

ax.plot(binCentres, np.log(avgPdfValues))
for f in fns: for f in fns:
plt.plot(binCentres, [ f(x) for x in binCentres ]) ax.plot(binCentres, [ f(x) for x in binCentres ])
if title is not None: if title is not None:
plt.title(title) ax.set_title(title)
plt.xlim(bins[0], bins[-1]) ax.set_xlim(bins[0], bins[-1])
if ylims is not None: if ylims is not None:
plt.ylim(*ylims) ax.set_ylim(*ylims)
plt.savefig(outPdf)
canvas = FigureCanvas(fig)
canvas.print_figure(outPdf)


@codeDeps(d.SynthMethod) @codeDeps(d.SynthMethod)
def drawFor1DInput(debugAcc, subDist, outPdf, xlims, ylims, title = None, def drawFor1DInput(debugAcc, subDist, outPdf, xlims, ylims, title = None,
Expand All @@ -139,9 +139,7 @@ def drawFor1DInput(debugAcc, subDist, outPdf, xlims, ylims, title = None,
For both debugAcc and subDist, input should be a 1D vector and output For both debugAcc and subDist, input should be a 1D vector and output
should be a scalar. should be a scalar.
""" """
import matplotlib.pyplot as plt def subDrawScatter(ax, inputs, outputs):

def subDrawScatter(inputs, outputs):
for input in inputs: for input in inputs:
if len(input) != 1: if len(input) != 1:
raise RuntimeError('input should be a vector of length 1,' raise RuntimeError('input should be a vector of length 1,'
Expand All @@ -155,17 +153,18 @@ def subDrawScatter(inputs, outputs):
else: else:
xs = np.array(inputs)[:, 0] xs = np.array(inputs)[:, 0]
ys = np.array(outputs) ys = np.array(outputs)
plt.plot(xs, ys, '.', markersize = 0.2) ax.plot(xs, ys, '.', markersize = 0.2)


def subDrawMeanish(subDist, xlims, nx = 50, drawPlusMinusTwoStdev = False, def subDrawMeanish(ax, subDist, xlims, nx = 50,
drawPlusMinusTwoStdev = False,
numSamples = 200): numSamples = 200):
xmin, xmax = xlims xmin, xmax = xlims
xs = np.linspace(xmin, xmax, nx + 1) xs = np.linspace(xmin, xmax, nx + 1)
ysMeanish = [ ysMeanish = [
subDist.synth(np.array([x]), method = d.SynthMethod.Meanish) subDist.synth(np.array([x]), method = d.SynthMethod.Meanish)
for x in xs for x in xs
] ]
plt.plot(xs, ysMeanish, '-') ax.plot(xs, ysMeanish, '-')
if drawPlusMinusTwoStdev: if drawPlusMinusTwoStdev:
ysSampleLow = [] ysSampleLow = []
ysSampleMiddle = [] ysSampleMiddle = []
Expand All @@ -180,37 +179,41 @@ def subDrawMeanish(subDist, xlims, nx = 50, drawPlusMinusTwoStdev = False,
ysSampleLow.append(m - 2.0 * sd) ysSampleLow.append(m - 2.0 * sd)
ysSampleMiddle.append(m) ysSampleMiddle.append(m)
ysSampleHigh.append(m + 2.0 * sd) ysSampleHigh.append(m + 2.0 * sd)
plt.plot(xs, ysSampleMiddle, '-') ax.plot(xs, ysSampleMiddle, '-')
plt.plot(xs, ysSampleLow, '-') ax.plot(xs, ysSampleLow, '-')
plt.plot(xs, ysSampleHigh, '-') ax.plot(xs, ysSampleHigh, '-')


def subDrawPdfImage(subDist, xlims, ylims, nx = 100, ny = 100): def subDrawPdfImage(ax, subDist, xlims, ylims, nx = 100, ny = 100):
xmin, xmax = xlims xmin, xmax = xlims
ymin, ymax = ylims ymin, ymax = ylims
dx = (xmax - xmin) / nx dx = (xmax - xmin) / nx
dy = (ymax - ymin) / ny dy = (ymax - ymin) / ny
pdfValues = [ [ math.exp(subDist.logProb(np.array([x]), y)) pdfValues = [ [ math.exp(subDist.logProb(np.array([x]), y))
for x in np.linspace(xmin, xmax, nx + 1) ] for x in np.linspace(xmin, xmax, nx + 1) ]
for y in np.linspace(ymin, ymax, ny + 1) ] for y in np.linspace(ymin, ymax, ny + 1) ]
plt.imshow( ax.imshow(
pdfValues, pdfValues,
extent = [xmin - 0.5 * dx, xmax - 0.5 * dx, extent = [xmin - 0.5 * dx, xmax - 0.5 * dx,
ymin - 0.5 * dy, ymax - 0.5 * dy], ymin - 0.5 * dy, ymax - 0.5 * dy],
origin = 'lower', origin = 'lower',
cmap = plt.cm.gray, cmap = cm.gray,
interpolation = 'nearest' interpolation = 'nearest'
) )


plt.figure() fig = Figure()
subDrawScatter(debugAcc.memo.inputs, debugAcc.memo.outputs) ax = fig.add_subplot(1, 1, 1)
subDrawMeanish(subDist, xlims,
subDrawScatter(ax, debugAcc.memo.inputs, debugAcc.memo.outputs)
subDrawMeanish(ax, subDist, xlims,
drawPlusMinusTwoStdev = drawPlusMinusTwoStdev) drawPlusMinusTwoStdev = drawPlusMinusTwoStdev)
subDrawPdfImage(subDist, xlims, ylims) subDrawPdfImage(ax, subDist, xlims, ylims)
plt.xlim(*xlims) ax.set_xlim(*xlims)
plt.ylim(*ylims) ax.set_ylim(*ylims)
plt.xlabel('input') ax.set_xlabel('input')
plt.ylabel('output') ax.set_ylabel('output')
plt.grid(True) ax.grid(True)
if title is not None: if title is not None:
plt.title(title) ax.set_title(title)
plt.savefig(outPdf)
canvas = FigureCanvas(fig)
canvas.print_figure(outPdf)

0 comments on commit 411e4d1

Please sign in to comment.