Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inset axes #73

Merged
merged 7 commits into from
Oct 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion dabest/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,7 @@ def plot(self, color_col=None,

fig_size=None,
dpi=100,
ax=None,

swarmplot_kwargs=None,
violinplot_kwargs=None,
Expand Down Expand Up @@ -1176,6 +1177,9 @@ def plot(self, color_col=None,
The desired dimensions of the figure as a (length, width) tuple.
dpi : int, default 100
The dots per inch of the resulting figure.
ax : matplotlib.Axes, default None
Provide an existing axes for the plots to be created. If no axes
specified, a new figure will be created with the plot.
swarmplot_kwargs : dict, default None
Pass any keyword arguments accepted by the seaborn `swarmplot`
command here, as a dict. If None, the following keywords are
Expand Down Expand Up @@ -1344,4 +1348,4 @@ def dabest_obj(self):
Returns the `dabest` object that invoked the current EffectSizeDataFrame
class.
"""
return self.__dabest_obj
return self.__dabest_obj
78 changes: 60 additions & 18 deletions dabest/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):

fig_size=None,
dpi=100,
ax=None,

swarmplot_kwargs=None,
violinplot_kwargs=None,
Expand Down Expand Up @@ -256,26 +257,67 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):
# sns.set(context="talk", style='ticks')
init_fig_kwargs = dict(figsize=fig_size, dpi=plot_kwargs["dpi"])

# Here, we hardcode some figure parameters.
if float_contrast is True:
fig, axx = plt.subplots(ncols=2,
gridspec_kw={"width_ratios": [2.5, 1],
"wspace": 0},
**init_fig_kwargs)
width_ratios_ga = [2.5, 1]
h_scpace_cummings = 0.3
if plot_kwargs["ax"] is not None:
ax = plot_kwargs["ax"]
fig = ax.get_figure()
ax_position = ax.get_position() # [[x0, y0], [x1, y1]]
rawdata_axes = ax
if float_contrast is True:
axins = rawdata_axes.inset_axes(
[1, 0,
width_ratios_ga[1]/width_ratios_ga[0], 1])
rawdata_axes.set_position( # [l, b, w, h]
[ax_position.x0,
ax_position.y0,
(ax_position.x1 - ax_position.x0) * (width_ratios_ga[0] /
sum(width_ratios_ga)),
(ax_position.y1 - ax_position.y0)])

contrast_axes = axins

else:
axins = rawdata_axes.inset_axes([0, -1 - h_scpace_cummings, 1, 1])
plot_height = ((ax_position.y1 - ax_position.y0) /
(2 + h_scpace_cummings))
rawdata_axes.set_position(
[ax_position.x0,
ax_position.y0 + (1 + h_scpace_cummings) * plot_height,
(ax_position.x1 - ax_position.x0),
plot_height])

# If the contrast axes are NOT floating, create lists to store
# raw ylims and raw tick intervals, so that I can normalize
# their ylims later.
contrast_ax_ylim_low = list()
contrast_ax_ylim_high = list()
contrast_ax_ylim_tickintervals = list()
contrast_axes = axins
ax.contrast_axes = axins

else:
fig, axx = plt.subplots(nrows=2,
gridspec_kw={"hspace": 0.3},
**init_fig_kwargs)

# If the contrast axes are NOT floating, create lists to store raw ylims
# and raw tick intervals, so that I can normalize their ylims later.
contrast_ax_ylim_low = list()
contrast_ax_ylim_high = list()
contrast_ax_ylim_tickintervals = list()

rawdata_axes = axx[0]
contrast_axes = axx[1]
# Here, we hardcode some figure parameters.
if float_contrast is True:
fig, axx = plt.subplots(
ncols=2,
gridspec_kw={"width_ratios": width_ratios_ga,
"wspace": 0},
**init_fig_kwargs)

else:
fig, axx = plt.subplots(nrows=2,
gridspec_kw={"hspace": 0.3},
**init_fig_kwargs)
# If the contrast axes are NOT floating, create lists to store
# raw ylims and raw tick intervals, so that I can normalize
# their ylims later.
contrast_ax_ylim_low = list()
contrast_ax_ylim_high = list()
contrast_ax_ylim_tickintervals = list()

rawdata_axes = axx[0]
contrast_axes = axx[1]

rawdata_axes.set_frame_on(False)
contrast_axes.set_frame_on(False)
Expand Down
146 changes: 146 additions & 0 deletions dabest/tests/test_04_inset_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 26 16:57:11 2019

@author: adam nekimken
"""

import pytest
import matplotlib as mpl

import numpy as np
import seaborn as sns
from .._api import load
from .utils import create_dummy_dataset, get_swarm_yspans
import matplotlib.pyplot as plt
mpl.use('Agg')


def test_gardner_altman_inset_plot():

base_mean = np.random.randint(10, 101)
seed, ptp, df = create_dummy_dataset(base_mean=base_mean)
print('\nSeed = {}; base mean = {}'.format(seed, base_mean))

for c in df.columns[1:-1]:
print('{}...'.format(c))

# Create Gardner-Altman plot with specified axes
f1, ax = plt.subplots(1)
rand_swarm_ylim = (np.random.uniform(base_mean-10, base_mean, 1),
np.random.uniform(base_mean, base_mean+10, 1))
two_group_unpaired = load(df, idx=(df.columns[0], c))
f1 = two_group_unpaired.mean_diff.plot(swarm_ylim=rand_swarm_ylim,
swarm_label="Raw swarmplot...",
contrast_label="Contrast!",
ax=ax)

rawswarm_axes = ax
contrast_axes = ax.contrast_axes

# Check ylims match the desired ones.
assert rawswarm_axes.get_ylim()[0] == pytest.approx(rand_swarm_ylim[0])
assert rawswarm_axes.get_ylim()[1] == pytest.approx(rand_swarm_ylim[1])

# Check each swarmplot group matches canonical seaborn swarmplot.
_, swarmplt = plt.subplots(1)
swarmplt.set_ylim(rand_swarm_ylim)
sns.swarmplot(data=df[[df.columns[0], c]], ax=swarmplt)
sns_yspans = []
for coll in swarmplt.collections:
sns_yspans.append(get_swarm_yspans(coll))
dabest_yspans = [get_swarm_yspans(coll)
for coll in rawswarm_axes.collections]
for j, span in enumerate(sns_yspans):
assert span == pytest.approx(dabest_yspans[j])

# Check xtick labels.
swarm_xticks = [a.get_text() for a in rawswarm_axes.get_xticklabels()]
assert swarm_xticks[0] == "{}\nN = 30".format(df.columns[0])
assert swarm_xticks[1] == "{}\nN = 30".format(c)

contrast_xticks = [a.get_text()
for a in contrast_axes.get_xticklabels()]
assert contrast_xticks[1] == "{}\nminus\n{}".format(c, df.columns[0])

# Check ylabels.
assert rawswarm_axes.get_ylabel() == "Raw swarmplot..."
assert contrast_axes.get_ylabel() == "Contrast!"





#def test_cummings_unpaired():
# base_mean = np.random.randint(-5, 5)
# seed, ptp, df = create_dummy_dataset(base_mean=base_mean, expt_groups=7)
# print('\nSeed = {}; base mean = {}'.format(seed, base_mean))
#
# IDX = (('0','5'), ('3','2'), ('4', '1', '6'))
# multi_2group_unpaired = load(df, idx=IDX)
#
# rand_swarm_ylim = (np.random.uniform(base_mean-10, base_mean, 1),
# np.random.uniform(base_mean, base_mean+10, 1))
#
# if base_mean == 0:
# # Have to set the contrast ylim, because the way I dynamically generate
# # the contrast ylims will flunk out with base_mean = 0.
# rand_contrast_ylim = (-0.5, 0.5)
# else:
# rand_contrast_ylim = (-base_mean/3, base_mean/3)
#
# f1 = multi_2group_unpaired.mean_diff.plot(swarm_ylim=rand_swarm_ylim,
# contrast_ylim=rand_contrast_ylim,
# swarm_label="Raw swarmplot!",
# contrast_label="Contrast...")
#
# rawswarm_axes = f1.axes[0]
# contrast_axes = f1.axes[1]
#
# # Check swarm ylims match the desired ones.
# assert rawswarm_axes.get_ylim()[0] == pytest.approx(rand_swarm_ylim[0])
# assert rawswarm_axes.get_ylim()[1] == pytest.approx(rand_swarm_ylim[1])
#
# # Check contrast ylims match the desired ones.
# assert contrast_axes.get_ylim()[0] == pytest.approx(rand_contrast_ylim[0])
# assert contrast_axes.get_ylim()[1] == pytest.approx(rand_contrast_ylim[1])
#
# # Check xtick labels.
# idx_flat = [g for t in IDX for g in t]
# swarm_xticks = [a.get_text() for a in rawswarm_axes.get_xticklabels()]
# for j, xtick in enumerate(swarm_xticks):
# assert xtick == "{}\nN = 30".format(idx_flat[j])
#
# contrast_xticks = [a.get_text() for a in contrast_axes.get_xticklabels()]
# assert contrast_xticks[1] == "5\nminus\n0"
# assert contrast_xticks[3] == "2\nminus\n3"
# assert contrast_xticks[5] == "1\nminus\n4"
# assert contrast_xticks[6] == "6\nminus\n4"
#
# # Check ylabels.
# assert rawswarm_axes.get_ylabel() == "Raw swarmplot!"
# assert contrast_axes.get_ylabel() == "Contrast..."
#
#
#
#
#
#def test_gardner_altman_paired():
# base_mean = np.random.randint(-5, 5)
# seed, ptp, df = create_dummy_dataset(base_mean=base_mean)
#
#
# # Check that the plot data matches the raw data.
# two_group_paired = load(df, idx=("1", "2"), id_col="idcol", paired=True)
# f1 = two_group_paired.mean_diff.plot()
# rawswarm_axes = f1.axes[0]
# contrast_axes = f1.axes[1]
# assert df['1'].tolist() == [l.get_ydata()[0] for l in rawswarm_axes.lines]
# assert df['2'].tolist() == [l.get_ydata()[1] for l in rawswarm_axes.lines]
#
#
# # Check that id_col must be specified.
# err_to_catch = "`id_col` must be specified if `is_paired` is set to True."
# with pytest.raises(IndexError, match=err_to_catch):
# this_will_not_work = load(df, idx=("1", "2"), paired=True)