Skip to content

Commit

Permalink
Merge pull request #724 from weixuanfu/tpdm_jupyter
Browse files Browse the repository at this point in the history
Detect whether TPOT is running in a notebook
  • Loading branch information
weixuanfu committed Jul 16, 2018
2 parents b1ccc25 + 41cf783 commit f66b893
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
8 changes: 7 additions & 1 deletion tests/tpot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"""

from tpot import TPOTClassifier, TPOTRegressor
from tpot.base import TPOTBase
from tpot.base import TPOTBase, is_notebook
from tpot.driver import float_range
from tpot.gp_types import Output_Array
from tpot.gp_deap import mutNodeReplacement, _wrapped_cross_val_score, pick_two_individuals_eligible_for_crossover, cxOnePoint, varOr, initialize_stats_dict
Expand Down Expand Up @@ -148,6 +148,12 @@ def test_init_custom_parameters():
assert not (tpot_obj._toolbox is None)


def test_is_notebook():
"""Assert that isnotebook function works as expected."""
ret = is_notebook()
assert not ret


def test_init_default_scoring():
"""Assert that TPOT intitializes with the correct default scoring function."""
tpot_obj = TPOTRegressor()
Expand Down
27 changes: 23 additions & 4 deletions tpot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from scipy import sparse
import deap
from deap import base, creator, tools, gp
from tqdm import tqdm
from copy import copy, deepcopy

from sklearn.base import BaseEstimator
Expand Down Expand Up @@ -94,6 +93,26 @@ def handler(dwCtrlType, hook_sigint=_thread.interrupt_main):

win32api.SetConsoleCtrlHandler(handler, 1)

def is_notebook():
"""Check if TPOT is running in Jupyter notebook.
Returns
-------
True: TPOT is running in Jupyter notebook
False: TPOT is running in other terminals
"""
try:
from IPython import get_ipython
shell = get_ipython().__class__.__name__
# if shell == 'TerminalInteractiveShell', then Terminal running IPython
return shell == 'ZMQInteractiveShell'
except:
return False

if is_notebook():
from tqdm import tqdm_notebook as tqdm
else:
from tqdm import tqdm


class TPOTBase(BaseEstimator):
"""Automatically creates and optimizes machine learning pipelines using GP."""
Expand Down Expand Up @@ -526,9 +545,9 @@ def fit(self, features, target, sample_weight=None, groups=None):
target: array-like {n_samples}
List of class labels for prediction
sample_weight: array-like {n_samples}, optional
Per-sample weights. Higher weights indicate more importance. If specified,
sample_weight will be passed to any pipeline element whose fit() function accepts
a sample_weight argument. By default, using sample_weight does not affect tpot's
Per-sample weights. Higher weights indicate more importance. If specified,
sample_weight will be passed to any pipeline element whose fit() function accepts
a sample_weight argument. By default, using sample_weight does not affect tpot's
scoring functions, which determine preferences between pipelines.
groups: array-like, with shape {n_samples, }, optional
Group labels for the samples used when performing cross-validation.
Expand Down

0 comments on commit f66b893

Please sign in to comment.