Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ recursive-include autosklearn/metalearning/files *.txt
include autosklearn/util/logging.yaml
recursive-include autosklearn *.pyx
include requirements.txt
recursive-include autosklearn/experimental/askl2_portfolios *.json
include autosklearn/experimental/askl2_training_data.json
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,28 @@ auto-sklearn is an automated machine learning toolkit and a drop-in replacement

Find the documentation [here](http://automl.github.io/auto-sklearn/)

## Automated Machine Learning in four lines of code

```python
import autosklearn.classification
cls = autosklearn.classification.AutoSklearnClassifier()
cls.fit(X_train, y_train)
predictions = cls.predict(X_test)
```

## Relevant publications

Efficient and Robust Automated Machine Learning
Matthias Feurer, Aaron Klein, Katharina Eggensperger, Jost Springenberg, Manuel Blum and Frank Hutter
Advances in Neural Information Processing Systems 28 (2015)
http://papers.nips.cc/paper/5872-efficient-and-robust-automated-machine-learning.pdf

Auto-Sklearn 2.0: The Next Generation
Authors: Matthias Feurer, Katharina Eggensperger, Stefan Falkner, Marius Lindauer and Frank Hutter
To appear

## Status

Status for master branch

[![Build Status](https://travis-ci.org/automl/auto-sklearn.svg?branch=master)](https://travis-ci.org/automl/auto-sklearn)
Expand Down
2 changes: 1 addition & 1 deletion autosklearn/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Version information."""

# The following line *must* be the last in the module, exactly as formatted:
__version__ = "0.7.1"
__version__ = "0.8.0"
25 changes: 17 additions & 8 deletions autosklearn/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from autosklearn.util.hash import hash_array_or_matrix
from autosklearn.metrics import f1_macro, accuracy, r2
from autosklearn.constants import MULTILABEL_CLASSIFICATION, MULTICLASS_CLASSIFICATION, \
REGRESSION_TASKS, REGRESSION, BINARY_CLASSIFICATION
REGRESSION_TASKS, REGRESSION, BINARY_CLASSIFICATION, MULTIOUTPUT_REGRESSION


def _model_predict(model, X, batch_size, logger, task):
Expand Down Expand Up @@ -936,13 +936,16 @@ def _check_X(self, X):

def _check_y(self, y):
y = sklearn.utils.check_array(y, ensure_2d=False)

y = np.atleast_1d(y)
if y.ndim == 2 and y.shape[1] == 1:

if y.ndim == 1:
return y
elif y.ndim == 2 and y.shape[1] == 1:
warnings.warn("A column-vector y was passed when a 1d array was"
" expected. Will change shape via np.ravel().",
sklearn.utils.DataConversionWarning, stacklevel=2)
y = np.ravel(y)
return y

return y

Expand Down Expand Up @@ -1097,6 +1100,9 @@ def predict_proba(self, X, batch_size=None, n_jobs=1):
class AutoMLRegressor(BaseAutoML):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._task_mapping = {'continuous-multioutput': MULTIOUTPUT_REGRESSION,
'continuous': REGRESSION,
'multiclass': REGRESSION}

def fit(
self,
Expand All @@ -1110,17 +1116,20 @@ def fit(
load_models: bool = True,
):
X, y = super()._perform_input_checks(X, y)
_n_outputs = 1 if len(y.shape) == 1 else y.shape[1]
if _n_outputs > 1:
raise NotImplementedError(
'Multi-output regression is not implemented.')
y_task = type_of_target(y)
task = self._task_mapping.get(y_task)
if task is None:
raise ValueError('Cannot work on data of type %s' % y_task)

if self._metric is None:
self._metric = r2

self._n_outputs = 1 if len(y.shape) == 1 else y.shape[1]
return super().fit(
X, y,
X_test=X_test,
y_test=y_test,
task=REGRESSION,
task=task,
feat_type=feat_type,
dataset_name=dataset_name,
only_return_configuration_space=only_return_configuration_space,
Expand Down
9 changes: 6 additions & 3 deletions autosklearn/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
MULTICLASS_CLASSIFICATION = 2
MULTILABEL_CLASSIFICATION = 3
REGRESSION = 4
MULTIOUTPUT_REGRESSION = 5

REGRESSION_TASKS = [REGRESSION]
REGRESSION_TASKS = [REGRESSION, MULTIOUTPUT_REGRESSION]
CLASSIFICATION_TASKS = [BINARY_CLASSIFICATION, MULTICLASS_CLASSIFICATION,
MULTILABEL_CLASSIFICATION]

Expand All @@ -15,10 +16,12 @@
{BINARY_CLASSIFICATION: 'binary.classification',
MULTICLASS_CLASSIFICATION: 'multiclass.classification',
MULTILABEL_CLASSIFICATION: 'multilabel.classification',
REGRESSION: 'regression'}
REGRESSION: 'regression',
MULTIOUTPUT_REGRESSION: 'multioutput.regression'}

STRING_TO_TASK_TYPES = \
{'binary.classification': BINARY_CLASSIFICATION,
'multiclass.classification': MULTICLASS_CLASSIFICATION,
'multilabel.classification': MULTILABEL_CLASSIFICATION,
'regression': REGRESSION}
'regression': REGRESSION,
'multioutput.regression': MULTIOUTPUT_REGRESSION}
3 changes: 2 additions & 1 deletion autosklearn/data/xy_data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from scipy import sparse

from autosklearn.constants import STRING_TO_TASK_TYPES, REGRESSION, BINARY_CLASSIFICATION, \
MULTICLASS_CLASSIFICATION, MULTILABEL_CLASSIFICATION
MULTICLASS_CLASSIFICATION, MULTILABEL_CLASSIFICATION, MULTIOUTPUT_REGRESSION
from autosklearn.data.abstract_data_manager import AbstractDataManager


Expand All @@ -27,6 +27,7 @@ def __init__(self, X, y, X_test, y_test, task, feat_type, dataset_name):
label_num = {
REGRESSION: 1,
BINARY_CLASSIFICATION: 2,
MULTIOUTPUT_REGRESSION: y.shape[-1],
MULTICLASS_CLASSIFICATION: len(np.unique(y)),
MULTILABEL_CLASSIFICATION: y.shape[-1]
}
Expand Down
5 changes: 2 additions & 3 deletions autosklearn/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,15 +766,15 @@ def fit(self, X, y,
X : array-like or sparse matrix of shape = [n_samples, n_features]
The training input samples.

y : array-like, shape = [n_samples]
y : array-like, shape = [n_samples] or [n_samples, n_targets]
The regression target.

X_test : array-like or sparse matrix of shape = [n_samples, n_features]
Test data input samples. Will be used to save test predictions for
all models. This allows to evaluate the performance of Auto-sklearn
over time.

y_test : array-like, shape = [n_samples]
y_test : array-like, shape = [n_samples] or [n_samples, n_targets]
The regression target. Will be used to calculate the test error
of all models. This allows to evaluate the performance of
Auto-sklearn over time.
Expand All @@ -799,7 +799,6 @@ def fit(self, X, y,
target_type = type_of_target(y)
if target_type in ['multiclass-multioutput',
'multilabel-indicator',
'continuous-multioutput',
'unknown',
]:
raise ValueError("regression with data of type %s is not"
Expand Down
20 changes: 14 additions & 6 deletions autosklearn/evaluation/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
REGRESSION_TASKS,
MULTILABEL_CLASSIFICATION,
MULTICLASS_CLASSIFICATION,
MULTIOUTPUT_REGRESSION
)
from autosklearn.pipeline.implementations.util import (
convert_multioutput_multiclass_to_multilabel
Expand Down Expand Up @@ -204,12 +205,19 @@ def _get_model(self):
random_state=self.seed,
init_params=self._init_params)
else:
dataset_properties = {
'task': self.task_type,
'sparse': self.datamanager.info['is_sparse'] == 1,
'multilabel': self.task_type == MULTILABEL_CLASSIFICATION,
'multiclass': self.task_type == MULTICLASS_CLASSIFICATION,
}
if self.task_type in REGRESSION_TASKS:
dataset_properties = {
'task': self.task_type,
'sparse': self.datamanager.info['is_sparse'] == 1,
'multioutput': self.task_type == MULTIOUTPUT_REGRESSION,
}
else:
dataset_properties = {
'task': self.task_type,
'sparse': self.datamanager.info['is_sparse'] == 1,
'multilabel': self.task_type == MULTILABEL_CLASSIFICATION,
'multiclass': self.task_type == MULTICLASS_CLASSIFICATION,
}
model = self.model_class(config=self.configuration,
dataset_properties=dataset_properties,
random_state=self.seed,
Expand Down
Empty file.
Loading