Skip to content

Commit

Permalink
Merge pull request #73 from anekimken/inset_axes
Browse files Browse the repository at this point in the history
Inset axes
  • Loading branch information
josesho committed Oct 2, 2019
2 parents 51dbe11 + 5b25e46 commit 2c2d759
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 19 deletions.
6 changes: 5 additions & 1 deletion dabest/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,7 @@ def plot(self, color_col=None,

fig_size=None,
dpi=100,
ax=None,

swarmplot_kwargs=None,
violinplot_kwargs=None,
Expand Down Expand Up @@ -1176,6 +1177,9 @@ def plot(self, color_col=None,
The desired dimensions of the figure as a (length, width) tuple.
dpi : int, default 100
The dots per inch of the resulting figure.
ax : matplotlib.Axes, default None
Provide an existing axes for the plots to be created. If no axes
specified, a new figure will be created with the plot.
swarmplot_kwargs : dict, default None
Pass any keyword arguments accepted by the seaborn `swarmplot`
command here, as a dict. If None, the following keywords are
Expand Down Expand Up @@ -1344,4 +1348,4 @@ def dabest_obj(self):
Returns the `dabest` object that invoked the current EffectSizeDataFrame
class.
"""
return self.__dabest_obj
return self.__dabest_obj
78 changes: 60 additions & 18 deletions dabest/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):
fig_size=None,
dpi=100,
ax=None,
swarmplot_kwargs=None,
violinplot_kwargs=None,
Expand Down Expand Up @@ -256,26 +257,67 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):
# sns.set(context="talk", style='ticks')
init_fig_kwargs = dict(figsize=fig_size, dpi=plot_kwargs["dpi"])

# Here, we hardcode some figure parameters.
if float_contrast is True:
fig, axx = plt.subplots(ncols=2,
gridspec_kw={"width_ratios": [2.5, 1],
"wspace": 0},
**init_fig_kwargs)
width_ratios_ga = [2.5, 1]
h_scpace_cummings = 0.3
if plot_kwargs["ax"] is not None:
ax = plot_kwargs["ax"]
fig = ax.get_figure()
ax_position = ax.get_position() # [[x0, y0], [x1, y1]]
rawdata_axes = ax
if float_contrast is True:
axins = rawdata_axes.inset_axes(
[1, 0,
width_ratios_ga[1]/width_ratios_ga[0], 1])
rawdata_axes.set_position( # [l, b, w, h]
[ax_position.x0,
ax_position.y0,
(ax_position.x1 - ax_position.x0) * (width_ratios_ga[0] /
sum(width_ratios_ga)),
(ax_position.y1 - ax_position.y0)])

contrast_axes = axins

else:
axins = rawdata_axes.inset_axes([0, -1 - h_scpace_cummings, 1, 1])
plot_height = ((ax_position.y1 - ax_position.y0) /
(2 + h_scpace_cummings))
rawdata_axes.set_position(
[ax_position.x0,
ax_position.y0 + (1 + h_scpace_cummings) * plot_height,
(ax_position.x1 - ax_position.x0),
plot_height])

# If the contrast axes are NOT floating, create lists to store
# raw ylims and raw tick intervals, so that I can normalize
# their ylims later.
contrast_ax_ylim_low = list()
contrast_ax_ylim_high = list()
contrast_ax_ylim_tickintervals = list()
contrast_axes = axins
ax.contrast_axes = axins

else:
fig, axx = plt.subplots(nrows=2,
gridspec_kw={"hspace": 0.3},
**init_fig_kwargs)

# If the contrast axes are NOT floating, create lists to store raw ylims
# and raw tick intervals, so that I can normalize their ylims later.
contrast_ax_ylim_low = list()
contrast_ax_ylim_high = list()
contrast_ax_ylim_tickintervals = list()

rawdata_axes = axx[0]
contrast_axes = axx[1]
# Here, we hardcode some figure parameters.
if float_contrast is True:
fig, axx = plt.subplots(
ncols=2,
gridspec_kw={"width_ratios": width_ratios_ga,
"wspace": 0},
**init_fig_kwargs)

else:
fig, axx = plt.subplots(nrows=2,
gridspec_kw={"hspace": 0.3},
**init_fig_kwargs)
# If the contrast axes are NOT floating, create lists to store
# raw ylims and raw tick intervals, so that I can normalize
# their ylims later.
contrast_ax_ylim_low = list()
contrast_ax_ylim_high = list()
contrast_ax_ylim_tickintervals = list()

rawdata_axes = axx[0]
contrast_axes = axx[1]

rawdata_axes.set_frame_on(False)
contrast_axes.set_frame_on(False)
Expand Down
146 changes: 146 additions & 0 deletions dabest/tests/test_04_inset_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 26 16:57:11 2019
@author: adam nekimken
"""

import pytest
import matplotlib as mpl

import numpy as np
import seaborn as sns
from .._api import load
from .utils import create_dummy_dataset, get_swarm_yspans
import matplotlib.pyplot as plt
mpl.use('Agg')


def test_gardner_altman_inset_plot():

base_mean = np.random.randint(10, 101)
seed, ptp, df = create_dummy_dataset(base_mean=base_mean)
print('\nSeed = {}; base mean = {}'.format(seed, base_mean))

for c in df.columns[1:-1]:
print('{}...'.format(c))

# Create Gardner-Altman plot with specified axes
f1, ax = plt.subplots(1)
rand_swarm_ylim = (np.random.uniform(base_mean-10, base_mean, 1),
np.random.uniform(base_mean, base_mean+10, 1))
two_group_unpaired = load(df, idx=(df.columns[0], c))
f1 = two_group_unpaired.mean_diff.plot(swarm_ylim=rand_swarm_ylim,
swarm_label="Raw swarmplot...",
contrast_label="Contrast!",
ax=ax)

rawswarm_axes = ax
contrast_axes = ax.contrast_axes

# Check ylims match the desired ones.
assert rawswarm_axes.get_ylim()[0] == pytest.approx(rand_swarm_ylim[0])
assert rawswarm_axes.get_ylim()[1] == pytest.approx(rand_swarm_ylim[1])

# Check each swarmplot group matches canonical seaborn swarmplot.
_, swarmplt = plt.subplots(1)
swarmplt.set_ylim(rand_swarm_ylim)
sns.swarmplot(data=df[[df.columns[0], c]], ax=swarmplt)
sns_yspans = []
for coll in swarmplt.collections:
sns_yspans.append(get_swarm_yspans(coll))
dabest_yspans = [get_swarm_yspans(coll)
for coll in rawswarm_axes.collections]
for j, span in enumerate(sns_yspans):
assert span == pytest.approx(dabest_yspans[j])

# Check xtick labels.
swarm_xticks = [a.get_text() for a in rawswarm_axes.get_xticklabels()]
assert swarm_xticks[0] == "{}\nN = 30".format(df.columns[0])
assert swarm_xticks[1] == "{}\nN = 30".format(c)

contrast_xticks = [a.get_text()
for a in contrast_axes.get_xticklabels()]
assert contrast_xticks[1] == "{}\nminus\n{}".format(c, df.columns[0])

# Check ylabels.
assert rawswarm_axes.get_ylabel() == "Raw swarmplot..."
assert contrast_axes.get_ylabel() == "Contrast!"





#def test_cummings_unpaired():
# base_mean = np.random.randint(-5, 5)
# seed, ptp, df = create_dummy_dataset(base_mean=base_mean, expt_groups=7)
# print('\nSeed = {}; base mean = {}'.format(seed, base_mean))
#
# IDX = (('0','5'), ('3','2'), ('4', '1', '6'))
# multi_2group_unpaired = load(df, idx=IDX)
#
# rand_swarm_ylim = (np.random.uniform(base_mean-10, base_mean, 1),
# np.random.uniform(base_mean, base_mean+10, 1))
#
# if base_mean == 0:
# # Have to set the contrast ylim, because the way I dynamically generate
# # the contrast ylims will flunk out with base_mean = 0.
# rand_contrast_ylim = (-0.5, 0.5)
# else:
# rand_contrast_ylim = (-base_mean/3, base_mean/3)
#
# f1 = multi_2group_unpaired.mean_diff.plot(swarm_ylim=rand_swarm_ylim,
# contrast_ylim=rand_contrast_ylim,
# swarm_label="Raw swarmplot!",
# contrast_label="Contrast...")
#
# rawswarm_axes = f1.axes[0]
# contrast_axes = f1.axes[1]
#
# # Check swarm ylims match the desired ones.
# assert rawswarm_axes.get_ylim()[0] == pytest.approx(rand_swarm_ylim[0])
# assert rawswarm_axes.get_ylim()[1] == pytest.approx(rand_swarm_ylim[1])
#
# # Check contrast ylims match the desired ones.
# assert contrast_axes.get_ylim()[0] == pytest.approx(rand_contrast_ylim[0])
# assert contrast_axes.get_ylim()[1] == pytest.approx(rand_contrast_ylim[1])
#
# # Check xtick labels.
# idx_flat = [g for t in IDX for g in t]
# swarm_xticks = [a.get_text() for a in rawswarm_axes.get_xticklabels()]
# for j, xtick in enumerate(swarm_xticks):
# assert xtick == "{}\nN = 30".format(idx_flat[j])
#
# contrast_xticks = [a.get_text() for a in contrast_axes.get_xticklabels()]
# assert contrast_xticks[1] == "5\nminus\n0"
# assert contrast_xticks[3] == "2\nminus\n3"
# assert contrast_xticks[5] == "1\nminus\n4"
# assert contrast_xticks[6] == "6\nminus\n4"
#
# # Check ylabels.
# assert rawswarm_axes.get_ylabel() == "Raw swarmplot!"
# assert contrast_axes.get_ylabel() == "Contrast..."
#
#
#
#
#
#def test_gardner_altman_paired():
# base_mean = np.random.randint(-5, 5)
# seed, ptp, df = create_dummy_dataset(base_mean=base_mean)
#
#
# # Check that the plot data matches the raw data.
# two_group_paired = load(df, idx=("1", "2"), id_col="idcol", paired=True)
# f1 = two_group_paired.mean_diff.plot()
# rawswarm_axes = f1.axes[0]
# contrast_axes = f1.axes[1]
# assert df['1'].tolist() == [l.get_ydata()[0] for l in rawswarm_axes.lines]
# assert df['2'].tolist() == [l.get_ydata()[1] for l in rawswarm_axes.lines]
#
#
# # Check that id_col must be specified.
# err_to_catch = "`id_col` must be specified if `is_paired` is set to True."
# with pytest.raises(IndexError, match=err_to_catch):
# this_will_not_work = load(df, idx=("1", "2"), paired=True)

0 comments on commit 2c2d759

Please sign in to comment.