Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align ticks in at-risk count table #747

Merged
merged 10 commits into from Jun 19, 2019

Conversation

christopherahern
Copy link
Contributor

Currently, add_at_risk_counts generates a set of xticks that get added to a separate axes object below the original axes. In some cases the ticks can be mis-aligned due to differences in how the xlim and xticks objects from the initial axes object being passed in are used to generate the new axes object.

The documentation has a good example of this, which can be reproduced with the code below.

from lifelines.datasets import load_waltons
from lifelines import KaplanMeierFitter
from lifelines.plotting import add_at_risk_counts
import matplotlib.pyplot as plt

waltons = load_waltons()
T = waltons['T']
E = waltons['E']

ix = waltons['group'] == 'control'
kmf_control = KaplanMeierFitter()
kmf_control.fit(waltons.loc[ix]['T'], waltons.loc[ix]['E'], label='control')
kmf_exp = KaplanMeierFitter()
kmf_exp.fit(waltons.loc[~ix]['T'], waltons.loc[~ix]['E'], label='exp')

ax = plt.subplot(111)
ax = kmf_control.plot(ax=ax)
ax = kmf_exp.plot(ax=ax)
add_at_risk_counts(kmf_exp, kmf_control, ax=ax)

plt.show()

There is one extra tick in the at-risk counts than in the plot. This can be fixed by re-ordering how the xlim and xticks objects are created.

This MR does three things:

  1. Re-orders the creation of the axes attributes so that there are the same number of plot ticks and counts, and they line up
  2. Adds padding to the stratum label on the first tick so that labels are all right-aligned to the same point
  3. Allows any kwargs to be passed in to add_at_risk_counts that can be passed into the set_xticklabels method (e.g. fontsize)

Before

master

After

Updated

@CamDavidsonPilon
Copy link
Owner

This is pretty darn exciting! I'll review and test it locally now

@CamDavidsonPilon
Copy link
Owner

Great work, I appreciate the detail PR post as well!

In diagnosing this, I've also turned up a few other bugs as well, so thanks for turning my attention to this!

@CamDavidsonPilon
Copy link
Owner

Actually, I did find one bug. If you change the KaplanMeierFitter to WeibullFitter, the bottom labels no longer show up. Example:

from lifelines.datasets import load_waltons
from lifelines import KaplanMeierFitter, WeibullFitter
from lifelines.plotting import add_at_risk_counts
import matplotlib.pyplot as plt

waltons = load_waltons()
T = waltons['T']
E = waltons['E']

ix = waltons['group'] == 'control'
kmf_control = WeibullFitter()
kmf_control.fit(waltons.loc[ix]['T'], waltons.loc[ix]['E'], label='control')
kmf_exp = WeibullFitter()
kmf_exp.fit(waltons.loc[~ix]['T'], waltons.loc[~ix]['E'], label='exp')

ax = plt.subplot(111)
ax = kmf_control.plot_survival_function(ax=ax)
ax = kmf_exp.plot_survival_function(ax=ax)
add_at_risk_counts(kmf_exp, kmf_control, ax=ax)

plt.show()

master

Screen Shot 2019-06-18 at 4 50 23 PM

branch

Screen Shot 2019-06-18 at 4 50 48 PM

@CamDavidsonPilon
Copy link
Owner

CamDavidsonPilon commented Jun 18, 2019

From what I can tell, it has to do with the ax.get_xlim() being different between the two.

Edit: the problem is that we write the labels on the first tick, but this tick location is assigned to 0, and if the lower-bound of xlim is greater than 0, that tick is never shown. A solution is to skip ticks that will never be shown, and then on the first shown tick ( greater than lower-bound), write the labels. Ex code that works:

    # Add population size at times
    ticklabels = []
    has_written_labels = False

    for tick in ax2.get_xticks():
        lbl = ""

        if tick < ax2.get_xlim()[0]:
            ticklabels.append(lbl)
            continue

        # Get counts at tick
        counts = [f.durations[f.durations >= tick].shape[0] for f in fitters]
        # Create tick label
        for l, c in zip(labels, counts):
            # First shown tick is prepended with the label
            if not has_written_labels and l is not None:
                # Get length of largest count
                max_length = len(str(max(counts)))
                if is_latex_enabled():
                    s = "\n{}\\quad".format(l) + "{{:>{}d}}".format(max_length)
                else:
                    s = "\n{}   ".format(l) + "{{:>{}d}}".format(max_length)

            else:
                s = "\n{}"
            lbl += s.format(c)

        has_written_labels = True
        ticklabels.append(lbl.strip())

Screen Shot 2019-06-18 at 5 24 06 PM

@christopherahern
Copy link
Contributor Author

Hey, thanks for taking a look so quickly!

The Weibull fitter example is really useful. I think a good way to cover both cases is to take the xticks from the plot axis and filter them down to those that fall within the xlim of the plot axis. I think that will work with any subset of times plotted (e.g. automatically or via a loc that gets passed in). I'll test that out and make sure it covers Kaplan-Meier and Weibull.

