Skip to content

Commit

Permalink
plotting tets
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Jan 10, 2019
1 parent 1b297f6 commit 98a68fb
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 17 deletions.
6 changes: 4 additions & 2 deletions lifelines/fitters/aalen_additive_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,10 @@ def create_df_slicer(loc, iloc):
else:
columns = _to_list(columns)

ax = kwargs.setdefault("ax", plt.figure().add_subplot(111))
del kwargs["ax"]
if 'ax' in kwargs:
ax = kwargs.pop('ax')
else:
ax = plt.figure().add_subplot(111)

x = subset_df(self.cumulative_hazards_).index.values.astype(float)

Expand Down
21 changes: 6 additions & 15 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_naf_plotting_with_custom_colours(self, block):
naf.fit(data1)
ax = naf.plot(color="r")
naf.fit(data2)
naf.plot(ax=ax, c="k")
naf.plot(ax=ax, color="k")
self.plt.title("test_naf_plotting_with_custom_coloirs")
self.plt.show(block=block)
return
Expand Down Expand Up @@ -297,16 +297,6 @@ def test_kmf_left_censorship_plots(self, block):
self.plt.show(block=block)
return

def test_aaf_panel_dataset(self, block):

panel_dataset = load_panel_test()
aaf = AalenAdditiveFitter()
aaf.fit(panel_dataset, id_col="id", duration_col="t", event_col="E")
aaf.plot()
self.plt.title("test_aaf_panel_dataset")
self.plt.show(block=block)
return

def test_aalen_additive_fit_no_censor(self, block):
n = 2500
d = 6
Expand All @@ -319,14 +309,15 @@ def test_aalen_additive_fit_no_censor(self, block):
T = generate_random_lifetimes(hz, timeline)
X["T"] = T
X["E"] = np.random.binomial(1, 1, n)
X[np.isinf(X)] = 10
aaf = AalenAdditiveFitter()
aaf.fit(X, "T", "E")

for i in range(d + 1):
ax = self.plt.subplot(d + 1, 1, i + 1)
col = cumulative_hazards.columns[i]
ax = cumulative_hazards[col].loc[:15].plot(legend=False, ax=ax)
ax = aaf.plot(loc=slice(0, 15), ax=ax, columns=[col], legend=False)
ax = cumulative_hazards[col].loc[:15].plot(ax=ax)
ax = aaf.plot(loc=slice(0, 15), ax=ax, columns=[col])
self.plt.title("test_aalen_additive_fit_no_censor")
self.plt.show(block=block)
return
Expand All @@ -351,8 +342,8 @@ def test_aalen_additive_fit_with_censor(self, block):
for i in range(d + 1):
ax = self.plt.subplot(d + 1, 1, i + 1)
col = cumulative_hazards.columns[i]
ax = cumulative_hazards[col].loc[:15].plot(legend=False, ax=ax)
ax = aaf.plot(loc=slice(0, 15), ax=ax, columns=[col], legend=False)
ax = cumulative_hazards[col].loc[:15].plot(ax=ax)
ax = aaf.plot(loc=slice(0, 15), ax=ax, columns=[col])
self.plt.title("test_aalen_additive_fit_with_censor")
self.plt.show(block=block)
return

0 comments on commit 98a68fb

Please sign in to comment.