Skip to content

Commit

Permalink
generate.py is a mess
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Feb 1, 2019
1 parent a855d38 commit 66f3167
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 25 deletions.
2 changes: 1 addition & 1 deletion lifelines/fitters/cox_time_varying_fitter.py
Expand Up @@ -683,7 +683,7 @@ def plot(self, columns=None, **errorbar_kwargs):
ax.vlines(0, -2, len(columns) + 1, linestyles="dashed", linewidths=1, alpha=0.65)
ax.set_ylim(best_ylim)

tick_labels = columns[order]
tick_labels = [columns[i] for i in order]

plt.yticks(yaxis_locations, tick_labels)
plt.xlabel("log(HR) (%g%% CI)" % (self.alpha * 100))
Expand Down
2 changes: 1 addition & 1 deletion lifelines/fitters/coxph_fitter.py
Expand Up @@ -1542,7 +1542,7 @@ def plot(self, columns=None, **errorbar_kwargs):
ax.vlines(0, -2, len(columns) + 1, linestyles="dashed", linewidths=1, alpha=0.65)
ax.set_ylim(best_ylim)

tick_labels = columns[order]
tick_labels = [columns[i] for i in order]

plt.yticks(yaxis_locations, tick_labels)
plt.xlabel("log(HR) (%g%% CI)" % (self.alpha * 100))
Expand Down
4 changes: 2 additions & 2 deletions lifelines/generate_datasets.py
Expand Up @@ -155,7 +155,7 @@ def time_varying_coefficients(d, timelines, constant=False, independent=0, randg
try:
a = np.arange(d)
random.shuffle(a)
independent_variables = a[:independent]
independent = a[:independent]
except IndexError:
pass

Expand All @@ -164,7 +164,7 @@ def time_varying_coefficients(d, timelines, constant=False, independent=0, randg
data_generators = []
for i in range(d):
f = FUNCS[random.randint(0, n_funcs)] if not constant else constant_
if i in independent_variables:
if i in independent:
beta = 0
else:
beta = randgen((1 - constant) * 0.5 / d)
Expand Down
4 changes: 2 additions & 2 deletions lifelines/plotting.py
Expand Up @@ -219,8 +219,8 @@ def plot_lifetimes(
if entry is None:
entry = np.zeros(N)

assert durations.shape == (N,)
assert event_observed.shape == (N,)
assert durations.shape[0] == N
assert event_observed.shape[0] == N

if sort_by_duration:
# order by length of lifetimes;
Expand Down
36 changes: 17 additions & 19 deletions tests/test_plotting.py
Expand Up @@ -23,6 +23,11 @@
from lifelines.generate_datasets import cumulative_integral


@pytest.fixture()
def waltons():
return load_waltons()[["T", "E"]].iloc[:50]


@pytest.mark.skipif("DISPLAY" not in os.environ, reason="requires display")
class TestPlotting:
@pytest.fixture
Expand Down Expand Up @@ -149,26 +154,22 @@ def test_naf_plotting_slice(self, block):
self.plt.show(block=block)
return

def test_plot_lifetimes_calendar(self, block):
t = np.linspace(0, 20, 1000)
hz, coef, covrt = generate_hazard_rates(1, 5, t)
N = 20
def test_plot_lifetimes_calendar(self, block, waltons):
T, E = waltons["T"], waltons["E"]
current = 10
birthtimes = current * np.random.uniform(size=(N,))
T, C = generate_random_lifetimes(hz, t, size=N, censor=current - birthtimes)
ax = plot_lifetimes(T, event_observed=C, entry=birthtimes)
birthtimes = current * np.random.uniform(size=(T.shape[0],))
ax = plot_lifetimes(T, event_observed=E, entry=birthtimes)
assert ax is not None
self.plt.title("test_plot_lifetimes_calendar")
self.plt.show(block=block)

def test_plot_lifetimes_left_truncation(self, block):
t = np.linspace(0, 20, 1000)
hz, coef, covrt = generate_hazard_rates(1, 5, t)
def test_plot_lifetimes_left_truncation(self, block, waltons):
T, E = waltons["T"], waltons["E"]
N = 20
current = 10
birthtimes = current * np.random.uniform(size=(N,))
T, C = generate_random_lifetimes(hz, t, size=N, censor=current - birthtimes)
ax = plot_lifetimes(T, event_observed=C, entry=birthtimes, left_truncated=True)

birthtimes = current * np.random.uniform(size=(T.shape[0],))
ax = plot_lifetimes(T, event_observed=E, entry=birthtimes, left_truncated=True)
assert ax is not None
self.plt.title("test_plot_lifetimes_left_truncation")
self.plt.show(block=block)
Expand All @@ -189,12 +190,9 @@ def test_MACS_data_with_plot_lifetimes(self, block):
self.plt.title("test_MACS_data_with_plot_lifetimes")
self.plt.show(block=block)

def test_plot_lifetimes_relative(self, block):
t = np.linspace(0, 20, 1000)
hz, coef, covrt = generate_hazard_rates(1, 5, t)
N = 20
T, C = generate_random_lifetimes(hz, t, size=N, censor=True)
ax = plot_lifetimes(T, event_observed=C)
def test_plot_lifetimes_relative(self, block, waltons):
T, E = waltons["T"], waltons["E"]
ax = plot_lifetimes(T, event_observed=E)
assert ax is not None
self.plt.title("test_plot_lifetimes_relative")
self.plt.show(block=block)
Expand Down

0 comments on commit 66f3167

Please sign in to comment.