Skip to content
This repository has been archived by the owner on Jul 12, 2022. It is now read-only.

Commit

Permalink
metrics: Colormap on arrow-graphs use direction (instead of Y-sign).
Browse files Browse the repository at this point in the history
  • Loading branch information
ankostis committed Jan 16, 2015
1 parent cd6057c commit 345ec63
Showing 1 changed file with 23 additions and 35 deletions.
58 changes: 23 additions & 35 deletions wltp/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from numpy import polyfit, polyval
from wltp import model

import math
import numpy as np


Expand All @@ -29,7 +30,7 @@ class MidPointNorm(Normalize):
imshow(X, norm=norm)
"""
def __init__(self, midpoint=0, vmin=None, vmax=None, clip=False):
Normalize.__init__(self,vmin, vmax, clip)
Normalize.__init__(self, vmin, vmax, clip)
self.midpoint = midpoint

def __call__(self, value, clip=None):
Expand All @@ -42,11 +43,11 @@ def __call__(self, value, clip=None):
vmin, vmax, midpoint = self.vmin, self.vmax, self.midpoint

if not (vmin < midpoint < vmax):
raise ValueError("midpoint must be between maxvalue and minvalue.")
raise ValueError("Midpoint(%s) must be between minvalue(%s) and maxvalue(%s)!"%(midpoint, vmin, vmax))
elif vmin == vmax:
result.fill(0) # Or should it be all masked? Or 0.5?
elif vmin > vmax:
raise ValueError("maxvalue must be bigger than minvalue")
raise ValueError("Maxvalue(%s) must be bigger than minvalue(%s)!"%(vmin, vmax))
else:
vmin = float(vmin)
vmax = float(vmax)
Expand Down Expand Up @@ -139,15 +140,21 @@ def plot_class_limits(axis, y):
# bbox=bbox, size=8)


def calc_2D_diff_on_Y(X1, Y1, X2, Y2):
## From http://stackoverflow.com/questions/20924085/python-conversion-between-coordinates
#
def cart2pol(x, y):
rho = np.sqrt(x**2 + y**2)
phi = np.arctan2(y, x)
return(rho, phi)

def cartesians_to_polarDiffs(X1, Y1, X2, Y2):
"""
Given 2 sets of 2D-points calcs the euclidean distance from 2nd to 1st with sign based on the Y axis.
Given 2 sets of 2D-points calcs the polar euclidean-distance and angle from 2nd-pair of points to 1st-pair.
"""
U = X2 - X1
V = Y2 - Y1
DIFF = np.sqrt(U ** 2 + V ** 2)
DIFF[V < 0] = -DIFF[V < 0]
return U, V, DIFF
DIFF, ANGLE = cart2pol(U, V)
return U, V, DIFF, ANGLE


#############
Expand All @@ -161,7 +168,7 @@ def plot_xy_diffs_scatter(X, Y, X_REF, Y_REF, ref_label, data_label, diff_label=
color_diff = 'r'
alpha = 0.8

_, _, DIFF = calc_2D_diff_on_Y(X_REF, Y_REF, X, Y)
_, _, DIFF, _ = cartesians_to_polarDiffs(X_REF, Y_REF, X, Y)
if axes_tuple:
(axes, twin_axis) = axes_tuple
else:
Expand Down Expand Up @@ -204,23 +211,22 @@ def plot_xy_diffs_scatter(X, Y, X_REF, Y_REF, ref_label, data_label, diff_label=

def plot_xy_diffs_arrows(X, Y, X_REF, Y_REF, data_label, ref_label=None,
data_fmt="+k", data_kws={},
diff_label=None, diff_fmt="-r", diff_cmap=cm.PiYG, diff_kws={}, #@UndefinedVariable
diff_label=None, diff_fmt="-r", diff_cmap=cm.hsv, diff_kws={}, #@UndefinedVariable cm.PiYG
title=None, x_label=None, y_label=None,
axes_tuple=None,
mark_sections=None):
color_diff = 'r'
alpha = 0.9
cm_norm = MidPointNorm()

U, V, DIFF = calc_2D_diff_on_Y(X_REF, Y_REF, X, Y)
U, V, DIFF, ANGLE = cartesians_to_polarDiffs(X_REF, Y_REF, X, Y)

if axes_tuple:
(axes, twin_axis, cbar_axes) = axes_tuple
(axes, twin_axis) = axes_tuple
else:
bottom = 0.1
height = 0.8
axes = plt.axes([0.1, bottom, 0.80, height])
cbar_axes = plt.axes([0.90, bottom, 0.12, height])

## Prepare axes
#
Expand All @@ -235,7 +241,7 @@ def plot_xy_diffs_arrows(X, Y, X_REF, Y_REF, data_label, ref_label=None,
twin_axis.yaxis.grid(True, color=color_diff)

plt.title(title, axes=axes)
axes_tuple = (axes, twin_axis, cbar_axes)
axes_tuple = (axes, twin_axis)


if mark_sections == 'classes':
Expand All @@ -246,7 +252,7 @@ def plot_xy_diffs_arrows(X, Y, X_REF, Y_REF, data_label, ref_label=None,
## Plot data
#
l_ref = axes.quiver(X, Y, U, V,
DIFF, cmap=diff_cmap, norm=cm_norm,
ANGLE, cmap=diff_cmap, norm=cm_norm,
scale_units='xy', angles='xy', scale=1,
width=0.004, alpha=alpha,
pivot='tip'
Expand All @@ -255,29 +261,11 @@ def plot_xy_diffs_arrows(X, Y, X_REF, Y_REF, data_label, ref_label=None,
l_data, = axes.plot(X, Y, data_fmt, label=data_label, **data_kws)
l_data.set_picker(3)

l_diff = twin_axis.plot(X, DIFF, '.', color=color_diff, markersize=0.7)
line_points, regress_poly = fit_straight_line(X, DIFF)
l_diff = twin_axis.plot(X, V, '.', color=color_diff, markersize=0.7)
line_points, regress_poly = fit_straight_line(X, V)
l_diff_fitted, = twin_axis.plot(line_points, polyval(regress_poly, line_points), diff_fmt,
label=diff_label, **diff_kws)

## Colormap legend
#
min_DIFF = DIFF.min()
max_DIFF = DIFF.max()
nsamples = 20
m = np.linspace(min_DIFF, max_DIFF, nsamples)
m.resize((nsamples, 1))
extent = max_DIFF - min_DIFF
cbar_axes.imshow(m, cmap=diff_cmap, norm=cm_norm, aspect=600/extent, origin="lower",
extent=[0, 12, min_DIFF, max_DIFF])
cbar_axes.xaxis.set_visible(False)
cbar_axes.yaxis.set_ticks_position('left')
cbar_axes.tick_params(axis='y', colors=color_diff, direction='inout', pad=0, labelsize=10)

twin_axis.set_ylim([min_DIFF, max_DIFF]) ## Sync diff axes.
###twin_axis.set_axis_off() ## NOTE: Hiding axis, hides grids!
twin_axis.yaxis.set_ticklabels([]) ## NOTE: Turn it on, to check axes are indeed synced

return axes_tuple, (l_data, l_ref, l_diff, l_diff_fitted)

## http://stackoverflow.com/questions/13306519/get-data-from-plot-with-matplotlib
Expand Down

0 comments on commit 345ec63

Please sign in to comment.