Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove art.classifiers and art.wrappers #1256

Merged
merged 8 commits into from
Sep 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 0 additions & 30 deletions art/classifiers/__init__.py

This file was deleted.

17 changes: 0 additions & 17 deletions art/classifiers/scikitlearn/__init__.py

This file was deleted.

1 change: 1 addition & 0 deletions art/estimators/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from art.estimators.classification.lightgbm import LightGBMClassifier
from art.estimators.classification.mxnet import MXClassifier
from art.estimators.classification.pytorch import PyTorchClassifier
from art.estimators.classification.query_efficient_bb import QueryEfficientGradientEstimationClassifier
from art.estimators.classification.scikitlearn import SklearnClassifier
from art.estimators.classification.tensorflow import (
TFClassifier,
Expand Down
4 changes: 2 additions & 2 deletions art/estimators/classification/blackbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

class BlackBoxClassifier(ClassifierMixin, BaseEstimator):
"""
Wrapper class for black-box classifiers.
Class for black-box classifiers.
"""

estimator_params = Classifier.estimator_params + ["nb_classes", "input_shape", "predict_fn"]
Expand Down Expand Up @@ -163,7 +163,7 @@ def save(self, filename: str, path: Optional[str] = None) -> None:

class BlackBoxClassifierNeuralNetwork(NeuralNetworkMixin, ClassifierMixin, BaseEstimator):
"""
Wrapper class for black-box neural network classifiers.
Class for black-box neural network classifiers.
"""

estimator_params = (
Expand Down
2 changes: 1 addition & 1 deletion art/estimators/classification/catboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

class CatBoostARTClassifier(ClassifierDecisionTree):
"""
Wrapper class for importing CatBoost models.
Class for importing CatBoost models.
"""

estimator_params = ClassifierDecisionTree.estimator_params + ["nb_features"]
Expand Down
2 changes: 1 addition & 1 deletion art/estimators/classification/detector_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def loss_gradient( # pylint: disable=W0221
def layer_names(self) -> List[str]:
"""
Return the hidden layers in the model, if applicable. This function is not supported for the
Classifier and Detector wrapper.
Classifier and Detector classes.

:return: The hidden layers in the model, input and output layers excluded.
:raises `NotImplementedException`: This method is not supported for detector-classifiers.
Expand Down
2 changes: 1 addition & 1 deletion art/estimators/classification/lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

class LightGBMClassifier(ClassifierDecisionTree):
"""
Wrapper class for importing LightGBM models.
Class for importing LightGBM models.
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion art/estimators/classification/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

class MXClassifier(ClassGradientsMixin, ClassifierMixin, MXEstimator): # lgtm [py/missing-call-to-init]
"""
Wrapper class for importing MXNet Gluon models.
Class for importing MXNet Gluon models.
"""

estimator_params = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,31 @@
"""
Provides black-box gradient estimation using NES.
"""
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from typing import List, Optional, Tuple, Union, TYPE_CHECKING

import numpy as np
from scipy.stats import entropy

from art.estimators.classification.classifier import ClassifierClassLossGradients
from art.utils import clip_and_round, deprecated
from art.wrappers.wrapper import ClassifierWrapper
from art.estimators.estimator import BaseEstimator
from art.estimators.classification.classifier import ClassifierMixin, ClassifierLossGradients
from art.utils import clip_and_round

if TYPE_CHECKING:
from art.utils import CLASSIFIER_CLASS_LOSS_GRADIENTS_TYPE

logger = logging.getLogger(__name__)


@deprecated(
end_version="1.8.0",
reason="Expectation over transformation has been replaced with " "art.estimators",
replaced_by="art.preprocessing.expectation_over_transformation",
)
class QueryEfficientBBGradientEstimation(ClassifierWrapper, ClassifierClassLossGradients):
class QueryEfficientGradientEstimationClassifier(ClassifierLossGradients, ClassifierMixin, BaseEstimator):
"""
Implementation of Query-Efficient Black-box Adversarial Examples. The attack approximates the gradient by
maximizing the loss function over samples drawn from random Gaussian noise around the input.

| Paper link: https://arxiv.org/abs/1712.07113
"""

attack_params = ["num_basis", "sigma", "round_samples"]
estimator_params = ["num_basis", "sigma", "round_samples"]

def __init__(
self,
Expand All @@ -59,20 +52,19 @@ def __init__(
round_samples: float = 0.0,
) -> None:
"""
:param classifier: An instance of a `Classifier` whose loss_gradient is being approximated.
:param classifier: An instance of a classification estimator whose loss_gradient is being approximated.
:param num_basis: The number of samples to draw to approximate the gradient.
:param sigma: Scaling on the Gaussian noise N(0,1).
:param round_samples: The resolution of the input domain to round the data to, e.g., 1.0, or 1/255. Set to 0 to
disable.
"""
super().__init__(classifier)
# self.predict refers to predict of classifier
super().__init__(model=classifier.model, clip_values=classifier.clip_values)
# pylint: disable=E0203
self._predict = self.classifier.predict
self._classifier = classifier
self.num_basis = num_basis
self.sigma = sigma
self.round_samples = round_samples
self._nb_classes = self.classifier.nb_classes
self._nb_classes = self._classifier.nb_classes

@property
def input_shape(self) -> Tuple[int, ...]:
Expand All @@ -81,18 +73,18 @@ def input_shape(self) -> Tuple[int, ...]:

:return: Shape of one input sample.
"""
return self._input_shape # type: ignore
return self._classifier.input_shape # type: ignore

def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray: # pylint: disable=W0221
"""
Perform prediction of the classifier for input `x`.
Perform prediction of the classifier for input `x`. Rounds results first.

:param x: Features in array of shape (nb_samples, nb_features) or (nb_samples, nb_pixels_1, nb_pixels_2,
nb_channels) or (nb_samples, nb_channels, nb_pixels_1, nb_pixels_2).
:param batch_size: Size of batches.
:return: Array of predictions of shape `(nb_inputs, nb_classes)`.
"""
return self._wrap_predict(x, batch_size=batch_size)
return self._classifier.predict(clip_and_round(x, self.clip_values, self.round_samples), batch_size=batch_size)

def fit(self, x: np.ndarray, y: np.ndarray, **kwargs) -> None:
"""
Expand Down Expand Up @@ -172,16 +164,6 @@ def loss_gradient(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:
grads = self._apply_preprocessing_gradient(x, np.array(grads))
return grads

def _wrap_predict(self, x: np.ndarray, batch_size: int = 128) -> np.ndarray:
"""
Perform prediction for a batch of inputs. Rounds results first.

:param x: Input samples.
:param batch_size: Size of batches.
:return: Array of predictions of shape `(nb_inputs, nb_classes)`.
"""
return self._predict(clip_and_round(x, self.clip_values, self.round_samples), **{"batch_size": batch_size})

def get_activations(self, x: np.ndarray, layer: Union[int, str], batch_size: int) -> np.ndarray:
"""
Return the output of the specified layer for input `x`. `layer` is specified by layer index (between 0 and
Expand Down
26 changes: 13 additions & 13 deletions art/estimators/classification/scikitlearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def SklearnClassifier(
) -> "ScikitlearnClassifier":
"""
Create a `Classifier` instance from a scikit-learn Classifier model. This is a convenience function that
instantiates the correct wrapper class for the given scikit-learn model.
instantiates the correct class for the given scikit-learn model.

:param model: scikit-learn Classifier model.
:param clip_values: Tuple of the form `(min, max)` representing the minimum and maximum values allowed
Expand Down Expand Up @@ -100,7 +100,7 @@ def SklearnClassifier(

class ScikitlearnClassifier(ClassifierMixin, ScikitlearnEstimator): # lgtm [py/missing-call-to-init]
"""
Wrapper class for scikit-learn classifier models.
Class for scikit-learn classifier models.
"""

estimator_params = ClassifierMixin.estimator_params + ScikitlearnEstimator.estimator_params + ["use_logits"]
Expand Down Expand Up @@ -284,7 +284,7 @@ def _get_nb_classes(self) -> int:

class ScikitlearnDecisionTreeClassifier(ScikitlearnClassifier):
"""
Wrapper class for scikit-learn Decision Tree Classifier models.
Class for scikit-learn Decision Tree Classifier models.
"""

def __init__(
Expand Down Expand Up @@ -430,7 +430,7 @@ def _get_leaf_nodes(self, node_id, i_tree, class_label, box) -> List["LeafNode"]

class ScikitlearnDecisionTreeRegressor(ScikitlearnDecisionTreeClassifier):
"""
Wrapper class for scikit-learn Decision Tree Regressor models.
Class for scikit-learn Decision Tree Regressor models.
"""

def __init__(
Expand Down Expand Up @@ -520,7 +520,7 @@ def _get_leaf_nodes(self, node_id, i_tree, class_label, box) -> List["LeafNode"]

class ScikitlearnExtraTreeClassifier(ScikitlearnDecisionTreeClassifier):
"""
Wrapper class for scikit-learn Extra TreeClassifier Classifier models.
Class for scikit-learn Extra TreeClassifier Classifier models.
"""

def __init__(
Expand Down Expand Up @@ -559,7 +559,7 @@ def __init__(

class ScikitlearnAdaBoostClassifier(ScikitlearnClassifier):
"""
Wrapper class for scikit-learn AdaBoost Classifier models.
Class for scikit-learn AdaBoost Classifier models.
"""

def __init__(
Expand Down Expand Up @@ -598,7 +598,7 @@ def __init__(

class ScikitlearnBaggingClassifier(ScikitlearnClassifier):
"""
Wrapper class for scikit-learn Bagging Classifier models.
Class for scikit-learn Bagging Classifier models.
"""

def __init__(
Expand Down Expand Up @@ -638,7 +638,7 @@ def __init__(

class ScikitlearnExtraTreesClassifier(ScikitlearnClassifier, DecisionTreeMixin):
"""
Wrapper class for scikit-learn Extra Trees Classifier models.
Class for scikit-learn Extra Trees Classifier models.
"""

def __init__(
Expand Down Expand Up @@ -711,7 +711,7 @@ def get_trees(self) -> List["Tree"]: # lgtm [py/similar-function]

class ScikitlearnGradientBoostingClassifier(ScikitlearnClassifier, DecisionTreeMixin):
"""
Wrapper class for scikit-learn Gradient Boosting Classifier models.
Class for scikit-learn Gradient Boosting Classifier models.
"""

def __init__(
Expand Down Expand Up @@ -785,7 +785,7 @@ def get_trees(self) -> List["Tree"]:

class ScikitlearnRandomForestClassifier(ScikitlearnClassifier):
"""
Wrapper class for scikit-learn Random Forest Classifier models.
Class for scikit-learn Random Forest Classifier models.
"""

def __init__(
Expand Down Expand Up @@ -858,7 +858,7 @@ def get_trees(self) -> List["Tree"]: # lgtm [py/similar-function]

class ScikitlearnLogisticRegression(ClassGradientsMixin, LossGradientsMixin, ScikitlearnClassifier):
"""
Wrapper class for scikit-learn Logistic Regression models.
Class for scikit-learn Logistic Regression models.
"""

def __init__(
Expand Down Expand Up @@ -1045,7 +1045,7 @@ def get_trainable_attribute_names() -> Tuple[str, str]:

class ScikitlearnGaussianNB(ScikitlearnClassifier):
"""
Wrapper class for scikit-learn Gaussian Naive Bayes models.
Class for scikit-learn Gaussian Naive Bayes models.
"""

def __init__(
Expand Down Expand Up @@ -1094,7 +1094,7 @@ def get_trainable_attribute_names() -> Tuple[str, str]:

class ScikitlearnSVC(ClassGradientsMixin, LossGradientsMixin, ScikitlearnClassifier):
"""
Wrapper class for scikit-learn C-Support Vector Classification models.
Class for scikit-learn C-Support Vector Classification models.
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion art/estimators/classification/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

class XGBoostClassifier(ClassifierDecisionTree):
"""
Wrapper class for importing XGBoost models.
Class for importing XGBoost models.
"""

estimator_params = ClassifierDecisionTree.estimator_params + [
Expand Down
2 changes: 2 additions & 0 deletions art/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from art.estimators.classification.lightgbm import LightGBMClassifier
from art.estimators.classification.mxnet import MXClassifier
from art.estimators.classification.pytorch import PyTorchClassifier
from art.estimators.classification.query_efficient_bb import QueryEfficientGradientEstimationClassifier
from art.estimators.classification.scikitlearn import (
ScikitlearnAdaBoostClassifier,
ScikitlearnBaggingClassifier,
Expand Down Expand Up @@ -114,6 +115,7 @@
ScikitlearnSVC,
TensorFlowClassifier,
TensorFlowV2Classifier,
QueryEfficientGradientEstimationClassifier,
]

CLASSIFIER_CLASS_LOSS_GRADIENTS_TYPE = Union[ # pylint: disable=C0103
Expand Down
18 changes: 0 additions & 18 deletions art/wrappers/__init__.py

This file was deleted.

Loading