Any idea why the xlim for the Weibull fitter starts at 6? The same limit is generated by plot and plot_survival_function, so there's something in _plot_estimate that's generating that. If the Weibull plot should start at zero, then that would be where it's happening.

@CamDavidsonPilon
Copy link
Owner

Any idea why the xlim for the Weibull fitter starts at 6?

The xlim automatically set by mpl after we provide mpl the data to plot. In all cases, the x-variable is the timeline attribute (which can be user provided in the call to fit). In the KaplanMeier case, the timeline attribute starts at 0, and for parametric models, the timeline starts at the minimum observation.

@CamDavidsonPilon
Copy link
Owner

I think a good way to cover both cases is to take the xticks from the plot axis and filter them down to those that fall within the xlim of the plot axis. I think that will work with any subset of times plotted (e.g. automatically or via a loc that gets passed in).

👍 seems like a good solution

@christopherahern
Copy link
Contributor Author

Ok, I changed how the xticks for the at-risk counts are being generated. They can now be passed in as a kwarg or they are generated as the visible ticks from the plot axes object.

Here is boilerplate code for the plots below.

from lifelines.datasets import load_waltons
from lifelines import KaplanMeierFitter, WeibullFitter
from lifelines.plotting import add_at_risk_counts
import matplotlib.pyplot as plt

waltons = load_waltons()
T = waltons['T']
E = waltons['E']

ix = waltons['group'] == 'control'

kmf_control = KaplanMeierFitter()
kmf_control.fit(waltons.loc[ix]['T'], waltons.loc[ix]['E'], label='control')
kmf_exp = KaplanMeierFitter()
kmf_exp.fit(waltons.loc[~ix]['T'], waltons.loc[~ix]['E'], label='exp')

wbl_control = WeibullFitter()
wbl_control.fit(waltons.loc[ix]['T'], waltons.loc[ix]['E'], label='control')
wbl_exp = WeibullFitter()
wbl_exp.fit(waltons.loc[~ix]['T'], waltons.loc[~ix]['E'], label='exp')

Default

ax = plt.subplot(111)
ax = kmf_control.plot(ax=ax)
ax = kmf_exp.plot(ax=ax)
add_at_risk_counts(kmf_exp, kmf_control, ax=ax)
plt.show()

default_kmf

ax = plt.subplot(111)
ax = wbl_control.plot_survival_function(ax=ax)
ax = wbl_exp.plot_survival_function(ax=ax)
add_at_risk_counts(wbl_exp, wbl_control, ax=ax)
plt.show()

default_wbl

Setting xlim

Works when setting xlim outside of the fitter plot methods.

ax = plt.subplot(111)
ax = kmf_control.plot(ax=ax)
ax = kmf_exp.plot(ax=ax)
ax.set_xlim(0, 30)
add_at_risk_counts(kmf_exp, kmf_control, ax=ax)
plt.show()

xlim_kmf

ax = plt.subplot(111)
ax = wbl_control.plot_survival_function(ax=ax)
ax = wbl_exp.plot_survival_function(ax=ax)
ax.set_xlim(0, 80)
add_at_risk_counts(wbl_exp, wbl_control, ax=ax)
plt.show()

xlim_wbl

Works when plotting subset of time

Not entirely sure what to expect for the use of loc here for the plots themselves. The at-risk counts do appear to be showing up only at the visible ticks though.

ax = plt.subplot(111)
ax = kmf_control.plot(ax=ax, loc=slice(10, 60))
ax = kmf_exp.plot(ax=ax, loc=slice(10, 60))
add_at_risk_counts(kmf_exp, kmf_control, ax=ax)
plt.show()

loc_kmf

ax = plt.subplot(111)
ax = kmf_control.plot(ax=ax, loc=slice(30, 40))
ax = kmf_exp.plot(ax=ax, loc=slice(30, 40))
add_at_risk_counts(kmf_exp, kmf_control, ax=ax)
plt.show()

loc_wbl

Supply xticks as kwarg

ax = plt.subplot(111)
ax = kmf_control.plot(ax=ax)
ax = kmf_exp.plot(ax=ax)
add_at_risk_counts(kmf_exp, kmf_control, ax=ax, xticks=[15, 23, 55])
plt.show()

xticks_kmf

ax = plt.subplot(111)
ax = kmf_control.plot(ax=ax)
ax = kmf_exp.plot(ax=ax)
add_at_risk_counts(kmf_exp, kmf_control, ax=ax, xticks=[17, 30, 60])
plt.show()

xticks_wbl

Modify at-risk counts text

ax = plt.subplot(111)
ax = kmf_control.plot(ax=ax)
ax = kmf_exp.plot(ax=ax)
add_at_risk_counts(kmf_exp, kmf_control, ax=ax, fontsize=8, color='b', fontstyle='italic')
plt.show()

text_kmf

Let me know if that looks good.

@CamDavidsonPilon
Copy link
Owner

Let me know if that looks good.

looks great! Thanks for this attention to detail!

@CamDavidsonPilon CamDavidsonPilon merged commit 52fba6e into CamDavidsonPilon:master Jun 19, 2019
@christopherahern
Copy link
Contributor Author

Great! Thanks for your help and great work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants