diff --git a/src/innvestigate/analyzer/__init__.py b/src/innvestigate/analyzer/__init__.py index fe349229..cb8e3fdb 100644 --- a/src/innvestigate/analyzer/__init__.py +++ b/src/innvestigate/analyzer/__init__.py @@ -1,9 +1,10 @@ -# Get Python six functionality: -from __future__ import absolute_import, division, print_function, unicode_literals +from __future__ import annotations -from .base import NotAnalyzeableModelException -from .deeptaylor import BoundedDeepTaylor, DeepTaylor -from .gradient_based import ( +from typing import Dict, Type + +from innvestigate.analyzer.base import AnalyzerBase, NotAnalyzeableModelException +from innvestigate.analyzer.deeptaylor import BoundedDeepTaylor, DeepTaylor +from innvestigate.analyzer.gradient_based import ( BaselineGradient, Deconvnet, Gradient, @@ -12,9 +13,9 @@ IntegratedGradients, SmoothGrad, ) -from .misc import Input, Random -from .pattern_based import PatternAttribution, PatternNet -from .relevance_based.relevance_analyzer import ( +from innvestigate.analyzer.misc import Input, Random +from innvestigate.analyzer.pattern_based import PatternAttribution, PatternNet +from innvestigate.analyzer.relevance_based.relevance_analyzer import ( LRP, LRPZ, BaselineLRPZ, @@ -36,12 +37,13 @@ LRPZPlus, LRPZPlusFast, ) -from .wrapper import AugmentReduceBase, GaussianSmoother, PathIntegrator, WrapperBase - -############################################################################### -############################################################################### -############################################################################### - +from innvestigate.analyzer.wrapper import ( + AugmentReduceBase, + GaussianSmoother, + PathIntegrator, + WrapperBase, +) +from innvestigate.utils.types import Model # Disable pyflaks warnings: assert NotAnalyzeableModelException @@ -52,12 +54,7 @@ assert PathIntegrator -############################################################################### -############################################################################### -############################################################################### - - -analyzers = { +analyzers: Dict[str, Type[AnalyzerBase]] = { # Utility. "input": Input, "random": Random, @@ -98,7 +95,7 @@ } -def create_analyzer(name, model, **kwargs): +def create_analyzer(name: str, model: Model, **kwargs) -> AnalyzerBase: """Instantiates the analyzer with the name 'name' This convenience function takes an analyzer name diff --git a/src/innvestigate/analyzer/base.py b/src/innvestigate/analyzer/base.py index 41c8f7a0..baa5afed 100644 --- a/src/innvestigate/analyzer/base.py +++ b/src/innvestigate/analyzer/base.py @@ -1,47 +1,33 @@ -# Get Python six functionality: -from __future__ import absolute_import, division, print_function, unicode_literals +from __future__ import annotations import warnings +from abc import ABCMeta, abstractmethod from builtins import zip +from typing import Any, Dict, List, Optional, Tuple -import keras.backend as K +import keras import keras.layers import keras.models import numpy as np -import six - -from .. import layers as ilayers -from .. import utils as iutils -from ..utils.keras import checks as kchecks -from ..utils.keras import graph as kgraph - -############################################################################### -############################################################################### -############################################################################### +import innvestigate.analyzer +import innvestigate.utils as iutils +import innvestigate.utils.keras.graph as kgraph +from innvestigate.utils.types import LayerCheck, Model, ModelCheckDict, OptionalList __all__ = [ "NotAnalyzeableModelException", "AnalyzerBase", "TrainerMixin", "OneEpochTrainerMixin", - "AnalyzerNetworkBase", - "ReverseAnalyzerBase", ] -############################################################################### -############################################################################### -############################################################################### - - class NotAnalyzeableModelException(Exception): """Indicates that the model cannot be analyzed by an analyzer.""" - pass - -class AnalyzerBase(object): +class AnalyzerBase(metaclass=ABCMeta): """The basic interface of an iNNvestigate analyzer. This class defines the basic interface for analyzers: @@ -58,47 +44,81 @@ class AnalyzerBase(object): :param model: A Keras model. :param disable_model_checks: Do not execute model checks that enforce compatibility of analyzer and model. + :param neuron_selection_mode: How to select the neuron to analyze. + Possible values are 'max_activation', 'index' for the neuron + (expects indices at :func:`analyze` calls), 'all' take all neurons. .. note:: To develop a new analyzer derive from :class:`AnalyzerNetworkBase`. """ - def __init__(self, model, disable_model_checks=False): + def __init__( + self, + model: Model, + neuron_selection_mode: str = "max_activation", + disable_model_checks: bool = False, + _model_check_done: bool = False, + _model_checks: List[ModelCheckDict] = None, + ) -> None: + self._model = model self._disable_model_checks = disable_model_checks + self._model_check_done = _model_check_done - self._do_model_checks() - - def _add_model_check(self, check, message, check_type="exception"): - if getattr(self, "_model_check_done", False): + # There are three possible neuron selection modes + # that return an explanation w.r.t.: + # * "max_activation": maximum activated neuron + # * "index": neuron at index given on call to `analyze` + # * "all": all output neurons + if neuron_selection_mode not in ["max_activation", "index", "all"]: + raise ValueError("neuron_selection_mode parameter is not valid.") + self._neuron_selection_mode: str = neuron_selection_mode + + # If no model checks are given, initialize an empty list of checks + # that child analyzers can append to. + if _model_checks is None: + _model_checks = [] + self._model_checks: List[ModelCheckDict] = _model_checks + + def _add_model_check( + self, check: LayerCheck, message: str, check_type: str = "exception" + ) -> None: + """Add model check to list of checks `self._model_checks`. + + :param check: Callable that performs a boolean check on a Keras layers. + :type check: LayerCheck + :param message: Error message if check fails. + :type message: str + :param check_type: Either "exception" or "warning". Defaults to "exception" + :type check_type: str, optional + :raises Exception: [description] + """ + + if self._model_check_done: raise Exception( - "Cannot add model check anymore." " Check was already performed." + "Cannot add model check anymore. Check was already performed." ) - if not hasattr(self, "_model_checks"): - self._model_checks = [] - - check_instance = { + check_instance: ModelCheckDict = { "check": check, "message": message, - "type": check_type, + "check_type": check_type, } self._model_checks.append(check_instance) - def _do_model_checks(self): - model_checks = getattr(self, "_model_checks", []) - - if not self._disable_model_checks and len(model_checks) > 0: - check = [x["check"] for x in model_checks] - types = [x["type"] for x in model_checks] - messages = [x["message"] for x in model_checks] + def _do_model_checks(self) -> None: + if not self._disable_model_checks and len(self._model_checks) > 0: + check = [x["check"] for x in self._model_checks] + types = [x["check_type"] for x in self._model_checks] + messages = [x["message"] for x in self._model_checks] checked = kgraph.model_contains(self._model, check) - tmp = zip(iutils.to_list(checked), messages, types) + + tmp = zip(checked, messages, types) for checked_layers, message, check_type in tmp: if len(checked_layers) > 0: - tmp_message = "%s\nCheck triggerd by layers: %s" % ( + tmp_message = "%s\nCheck triggered by layers: %s" % ( message, checked_layers, ) @@ -108,12 +128,11 @@ def _do_model_checks(self): elif check_type == "warning": # TODO(albermax) only the first warning will be shown warnings.warn(tmp_message) - else: - raise NotImplementedError() + raise NotImplementedError() self._model_check_done = True - def fit(self, *args, **kwargs): + def fit(self, *_args, disable_no_training_warning: bool = False, **_kwargs): """ Stub that eats arguments. If an analyzer needs training include :class:`TrainerMixin`. @@ -121,16 +140,16 @@ def fit(self, *args, **kwargs): :param disable_no_training_warning: Do not warn if this function is called despite no training is needed. """ - disable_no_training_warning = kwargs.pop("disable_no_training_warning", False) if not disable_no_training_warning: - # issue warning if not training is foreseen, - # but is fit is still called. + # issue warning if no training is foreseen, but fit() is still called. warnings.warn( "This analyzer does not need to be trained." " Still fit() is called.", RuntimeWarning, ) - def fit_generator(self, *args, **kwargs): + def fit_generator( + self, *_args, disable_no_training_warning: bool = False, **_kwargs + ): """ Stub that eats arguments. If an analyzer needs training include :class:`TrainerMixin`. @@ -138,33 +157,34 @@ def fit_generator(self, *args, **kwargs): :param disable_no_training_warning: Do not warn if this function is called despite no training is needed. """ - disable_no_training_warning = kwargs.pop("disable_no_training_warning", False) if not disable_no_training_warning: - # issue warning if not training is foreseen, - # but is fit is still called. + # issue warning if no training is foreseen, but fit() is still called. warnings.warn( "This analyzer does not need to be trained." " Still fit_generator() is called.", RuntimeWarning, ) - def analyze(self, X): + @abstractmethod + def analyze( + self, X: OptionalList[np.ndarray], *args: Any, **kwargs: Any + ) -> OptionalList[np.ndarray]: """ Analyze the behavior of model on input `X`. :param X: Input as expected by model. """ - raise NotImplementedError() - def _get_state(self): + def _get_state(self) -> dict: state = { "model_json": self._model.to_json(), "model_weights": self._model.get_weights(), "disable_model_checks": self._disable_model_checks, + "neuron_selection_mode": self._neuron_selection_mode, } return state - def save(self): + def save(self) -> Tuple[str, dict]: """ Save state of analyzer, can be passed to :func:`Analyzer.load` to resemble the analyzer. @@ -175,7 +195,7 @@ def save(self): class_name = self.__class__.__name__ return class_name, state - def save_npz(self, fname): + def save_npz(self, fname: str) -> None: """ Save state of analyzer, can be passed to :func:`Analyzer.load_npz` to resemble the analyzer. @@ -186,18 +206,26 @@ def save_npz(self, fname): np.savez(fname, **{"class_name": class_name, "state": state}) @classmethod - def _state_to_kwargs(clazz, state): + def _state_to_kwargs(cls, state: dict) -> dict: + disable_model_checks = state.pop("disable_model_checks") model_json = state.pop("model_json") model_weights = state.pop("model_weights") - disable_model_checks = state.pop("disable_model_checks") + neuron_selection_mode = state.pop("neuron_selection_mode") + + # since `super()._state_to_kwargs(state)` should be called last + # in every child class, the dict `state` should be empty at this point. assert len(state) == 0 model = keras.models.model_from_json(model_json) model.set_weights(model_weights) - return {"model": model, "disable_model_checks": disable_model_checks} + return { + "model": model, + "disable_model_checks": disable_model_checks, + "neuron_selection_mode": neuron_selection_mode, + } @staticmethod - def load(class_name, state): + def load(class_name: str, state: Dict[str, Any]) -> AnalyzerBase: """ Resembles an analyzer from the state created by :func:`analyzer.save()`. @@ -205,13 +233,11 @@ def load(class_name, state): :param class_name: The analyzer's class name. :param state: The analyzer's state. """ - # Todo:do in a smarter way! - import innvestigate.analyzer + # TODO: do in a smarter way! + cls = getattr(innvestigate.analyzer, class_name) - clazz = getattr(innvestigate.analyzer, class_name) - - kwargs = clazz._state_to_kwargs(state) - return clazz(**kwargs) + kwargs = cls._state_to_kwargs(state) + return cls(**kwargs) # type: ignore @staticmethod def load_npz(fname): @@ -221,16 +247,14 @@ def load_npz(fname): :param fname: The file's name. """ - f = np.load(fname) + npz_file = np.load(fname) - class_name = f["class_name"].item() - state = f["state"].item() + class_name = npz_file["class_name"].item() + state = npz_file["state"].item() return AnalyzerBase.load(class_name, state) ############################################################################### -############################################################################### -############################################################################### class TrainerMixin(object): @@ -240,13 +264,15 @@ class TrainerMixin(object): to the user. """ - # todo: extend with Y - def fit(self, X=None, batch_size=32, **kwargs): + # TODO: extend with Y + def fit( + self, X: Optional[np.ndarray] = None, batch_size: int = 32, **kwargs + ) -> None: """ Takes the same parameters as Keras's :func:`model.fit` function. """ generator = iutils.BatchSequence(X, batch_size) - return self._fit_generator(generator, **kwargs) + return self._fit_generator(generator, **kwargs) # type: ignore def fit_generator(self, *args, **kwargs): """ @@ -257,14 +283,14 @@ def fit_generator(self, *args, **kwargs): def _fit_generator( self, - generator, - steps_per_epoch=None, - epochs=1, - max_queue_size=10, - workers=1, - use_multiprocessing=False, + generator: iutils.BatchSequence, + steps_per_epoch: int = None, + epochs: int = 1, + max_queue_size: int = 10, + workers: int = 1, + use_multiprocessing: bool = False, verbose=0, - disable_no_training_warning=None, + disable_no_training_warning: bool = None, ): raise NotImplementedError() @@ -274,590 +300,16 @@ class OneEpochTrainerMixin(TrainerMixin): except that the training is limited to one epoch. """ - def fit(self, *args, **kwargs): + def fit(self, *args, **kwargs) -> None: """ Same interface as :func:`fit` of :class:`TrainerMixin` except that the parameter epoch is fixed to 1. """ - return super(OneEpochTrainerMixin, self).fit(*args, epochs=1, **kwargs) + return super().fit(*args, epochs=1, **kwargs) - def fit_generator(self, *args, **kwargs): + def fit_generator(self, *args, steps: int = None, **kwargs): """ Same interface as :func:`fit_generator` of :class:`TrainerMixin` except that the parameter epoch is fixed to 1. """ - steps = kwargs.pop("steps", None) - return super(OneEpochTrainerMixin, self).fit_generator( - *args, steps_per_epoch=steps, epochs=1, **kwargs - ) - - -############################################################################### -############################################################################### -############################################################################### - - -class AnalyzerNetworkBase(AnalyzerBase): - """Convenience interface for analyzers. - - This class provides helpful functionality to create analyzer's. - Basically it: - - * takes the input model and adds a layer that selects - the desired output neuron to analyze. - * passes the new model to :func:`_create_analysis` which should - return the analysis as Keras tensors. - * compiles the function and serves the output to :func:`analyze` calls. - * allows :func:`_create_analysis` to return tensors - that are intercept for debugging purposes. - - :param neuron_selection_mode: How to select the neuron to analyze. - Possible values are 'max_activation', 'index' for the neuron - (expects indices at :func:`analyze` calls), 'all' take all neurons. - :param allow_lambda_layers: Allow the model to contain lambda layers. - """ - - def __init__( - self, - model, - neuron_selection_mode="max_activation", - allow_lambda_layers=False, - **kwargs - ): - if neuron_selection_mode not in ["max_activation", "index", "all"]: - raise ValueError("neuron_selection parameter is not valid.") - self._neuron_selection_mode = neuron_selection_mode - - self._allow_lambda_layers = allow_lambda_layers - self._add_model_check( - lambda layer: ( - not self._allow_lambda_layers - and isinstance(layer, keras.layers.core.Lambda) - ), - ( - "Lamda layers are not allowed. " - "To force use set allow_lambda_layers parameter." - ), - check_type="exception", - ) - - self._special_helper_layers = [] - - super(AnalyzerNetworkBase, self).__init__(model, **kwargs) - - def _add_model_softmax_check(self): - """ - Adds check that prevents models from containing a softmax. - """ - self._add_model_check( - lambda layer: kchecks.contains_activation(layer, activation="softmax"), - "This analysis method does not support softmax layers.", - check_type="exception", - ) - - def _prepare_model(self, model): - """ - Prepares the model to analyze before it gets actually analyzed. - - This class adds the code to select a specific output neuron. - """ - neuron_selection_mode = self._neuron_selection_mode - model_inputs = model.inputs - - model_output = model.outputs - if len(model_output) > 1: - raise ValueError("Only models with one output tensor are allowed.") - analysis_inputs = [] - stop_analysis_at_tensors = [] - - # Flatten to form (batch_size, other_dimensions): - if K.ndim(model_output[0]) > 2: - model_output = keras.layers.Flatten()(model_output) - - if neuron_selection_mode == "max_activation": - l = ilayers.Max(name="iNNvestigate_max") - model_output = l(model_output) - self._special_helper_layers.append(l) - elif neuron_selection_mode == "index": - neuron_indexing = keras.layers.Input( - batch_shape=[None, None], - dtype=np.int32, - name="iNNvestigate_neuron_indexing", - ) - self._special_helper_layers.append(neuron_indexing._keras_history[0]) - analysis_inputs.append(neuron_indexing) - # The indexing tensor should not be analyzed. - stop_analysis_at_tensors.append(neuron_indexing) - - l = ilayers.GatherND(name="iNNvestigate_gather_nd") - model_output = l(model_output + [neuron_indexing]) - self._special_helper_layers.append(l) - elif neuron_selection_mode == "all": - pass - else: - raise NotImplementedError() - - model = keras.models.Model( - inputs=model_inputs + analysis_inputs, outputs=model_output - ) - return model, analysis_inputs, stop_analysis_at_tensors - - def create_analyzer_model(self): - """ - Creates the analyze functionality. If not called beforehand - it will be called by :func:`analyze`. - """ - model_inputs = self._model.inputs - tmp = self._prepare_model(self._model) - model, analysis_inputs, stop_analysis_at_tensors = tmp - self._analysis_inputs = analysis_inputs - self._prepared_model = model - - tmp = self._create_analysis( - model, stop_analysis_at_tensors=stop_analysis_at_tensors - ) - if isinstance(tmp, tuple): - if len(tmp) == 3: - analysis_outputs, debug_outputs, constant_inputs = tmp - elif len(tmp) == 2: - analysis_outputs, debug_outputs = tmp - constant_inputs = list() - elif len(tmp) == 1: - analysis_outputs = iutils.to_list(tmp[0]) - constant_inputs, debug_outputs = list(), list() - else: - raise Exception("Unexpected output from _create_analysis.") - else: - analysis_outputs = tmp - constant_inputs, debug_outputs = list(), list() - - analysis_outputs = iutils.to_list(analysis_outputs) - debug_outputs = iutils.to_list(debug_outputs) - constant_inputs = iutils.to_list(constant_inputs) - - self._n_data_input = len(model_inputs) - self._n_constant_input = len(constant_inputs) - self._n_data_output = len(analysis_outputs) - self._n_debug_output = len(debug_outputs) - self._analyzer_model = keras.models.Model( - inputs=model_inputs + analysis_inputs + constant_inputs, - outputs=analysis_outputs + debug_outputs, - ) - - def _create_analysis(self, model, stop_analysis_at_tensors=[]): - """ - Interface that needs to be implemented by a derived class. - - This function is expected to create a Keras graph that creates - a custom analysis for the model inputs given the model outputs. - - :param model: Target of analysis. - :param stop_analysis_at_tensors: A list of tensors where to stop the - analysis. Similar to stop_gradient arguments when computing the - gradient of a graph. - :return: Either one-, two- or three-tuple of lists of tensors. - * The first list of tensors represents the analysis for each - model input tensor. Tensors present in stop_analysis_at_tensors - should be omitted. - * The second list, if present, is a list of debug tensors that will - be passed to :func:`_handle_debug_output` after the analysis - is executed. - * The third list, if present, is a list of constant input tensors - added to the analysis model. - """ - raise NotImplementedError() - - def _handle_debug_output(self, debug_values): - raise NotImplementedError() - - def analyze(self, X, neuron_selection=None): - """ - Same interface as :class:`Analyzer` besides - - :param neuron_selection: If neuron_selection_mode is 'index' this - should be an integer with the index for the chosen neuron. - """ - if not hasattr(self, "_analyzer_model"): - self.create_analyzer_model() - - X = iutils.to_list(X) - - if neuron_selection is not None and self._neuron_selection_mode != "index": - raise ValueError( - "Only neuron_selection_mode 'index' expects " - "the neuron_selection parameter." - ) - if neuron_selection is None and self._neuron_selection_mode == "index": - raise ValueError( - "neuron_selection_mode 'index' expects " - "the neuron_selection parameter." - ) - - if self._neuron_selection_mode == "index": - neuron_selection = np.asarray(neuron_selection).flatten() - if neuron_selection.size == 1: - neuron_selection = np.repeat(neuron_selection, len(X[0])) - - # Add first axis indices for gather_nd - neuron_selection = np.hstack( - ( - np.arange(len(neuron_selection)).reshape((-1, 1)), - neuron_selection.reshape((-1, 1)), - ) - ) - ret = self._analyzer_model.predict_on_batch(X + [neuron_selection]) - else: - ret = self._analyzer_model.predict_on_batch(X) - - if self._n_debug_output > 0: - self._handle_debug_output(ret[-self._n_debug_output :]) - ret = ret[: -self._n_debug_output] - - if isinstance(ret, list) and len(ret) == 1: - ret = ret[0] - return ret - - def _get_state(self): - state = super(AnalyzerNetworkBase, self)._get_state() - state.update({"neuron_selection_mode": self._neuron_selection_mode}) - state.update({"allow_lambda_layers": self._allow_lambda_layers}) - return state - - @classmethod - def _state_to_kwargs(clazz, state): - neuron_selection_mode = state.pop("neuron_selection_mode") - allow_lambda_layers = state.pop("allow_lambda_layers") - kwargs = super(AnalyzerNetworkBase, clazz)._state_to_kwargs(state) - kwargs.update( - { - "neuron_selection_mode": neuron_selection_mode, - "allow_lambda_layers": allow_lambda_layers, - } - ) - return kwargs - - -class ReverseAnalyzerBase(AnalyzerNetworkBase): - """Convenience class for analyzers that revert the model's structure. - - This class contains many helper functions around the graph - reverse function :func:`innvestigate.utils.keras.graph.reverse_model`. - - The deriving classes should specify how the graph should be reverted - by implementing the following functions: - - * :func:`_reverse_mapping(layer)` given a layer this function - returns a reverse mapping for the layer as specified in - :func:`innvestigate.utils.keras.graph.reverse_model` or None. - - This function can be implemented, but it is encouraged to - implement a default mapping and add additional changes with - the function :func:`_add_conditional_reverse_mapping` (see below). - - The default behavior is finding a conditional mapping (see below), - if none is found, :func:`_default_reverse_mapping` is applied. - * :func:`_default_reverse_mapping` defines the default - reverse mapping. - * :func:`_head_mapping` defines how the outputs of the model - should be instantiated before the are passed to the reversed - network. - - Furthermore other parameters of the function - :func:`innvestigate.utils.keras.graph.reverse_model` can - be changed by setting the according parameters of the - init function: - - :param reverse_verbose: Print information on the reverse process. - :param reverse_clip_values: Clip the values that are passed along - the reverted network. Expects tuple (min, max). - :param reverse_project_bottleneck_layers: Project the value range - of bottleneck tensors in the reverse network into another range. - :param reverse_check_min_max_values: Print the min/max values - observed in each tensor along the reverse network whenever - :func:`analyze` is called. - :param reverse_check_finite: Check if values passed along the - reverse network are finite. - :param reverse_keep_tensors: Keeps the tensors created in the - backward pass and stores them in the attribute - :attr:`_reversed_tensors`. - :param reverse_reapply_on_copied_layers: See - :func:`innvestigate.utils.keras.graph.reverse_model`. - """ - - def __init__( - self, - model, - reverse_verbose=False, - reverse_clip_values=False, - reverse_project_bottleneck_layers=False, - reverse_check_min_max_values=False, - reverse_check_finite=False, - reverse_keep_tensors=False, - reverse_reapply_on_copied_layers=False, - **kwargs - ): - self._reverse_verbose = reverse_verbose - self._reverse_clip_values = reverse_clip_values - self._reverse_project_bottleneck_layers = reverse_project_bottleneck_layers - self._reverse_check_min_max_values = reverse_check_min_max_values - self._reverse_check_finite = reverse_check_finite - self._reverse_keep_tensors = reverse_keep_tensors - self._reverse_reapply_on_copied_layers = reverse_reapply_on_copied_layers - super(ReverseAnalyzerBase, self).__init__(model, **kwargs) - - def _gradient_reverse_mapping(self, Xs, Ys, reversed_Ys, reverse_state): - mask = [x not in reverse_state["stop_mapping_at_tensors"] for x in Xs] - return ilayers.GradientWRT(len(Xs), mask=mask)(Xs + Ys + reversed_Ys) - - def _reverse_mapping(self, layer): - """ - This function should return a reverse mapping for the passed layer. - - If this function returns None, :func:`_default_reverse_mapping` - is applied. - - :param layer: The layer for which a mapping should be returned. - :return: The mapping can be of the following forms: - * A function of form (A) f(Xs, Ys, reversed_Ys, reverse_state) - that maps reversed_Ys to reversed_Xs (which should contain - tensors of the same shape and type). - * A function of form f(B) f(layer, reverse_state) that returns - a function of form (A). - * A :class:`ReverseMappingBase` subclass. - """ - if layer in self._special_helper_layers: - # Special layers added by AnalyzerNetworkBase - # that should not be exposed to user. - return self._gradient_reverse_mapping - - return self._apply_conditional_reverse_mappings(layer) - - def _add_conditional_reverse_mapping( - self, condition, mapping, priority=-1, name=None - ): - """ - This function should return a reverse mapping for the passed layer. - - If this function returns None, :func:`_default_reverse_mapping` - is applied. - - :param condition: Condition when this mapping should be applied. - Form: f(layer) -> bool - :param mapping: The mapping can be of the following forms: - * A function of form (A) f(Xs, Ys, reversed_Ys, reverse_state) - that maps reversed_Ys to reversed_Xs (which should contain - tensors of the same shape and type). - * A function of form f(B) f(layer, reverse_state) that returns - a function of form (A). - * A :class:`ReverseMappingBase` subclass. - :param priority: The higher the earlier the condition gets - evaluated. - :param name: An identifying name. - """ - if getattr(self, "_reverse_mapping_applied", False): - raise Exception( - "Cannot add conditional mapping " "after first application." - ) - - if not hasattr(self, "_conditional_reverse_mappings"): - self._conditional_reverse_mappings = {} - - if priority not in self._conditional_reverse_mappings: - self._conditional_reverse_mappings[priority] = [] - - tmp = {"condition": condition, "mapping": mapping, "name": name} - self._conditional_reverse_mappings[priority].append(tmp) - - def _apply_conditional_reverse_mappings(self, layer): - mappings = getattr(self, "_conditional_reverse_mappings", {}) - self._reverse_mapping_applied = True - - # Search for mapping. First consider ones with highest priority, - # inside priority in order of adding. - sorted_keys = sorted(mappings.keys())[::-1] - for key in sorted_keys: - for mapping in mappings[key]: - if mapping["condition"](layer): - return mapping["mapping"] - - return None - - def _default_reverse_mapping(self, Xs, Ys, reversed_Ys, reverse_state): - """ - Fallback function to map reversed_Ys to reversed_Xs - (which should contain tensors of the same shape and type). - """ - return self._gradient_reverse_mapping(Xs, Ys, reversed_Ys, reverse_state) - - def _head_mapping(self, X): - """ - Map output tensors to new values before passing - them into the reverted network. - """ - return X - - def _postprocess_analysis(self, X): - return X - - def _reverse_model( - self, model, stop_analysis_at_tensors=[], return_all_reversed_tensors=False - ): - return kgraph.reverse_model( - model, - reverse_mappings=self._reverse_mapping, - default_reverse_mapping=self._default_reverse_mapping, - head_mapping=self._head_mapping, - stop_mapping_at_tensors=stop_analysis_at_tensors, - verbose=self._reverse_verbose, - clip_all_reversed_tensors=self._reverse_clip_values, - project_bottleneck_tensors=self._reverse_project_bottleneck_layers, - return_all_reversed_tensors=return_all_reversed_tensors, - ) - - def _create_analysis(self, model, stop_analysis_at_tensors=[]): - return_all_reversed_tensors = ( - self._reverse_check_min_max_values - or self._reverse_check_finite - or self._reverse_keep_tensors - ) - ret = self._reverse_model( - model, - stop_analysis_at_tensors=stop_analysis_at_tensors, - return_all_reversed_tensors=return_all_reversed_tensors, - ) - - if return_all_reversed_tensors: - ret = (self._postprocess_analysis(ret[0]), ret[1]) - else: - ret = self._postprocess_analysis(ret) - - if return_all_reversed_tensors: - debug_tensors = [] - self._debug_tensors_indices = {} - - values = list(six.itervalues(ret[1])) - mapping = {i: v["id"] for i, v in enumerate(values)} - tensors = [v["final_tensor"] for v in values] - self._reverse_tensors_mapping = mapping - - if self._reverse_check_min_max_values: - tmp = [ilayers.Min(None)(x) for x in tensors] - self._debug_tensors_indices["min"] = ( - len(debug_tensors), - len(debug_tensors) + len(tmp), - ) - debug_tensors += tmp - - tmp = [ilayers.Max(None)(x) for x in tensors] - self._debug_tensors_indices["max"] = ( - len(debug_tensors), - len(debug_tensors) + len(tmp), - ) - debug_tensors += tmp - - if self._reverse_check_finite: - tmp = iutils.to_list(ilayers.FiniteCheck()(tensors)) - self._debug_tensors_indices["finite"] = ( - len(debug_tensors), - len(debug_tensors) + len(tmp), - ) - debug_tensors += tmp - - if self._reverse_keep_tensors: - self._debug_tensors_indices["keep"] = ( - len(debug_tensors), - len(debug_tensors) + len(tensors), - ) - debug_tensors += tensors - - ret = (ret[0], debug_tensors) - return ret - - def _handle_debug_output(self, debug_values): - - if self._reverse_check_min_max_values: - indices = self._debug_tensors_indices["min"] - tmp = debug_values[indices[0] : indices[1]] - tmp = sorted( - [(self._reverse_tensors_mapping[i], v) for i, v in enumerate(tmp)] - ) - print( - "Minimum values in tensors: " - "((NodeID, TensorID), Value) - {}".format(tmp) - ) - - indices = self._debug_tensors_indices["max"] - tmp = debug_values[indices[0] : indices[1]] - tmp = sorted( - [(self._reverse_tensors_mapping[i], v) for i, v in enumerate(tmp)] - ) - print( - "Maximum values in tensors: " - "((NodeID, TensorID), Value) - {}".format(tmp) - ) - - if self._reverse_check_finite: - indices = self._debug_tensors_indices["finite"] - tmp = debug_values[indices[0] : indices[1]] - nfinite_tensors = np.flatnonzero(np.asarray(tmp) > 0) - - if len(nfinite_tensors) > 0: - nfinite_tensors = sorted( - [self._reverse_tensors_mapping[i] for i in nfinite_tensors] - ) - print( - "Not finite values found in following nodes: " - "(NodeID, TensorID) - {}".format(nfinite_tensors) - ) - - if self._reverse_keep_tensors: - indices = self._debug_tensors_indices["keep"] - tmp = debug_values[indices[0] : indices[1]] - tmp = sorted( - [(self._reverse_tensors_mapping[i], v) for i, v in enumerate(tmp)] - ) - self._reversed_tensors = tmp - - def _get_state(self): - state = super(ReverseAnalyzerBase, self)._get_state() - state.update({"reverse_verbose": self._reverse_verbose}) - state.update({"reverse_clip_values": self._reverse_clip_values}) - state.update( - { - "reverse_project_bottleneck_layers": self._reverse_project_bottleneck_layers - } - ) - state.update( - {"reverse_check_min_max_values": self._reverse_check_min_max_values} - ) - state.update({"reverse_check_finite": self._reverse_check_finite}) - state.update({"reverse_keep_tensors": self._reverse_keep_tensors}) - state.update( - {"reverse_reapply_on_copied_layers": self._reverse_reapply_on_copied_layers} - ) - return state - - @classmethod - def _state_to_kwargs(clazz, state): - reverse_verbose = state.pop("reverse_verbose") - reverse_clip_values = state.pop("reverse_clip_values") - reverse_project_bottleneck_layers = state.pop( - "reverse_project_bottleneck_layers" - ) - reverse_check_min_max_values = state.pop("reverse_check_min_max_values") - reverse_check_finite = state.pop("reverse_check_finite") - reverse_keep_tensors = state.pop("reverse_keep_tensors") - reverse_reapply_on_copied_layers = state.pop("reverse_reapply_on_copied_layers") - kwargs = super(ReverseAnalyzerBase, clazz)._state_to_kwargs(state) - kwargs.update( - { - "reverse_verbose": reverse_verbose, - "reverse_clip_values": reverse_clip_values, - "reverse_project_bottleneck_layers": reverse_project_bottleneck_layers, - "reverse_check_min_max_values": reverse_check_min_max_values, - "reverse_check_finite": reverse_check_finite, - "reverse_keep_tensors": reverse_keep_tensors, - "reverse_reapply_on_copied_layers": reverse_reapply_on_copied_layers, - } - ) - return kwargs + return super().fit_generator(*args, steps_per_epoch=steps, epochs=1, **kwargs) diff --git a/src/innvestigate/analyzer/deeptaylor.py b/src/innvestigate/analyzer/deeptaylor.py index 7201b678..78bb0fb4 100644 --- a/src/innvestigate/analyzer/deeptaylor.py +++ b/src/innvestigate/analyzer/deeptaylor.py @@ -1,13 +1,15 @@ -# Get Python six functionality: -from __future__ import absolute_import, division, print_function, unicode_literals +from __future__ import annotations + +from typing import Any import keras.layers import keras.models -from ..utils.keras import checks as kchecks -from ..utils.keras import graph as kgraph -from . import base -from .relevance_based import relevance_rule as lrp_rules +import innvestigate.analyzer.relevance_based.relevance_rule as lrp_rules +import innvestigate.utils.keras.checks as kchecks +import innvestigate.utils.keras.graph as kgraph +from innvestigate.analyzer.reverse_base import ReverseAnalyzerBase +from innvestigate.utils.types import Model __all__ = [ "DeepTaylor", @@ -15,7 +17,7 @@ ] -class DeepTaylor(base.ReverseAnalyzerBase): +class DeepTaylor(ReverseAnalyzerBase): """DeepTaylor for ReLU-networks with unbounded input This class implements the DeepTaylor algorithm for neural networks with @@ -24,19 +26,19 @@ class DeepTaylor(base.ReverseAnalyzerBase): :param model: A Keras model. """ - def __init__(self, model, *args, **kwargs): + def __init__(self, model: Model, *args, **kwargs) -> None: + super().__init__(model, *args, **kwargs) + # Add and run model checks self._add_model_softmax_check() self._add_model_check( lambda layer: not kchecks.only_relu_activation(layer), "This DeepTaylor implementation only supports ReLU activations.", check_type="exception", ) - super(DeepTaylor, self).__init__(model, *args, **kwargs) + self._do_model_checks() - def _create_analysis(self, *args, **kwargs): - def do_nothing(Xs, Ys, As, reverse_state): - return As + def _create_analysis(self, *args: Any, **kwargs: Any): # Kernel layers. self._add_conditional_reverse_mapping( @@ -129,9 +131,9 @@ def do_nothing(Xs, Ys, As, reverse_state): name="deep_taylor_no_transform", ) - return super(DeepTaylor, self)._create_analysis(*args, **kwargs) + return super()._create_analysis(*args, **kwargs) - def _default_reverse_mapping(self, Xs, Ys, reversed_Ys, reverse_state): + def _default_reverse_mapping(self, _Xs, _Ys, _reversed_Ys, reverse_state): """ Block all default mappings. """ @@ -147,7 +149,7 @@ def _prepare_model(self, model): inputs=model.inputs, outputs=positive_outputs ) - return super(DeepTaylor, self)._prepare_model(model_with_positive_output) + return super()._prepare_model(model_with_positive_output) class BoundedDeepTaylor(DeepTaylor): @@ -162,27 +164,24 @@ class BoundedDeepTaylor(DeepTaylor): """ def __init__(self, model, low=None, high=None, **kwargs): + super().__init__(model, **kwargs) if low is None or high is None: raise ValueError( - "The low or high parameter is missing." - " Z-B (bounded rule) require both values." + "The low or high parameter is missing. " + "Z-B (bounded rule) require both values." ) self._bounds_low = low self._bounds_high = high - super(BoundedDeepTaylor, self).__init__(model, **kwargs) - def _create_analysis(self, *args, **kwargs): low, high = self._bounds_low, self._bounds_high class BoundedProxyRule(lrp_rules.BoundedRule): def __init__(self, *args, **kwargs): - super(BoundedProxyRule, self).__init__( - *args, low=low, high=high, **kwargs - ) + super().__init__(*args, low=low, high=high, **kwargs) self._add_conditional_reverse_mapping( lambda l: kchecks.is_input_layer(l) and kchecks.contains_kernel(l), @@ -191,4 +190,20 @@ def __init__(self, *args, **kwargs): priority=10, # do first ) - return super(BoundedDeepTaylor, self)._create_analysis(*args, **kwargs) + return super()._create_analysis(*args, **kwargs) + + def _get_state(self): + state = super()._get_state() + state.update({"low": self._bounds_low, "high": self._bounds_high}) + return state + + @classmethod + def _state_to_kwargs(cls, state): + low = state.pop("low") + high = state.pop("high") + + # call super after popping class-specific states: + kwargs = super()._state_to_kwargs(state) + + kwargs.update({"low": low, "high": high}) + return kwargs diff --git a/src/innvestigate/analyzer/gradient_based.py b/src/innvestigate/analyzer/gradient_based.py index 3764ec76..de4d319d 100644 --- a/src/innvestigate/analyzer/gradient_based.py +++ b/src/innvestigate/analyzer/gradient_based.py @@ -1,20 +1,18 @@ -# Get Python six functionality: -from __future__ import absolute_import, division, print_function, unicode_literals +from __future__ import annotations + +from typing import Dict, Optional import keras import keras.models -from .. import layers as ilayers -from .. import utils as iutils -from ..utils import keras as kutils -from ..utils.keras import checks as kchecks -from ..utils.keras import graph as kgraph -from . import base, wrapper - -############################################################################### -############################################################################### -############################################################################### - +import innvestigate.layers as ilayers +import innvestigate.utils as iutils +import innvestigate.utils.keras as kutils +import innvestigate.utils.keras.checks as kchecks +import innvestigate.utils.keras.graph as kgraph +from innvestigate.analyzer.network_base import AnalyzerNetworkBase +from innvestigate.analyzer.reverse_base import ReverseAnalyzerBase +from innvestigate.analyzer.wrapper import GaussianSmoother, PathIntegrator __all__ = [ "BaselineGradient", @@ -27,12 +25,7 @@ ] -############################################################################### -############################################################################### -############################################################################### - - -class BaselineGradient(base.AnalyzerNetworkBase): +class BaselineGradient(AnalyzerNetworkBase): """Gradient analyzer based on build-in gradient. Returns as analysis the function value with respect to the input. @@ -43,6 +36,7 @@ class BaselineGradient(base.AnalyzerNetworkBase): """ def __init__(self, model, postprocess=None, **kwargs): + super().__init__(model, **kwargs) if postprocess not in [None, "abs", "square"]: raise ValueError( @@ -51,10 +45,12 @@ def __init__(self, model, postprocess=None, **kwargs): self._postprocess = postprocess self._add_model_softmax_check() + self._do_model_checks() - super(BaselineGradient, self).__init__(model, **kwargs) + def _create_analysis(self, model, stop_analysis_at_tensors=None): + if stop_analysis_at_tensors is None: + stop_analysis_at_tensors = [] - def _create_analysis(self, model, stop_analysis_at_tensors=[]): tensors_to_analyze = [ x for x in iutils.to_list(model.inputs) if x not in stop_analysis_at_tensors ] @@ -70,14 +66,16 @@ def _create_analysis(self, model, stop_analysis_at_tensors=[]): return iutils.to_list(ret) def _get_state(self): - state = super(BaselineGradient, self)._get_state() + state = super()._get_state() state.update({"postprocess": self._postprocess}) return state @classmethod - def _state_to_kwargs(clazz, state): + def _state_to_kwargs(cls, state): postprocess = state.pop("postprocess") - kwargs = super(BaselineGradient, clazz)._state_to_kwargs(state) + # call super after popping class-specific states: + kwargs = super()._state_to_kwargs(state) + kwargs.update( { "postprocess": postprocess, @@ -86,7 +84,7 @@ def _state_to_kwargs(clazz, state): return kwargs -class Gradient(base.ReverseAnalyzerBase): +class Gradient(ReverseAnalyzerBase): """Gradient analyzer. Returns as analysis the function value with respect to the input. @@ -95,23 +93,24 @@ class Gradient(base.ReverseAnalyzerBase): :param model: A Keras model. """ - def __init__(self, model, postprocess=None, **kwargs): + def __init__(self, model, postprocess: Optional[str] = None, **kwargs): + super(Gradient, self).__init__(model, **kwargs) if postprocess not in [None, "abs", "square"]: raise ValueError( - "Parameter 'postprocess' must be either " "None, 'abs', or 'square'." + """Parameter 'postprocess' must be either None, "abs", or "square".""" ) self._postprocess = postprocess + # Add and run model checks self._add_model_softmax_check() - - super(Gradient, self).__init__(model, **kwargs) + self._do_model_checks() def _head_mapping(self, X): return ilayers.OnesLike()(X) def _postprocess_analysis(self, X): - ret = super(Gradient, self)._postprocess_analysis(X) + ret = super()._postprocess_analysis(X) if self._postprocess == "abs": ret = ilayers.Abs()(ret) @@ -121,14 +120,16 @@ def _postprocess_analysis(self, X): return iutils.to_list(ret) def _get_state(self): - state = super(Gradient, self)._get_state() + state = super()._get_state() state.update({"postprocess": self._postprocess}) return state @classmethod - def _state_to_kwargs(clazz, state): + def _state_to_kwargs(cls, state): postprocess = state.pop("postprocess") - kwargs = super(Gradient, clazz)._state_to_kwargs(state) + # call super after popping class-specific states: + kwargs = super()._state_to_kwargs(state) + kwargs.update( { "postprocess": postprocess, @@ -138,8 +139,6 @@ def _state_to_kwargs(clazz, state): ############################################################################### -############################################################################### -############################################################################### class InputTimesGradient(Gradient): @@ -150,15 +149,16 @@ class InputTimesGradient(Gradient): def __init__(self, model, **kwargs): - self._add_model_softmax_check() - super(InputTimesGradient, self).__init__(model, **kwargs) - def _create_analysis(self, model, stop_analysis_at_tensors=[]): + def _create_analysis(self, model, stop_analysis_at_tensors=None): + if stop_analysis_at_tensors is None: + stop_analysis_at_tensors = [] + tensors_to_analyze = [ x for x in iutils.to_list(model.inputs) if x not in stop_analysis_at_tensors ] - gradients = super(InputTimesGradient, self)._create_analysis( + gradients = super()._create_analysis( model, stop_analysis_at_tensors=stop_analysis_at_tensors ) return [ @@ -168,8 +168,6 @@ def _create_analysis(self, model, stop_analysis_at_tensors=[]): ############################################################################### -############################################################################### -############################################################################### class DeconvnetReverseReLULayer(kgraph.ReverseMappingBase): @@ -180,7 +178,7 @@ def __init__(self, layer, state): name_template="reversed_%s", ) - def apply(self, Xs, Ys, reversed_Ys, reverse_state): + def apply(self, Xs, Ys, reversed_Ys, reverse_state: Dict): # Apply relus conditioned on backpropagated values. reversed_Ys = kutils.apply(self._activation, reversed_Ys) @@ -189,7 +187,7 @@ def apply(self, Xs, Ys, reversed_Ys, reverse_state): return ilayers.GradientWRT(len(Xs))(Xs + Ys_wo_relu + reversed_Ys) -class Deconvnet(base.ReverseAnalyzerBase): +class Deconvnet(ReverseAnalyzerBase): """Deconvnet analyzer. Applies the "deconvnet" algorithm to analyze the model. @@ -198,15 +196,16 @@ class Deconvnet(base.ReverseAnalyzerBase): """ def __init__(self, model, **kwargs): + super().__init__(model, **kwargs) + # Add and run model checks self._add_model_softmax_check() self._add_model_check( lambda layer: not kchecks.only_relu_activation(layer), "Deconvnet is only specified for networks with ReLU activations.", check_type="exception", ) - - super(Deconvnet, self).__init__(model, **kwargs) + self._do_model_checks() def _create_analysis(self, *args, **kwargs): @@ -216,10 +215,10 @@ def _create_analysis(self, *args, **kwargs): name="deconvnet_reverse_relu_layer", ) - return super(Deconvnet, self)._create_analysis(*args, **kwargs) + return super()._create_analysis(*args, **kwargs) -def GuidedBackpropReverseReLULayer(Xs, Ys, reversed_Ys, reverse_state): +def GuidedBackpropReverseReLULayer(Xs, Ys, reversed_Ys, reverse_state: Dict): activation = keras.layers.Activation("relu") # Apply relus conditioned on backpropagated values. reversed_Ys = kutils.apply(activation, reversed_Ys) @@ -228,7 +227,7 @@ def GuidedBackpropReverseReLULayer(Xs, Ys, reversed_Ys, reverse_state): return ilayers.GradientWRT(len(Xs))(Xs + Ys + reversed_Ys) -class GuidedBackprop(base.ReverseAnalyzerBase): +class GuidedBackprop(ReverseAnalyzerBase): """Guided backprop analyzer. Applies the "guided backprop" algorithm to analyze the model. @@ -237,15 +236,16 @@ class GuidedBackprop(base.ReverseAnalyzerBase): """ def __init__(self, model, **kwargs): + super().__init__(model, **kwargs) + # Add and run model checks self._add_model_softmax_check() self._add_model_check( lambda layer: not kchecks.only_relu_activation(layer), "GuidedBackprop is only specified for " "networks with ReLU activations.", check_type="exception", ) - - super(GuidedBackprop, self).__init__(model, **kwargs) + self._do_model_checks() def _create_analysis(self, *args, **kwargs): @@ -255,15 +255,13 @@ def _create_analysis(self, *args, **kwargs): name="guided_backprop_reverse_relu_layer", ) - return super(GuidedBackprop, self)._create_analysis(*args, **kwargs) + return super()._create_analysis(*args, **kwargs) ############################################################################### -############################################################################### -############################################################################### -class IntegratedGradients(wrapper.PathIntegrator): +class IntegratedGradients(PathIntegrator): """Integrated gradient analyzer. Applies the "integrated gradient" algorithm to analyze the model. @@ -272,23 +270,38 @@ class IntegratedGradients(wrapper.PathIntegrator): :param steps: Number of steps to use average along integration path. """ - def __init__(self, model, steps=64, **kwargs): - subanalyzer_kwargs = {} - kwargs_keys = ["neuron_selection_mode", "postprocess"] - for key in kwargs_keys: - if key in kwargs: - subanalyzer_kwargs[key] = kwargs.pop(key) - subanalyzer = Gradient(model, **subanalyzer_kwargs) + def __init__( + self, + model, + steps=64, + neuron_selection_mode="max_activation", + postprocess=None, + **kwargs + ): + # If initialized through serialization: + if "subanalyzer" in kwargs: + subanalyzer = kwargs.pop("subanalyzer") + # If initialized normally: + else: + + subanalyzer = Gradient( + model, + neuron_selection_mode=neuron_selection_mode, + postprocess=postprocess, + ) - super(IntegratedGradients, self).__init__(subanalyzer, steps=steps, **kwargs) + super().__init__( + subanalyzer, + steps=steps, + neuron_selection_mode=neuron_selection_mode, + **kwargs + ) ############################################################################### -############################################################################### -############################################################################### -class SmoothGrad(wrapper.GaussianSmoother): +class SmoothGrad(GaussianSmoother): """Smooth grad analyzer. Applies the "smooth grad" algorithm to analyze the model. @@ -297,14 +310,29 @@ class SmoothGrad(wrapper.GaussianSmoother): :param augment_by_n: Number of distortions to average for smoothing. """ - def __init__(self, model, augment_by_n=64, **kwargs): - subanalyzer_kwargs = {} - kwargs_keys = ["neuron_selection_mode", "postprocess"] - for key in kwargs_keys: - if key in kwargs: - subanalyzer_kwargs[key] = kwargs.pop(key) - subanalyzer = Gradient(model, **subanalyzer_kwargs) + def __init__( + self, + model, + augment_by_n=64, + neuron_selection_mode="max_activation", + postprocess=None, + **kwargs + ): + # If initialized through serialization: + if "subanalyzer" in kwargs: + subanalyzer = kwargs.pop("subanalyzer") + # If initialized normally: + else: + + subanalyzer = Gradient( + model, + neuron_selection_mode=neuron_selection_mode, + postprocess=postprocess, + ) - super(SmoothGrad, self).__init__( - subanalyzer, augment_by_n=augment_by_n, **kwargs + super().__init__( + subanalyzer, + augment_by_n=augment_by_n, + neuron_selection_mode=neuron_selection_mode, + **kwargs ) diff --git a/src/innvestigate/analyzer/misc.py b/src/innvestigate/analyzer/misc.py index ddaf8afe..472fc8cf 100644 --- a/src/innvestigate/analyzer/misc.py +++ b/src/innvestigate/analyzer/misc.py @@ -1,23 +1,12 @@ -# Get Python six functionality: -from __future__ import absolute_import, division, print_function, unicode_literals - -from .. import layers as ilayers -from .. import utils as iutils -from .base import AnalyzerNetworkBase - -############################################################################### -############################################################################### -############################################################################### +from __future__ import annotations +import innvestigate.layers as ilayers +import innvestigate.utils as iutils +from innvestigate.analyzer.network_base import AnalyzerNetworkBase __all__ = ["Random", "Input"] -############################################################################### -############################################################################### -############################################################################### - - class Input(AnalyzerNetworkBase): """Returns the input. @@ -26,7 +15,13 @@ class Input(AnalyzerNetworkBase): :param model: A Keras model. """ - def _create_analysis(self, model, stop_analysis_at_tensors=[]): + def __init__(self, model, **kwargs) -> None: + super().__init__(model, **kwargs) + + def _create_analysis(self, model, stop_analysis_at_tensors=None): + if stop_analysis_at_tensors is None: + stop_analysis_at_tensors = [] + tensors_to_analyze = [ x for x in iutils.to_list(model.inputs) if x not in stop_analysis_at_tensors ] @@ -45,9 +40,12 @@ class Random(AnalyzerNetworkBase): def __init__(self, model, stddev=1, **kwargs): self._stddev = stddev - super(Random, self).__init__(model, **kwargs) + super().__init__(model, **kwargs) + + def _create_analysis(self, model, stop_analysis_at_tensors=None): + if stop_analysis_at_tensors is None: + stop_analysis_at_tensors = [] - def _create_analysis(self, model, stop_analysis_at_tensors=[]): noise = ilayers.TestPhaseGaussianNoise(stddev=self._stddev) tensors_to_analyze = [ x for x in iutils.to_list(model.inputs) if x not in stop_analysis_at_tensors @@ -55,13 +53,15 @@ def _create_analysis(self, model, stop_analysis_at_tensors=[]): return [noise(x) for x in tensors_to_analyze] def _get_state(self): - state = super(Random, self)._get_state() + state = super()._get_state() state.update({"stddev": self._stddev}) return state @classmethod - def _state_to_kwargs(clazz, state): + def _state_to_kwargs(cls, state): stddev = state.pop("stddev") - kwargs = super(Random, clazz)._state_to_kwargs(state) + # call super after popping class-specific states: + kwargs = super()._state_to_kwargs(state) + kwargs.update({"stddev": stddev}) return kwargs diff --git a/src/innvestigate/analyzer/network_base.py b/src/innvestigate/analyzer/network_base.py new file mode 100644 index 00000000..278a4c93 --- /dev/null +++ b/src/innvestigate/analyzer/network_base.py @@ -0,0 +1,309 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Tuple, Union + +import keras +import keras.backend as K +import keras.layers +import keras.models +import numpy as np + +import innvestigate.layers as ilayers +import innvestigate.utils as iutils +import innvestigate.utils.keras.checks as kchecks +from innvestigate.analyzer.base import AnalyzerBase +from innvestigate.utils.types import Layer, LayerCheck, Model, OptionalList, Tensor + +__all__ = ["AnalyzerNetworkBase"] + + +class AnalyzerNetworkBase(AnalyzerBase): + """Convenience interface for analyzers. + + This class provides helpful functionality to create analyzer's. + Basically it: + + * takes the input model and adds a layer that selects + the desired output neuron to analyze. + * passes the new model to :func:`_create_analysis` which should + return the analysis as Keras tensors. + * compiles the function and serves the output to :func:`analyze` calls. + * allows :func:`_create_analysis` to return tensors + that are intercept for debugging purposes. + + :param allow_lambda_layers: Allow the model to contain lambda layers. + """ + + def __init__( + self, + model: Model, + allow_lambda_layers: bool = False, + **kwargs, + ) -> None: + """ + From AnalyzerBase super init: + * Initializes empty list of _model_checks + * set _neuron_selection_mode + + Here: + * add check for lambda layers through 'allow_lambda_layers' + * define attributes for '_prepare_model', which is later called + through 'create_analyzer_model' + """ + # Call super init to initialize self._model_checks + super().__init__(model, **kwargs) + + # Add model check for lambda layers + self._allow_lambda_layers: bool = allow_lambda_layers + self._add_lambda_layers_check() + + # Attributes of prepared model created by '_prepare_model' + self._analyzer_model_done: bool = False + self._analyzer_model: Model = None + self._special_helper_layers: List[Layer] = [] # added for _reverse_mapping + self._analysis_inputs: Optional[List[Tensor]] = None + self._n_data_input: int = 0 + self._n_constant_input: int = 0 + self._n_data_output: int = 0 + self._n_debug_output: int = 0 + + def _add_model_softmax_check(self) -> None: + """ + Adds check that prevents models from containing a softmax. + """ + contains_softmax: LayerCheck = lambda layer: kchecks.contains_activation( + layer, activation="softmax" + ) + self._add_model_check( + check=contains_softmax, + message="This analysis method does not support softmax layers.", + check_type="exception", + ) + + def _add_lambda_layers_check(self) -> None: + check_lambda_layers: LayerCheck = lambda layer: ( + not self._allow_lambda_layers + and isinstance(layer, keras.layers.core.Lambda) + ) + self._add_model_check( + check=check_lambda_layers, + message=( + "Lamda layers are not allowed. " + "To force use set 'allow_lambda_layers' parameter." + ), + check_type="exception", + ) + + def _prepare_model(self, model: Model) -> Tuple[Model, List[Tensor], List[Tensor]]: + """ + Prepares the model to analyze before it gets actually analyzed. + + This class adds the code to select a specific output neuron. + """ + neuron_selection_mode: str + model_inputs: OptionalList[Tensor] + model_output: OptionalList[Tensor] + analysis_inputs: List[Tensor] + stop_analysis_at_tensors: List[Tensor] + + neuron_selection_mode = self._neuron_selection_mode + model_inputs = model.inputs + model_output = model.outputs + + if len(model_output) > 1: + raise ValueError("Only models with one output tensor are allowed.") + + analysis_inputs = [] + stop_analysis_at_tensors = [] + + # Flatten to form (batch_size, other_dimensions): + if K.ndim(model_output[0]) > 2: + model_output = keras.layers.Flatten()(model_output) + + if neuron_selection_mode == "max_activation": + inn_max = ilayers.Max(name="iNNvestigate_max") + model_output = inn_max(model_output) + self._special_helper_layers.append(inn_max) + + elif neuron_selection_mode == "index": + # Creates a placeholder tensor when `dtype` is passed. + neuron_indexing: Layer = keras.layers.Input( + batch_shape=[None, None], + dtype=np.int32, + name="iNNvestigate_neuron_indexing", + ) + # TODO: what does _keras_history[0] do? + self._special_helper_layers.append(neuron_indexing._keras_history[0]) + analysis_inputs.append(neuron_indexing) + # The indexing tensor should not be analyzed. + stop_analysis_at_tensors.append(neuron_indexing) + + inn_gather = ilayers.GatherND(name="iNNvestigate_gather_nd") + model_output = inn_gather(model_output + [neuron_indexing]) + self._special_helper_layers.append(inn_gather) + elif neuron_selection_mode == "all": + pass + else: + raise NotImplementedError() + + model = keras.models.Model( + inputs=model_inputs + analysis_inputs, outputs=model_output + ) + return model, analysis_inputs, stop_analysis_at_tensors + + def create_analyzer_model(self) -> None: + """ + Creates the analyze functionality. If not called beforehand + it will be called by :func:`analyze`. + """ + model_inputs = self._model.inputs + model, analysis_inputs, stop_analysis_at_tensors = self._prepare_model( + self._model + ) + self._analysis_inputs = analysis_inputs + self._prepared_model = model + + tmp = self._create_analysis( + model, stop_analysis_at_tensors=stop_analysis_at_tensors + ) + if isinstance(tmp, tuple): + if len(tmp) == 3: + analysis_outputs, debug_outputs, constant_inputs = tmp # type: ignore + elif len(tmp) == 2: + analysis_outputs, debug_outputs = tmp # type: ignore + constant_inputs = [] + elif len(tmp) == 1: + analysis_outputs = tmp[0] + constant_inputs = [] + debug_outputs = [] + else: + raise Exception("Unexpected output from _create_analysis.") + else: + analysis_outputs = tmp + constant_inputs = [] + debug_outputs = [] + + analysis_outputs = iutils.to_list(analysis_outputs) + debug_outputs = iutils.to_list(debug_outputs) + constant_inputs = iutils.to_list(constant_inputs) + + self._n_data_input = len(model_inputs) + self._n_constant_input = len(constant_inputs) + self._n_data_output = len(analysis_outputs) + self._n_debug_output = len(debug_outputs) + self._analyzer_model = keras.models.Model( + inputs=model_inputs + analysis_inputs + constant_inputs, + outputs=analysis_outputs + debug_outputs, + ) + + self._analyzer_model_done = True + + def _create_analysis( + self, model: Model, stop_analysis_at_tensors: List[Tensor] = None + ) -> Union[ + Tuple[List[Tensor]], + Tuple[List[Tensor], List[Tensor]], + Tuple[List[Tensor], List[Tensor], List[Tensor]], + ]: + """ + Interface that needs to be implemented by a derived class. + + This function is expected to create a Keras graph that creates + a custom analysis for the model inputs given the model outputs. + + :param model: Target of analysis. + :param stop_analysis_at_tensors: A list of tensors where to stop the + analysis. Similar to stop_gradient arguments when computing the + gradient of a graph. + :return: Either one-, two- or three-tuple of lists of tensors. + * The first list of tensors represents the analysis for each + model input tensor. Tensors present in stop_analysis_at_tensors + should be omitted. + * The second list, if present, is a list of debug tensors that will + be passed to :func:`_handle_debug_output` after the analysis + is executed. + * The third list, if present, is a list of constant input tensors + added to the analysis model. + """ + raise NotImplementedError() + + def _handle_debug_output(self, debug_values): + raise NotImplementedError() + + def analyze( + self, + X: OptionalList[np.ndarray], + neuron_selection: Optional[int] = None, + ) -> OptionalList[np.ndarray]: + """ + Same interface as :class:`Analyzer` besides + + :param neuron_selection: If neuron_selection_mode is 'index' this + should be an integer with the index for the chosen neuron. + """ + # TODO: what does should mean in docstring? + + if self._analyzer_model_done is False: + self.create_analyzer_model() + + if neuron_selection is not None and self._neuron_selection_mode != "index": + raise ValueError( + f"neuron_selection_mode {self._neuron_selection_mode} doesn't support ", + "'neuron_selection' parameter.", + ) + + if neuron_selection is None and self._neuron_selection_mode == "index": + raise ValueError( + "neuron_selection_mode 'index' expects 'neuron_selection' parameter." + ) + + X = iutils.to_list(X) + + ret: OptionalList[np.ndarray] + if self._neuron_selection_mode == "index": + if neuron_selection is not None: + # TODO: document how this works + selection = self._get_neuron_selection_array(X, neuron_selection) + ret = self._analyzer_model.predict_on_batch(X + [selection]) + else: + raise RuntimeError( + 'neuron_selection_mode "index" requires neuron_selection.' + ) + else: + ret = self._analyzer_model.predict_on_batch(X) + + if self._n_debug_output > 0: + self._handle_debug_output(ret[-self._n_debug_output :]) + ret = ret[: -self._n_debug_output] + + return iutils.unpack_singleton(ret) + + def _get_neuron_selection_array( + self, X: List[np.ndarray], neuron_selection: int + ) -> np.ndarray: + """Get neuron selection array for neuron_selection_mode "index".""" + # TODO: detailed documentation on how this selects neurons + + nsa = np.asarray(neuron_selection).flatten() # singleton ndarray + + # is 'nsa' is singleton, repeat it so that it matches number of rows of X + if nsa.size == 1: + nsa = np.repeat(nsa, len(X[0])) + + # Add first axis indices for gather_nd + nsa = np.hstack((np.arange(len(nsa)).reshape((-1, 1)), nsa.reshape((-1, 1)))) + return nsa + + def _get_state(self) -> Dict[str, Any]: + state = super()._get_state() + state.update({"allow_lambda_layers": self._allow_lambda_layers}) + return state + + @classmethod + def _state_to_kwargs(cls, state: Dict[str, Any]) -> Dict[str, Any]: + allow_lambda_layers = state.pop("allow_lambda_layers") + # call super after popping class-specific states: + kwargs = super()._state_to_kwargs(state) + + kwargs.update({"allow_lambda_layers": allow_lambda_layers}) + return kwargs diff --git a/src/innvestigate/analyzer/pattern_based.py b/src/innvestigate/analyzer/pattern_based.py index b9607a7f..2e9d717e 100644 --- a/src/innvestigate/analyzer/pattern_based.py +++ b/src/innvestigate/analyzer/pattern_based.py @@ -1,7 +1,7 @@ -# Get Python six functionality: -from __future__ import absolute_import, division, print_function, unicode_literals +from __future__ import annotations import warnings +from typing import Dict import keras import keras.activations @@ -12,18 +12,14 @@ import keras.models import numpy as np -from .. import layers as ilayers -from .. import tools as itools -from .. import utils -from ..utils import keras as kutils -from ..utils.keras import checks as kchecks -from ..utils.keras import graph as kgraph -from . import base - -############################################################################### -############################################################################### -############################################################################### - +import innvestigate.layers as ilayers +import innvestigate.tools as itools +import innvestigate.utils as iutils +import innvestigate.utils.keras as kutils +import innvestigate.utils.keras.checks as kchecks +import innvestigate.utils.keras.graph as kgraph +from innvestigate.analyzer.base import OneEpochTrainerMixin +from innvestigate.analyzer.reverse_base import ReverseAnalyzerBase __all__ = [ "PatternNet", @@ -31,11 +27,6 @@ ] -############################################################################### -############################################################################### -############################################################################### - - SUPPORTED_LAYER_PATTERNNET = ( keras.engine.topology.InputLayer, keras.layers.convolutional.Conv2D, @@ -64,7 +55,7 @@ class PatternNetReverseKernelLayer(kgraph.ReverseMappingBase): where the filter weights are replaced with the patterns. """ - def __init__(self, layer, state, pattern): + def __init__(self, layer, _state, pattern): config = layer.get_config() # Layer can contain a kernel and an activation. @@ -92,7 +83,7 @@ def __init__(self, layer, state, pattern): layer, name_template="reversed_pattern_%s", weights=filter_weights ) - def apply(self, Xs, Ys, reversed_Ys, reverse_state): + def apply(self, Xs, _Ys, reversed_Ys, _reverse_state: Dict): # Reapply the prepared layers. act_Xs = kutils.apply(self._filter_layer, Xs) act_Ys = kutils.apply(self._act_layer, act_Xs) @@ -109,13 +100,13 @@ def apply(self, Xs, Ys, reversed_Ys, reverse_state): tmp = reversed_Ys else: # if linear activation this behaves strange - tmp = utils.to_list(grad_act(act_Xs + act_Ys + reversed_Ys)) + tmp = iutils.to_list(grad_act(act_Xs + act_Ys + reversed_Ys)) # Second step: propagate through the pattern layer. return grad_pattern(Xs + pattern_Ys + tmp) -class PatternNet(base.OneEpochTrainerMixin, base.ReverseAnalyzerBase): +class PatternNet(OneEpochTrainerMixin, ReverseAnalyzerBase): """PatternNet analyzer. Applies the "PatternNet" algorithm to analyze the model's predictions. @@ -130,26 +121,26 @@ class PatternNet(base.OneEpochTrainerMixin, base.ReverseAnalyzerBase): """ def __init__(self, model, patterns=None, pattern_type=None, **kwargs): + super().__init__(model, **kwargs) + # Add and run model checks self._add_model_softmax_check() self._add_model_check( lambda layer: not kchecks.only_relu_activation(layer), - ( - "PatternNet is not well defined for " - "networks with non-ReLU activations." - ), + ("PatternNet is not well defined for networks with non-ReLU activations."), check_type="warning", ) self._add_model_check( lambda layer: not kchecks.is_convnet_layer(layer), - ("PatternNet is only well defined for " "convolutional neural networks."), + ("PatternNet is only well defined for convolutional neural networks."), check_type="warning", ) self._add_model_check( lambda layer: not isinstance(layer, SUPPORTED_LAYER_PATTERNNET), - ("PatternNet is only well defined for " "conv2d/max-pooling/dense layers."), + ("PatternNet is only well defined for conv2d/max-pooling/dense layers."), check_type="exception", ) + self._do_model_checks() self._patterns = patterns if self._patterns is not None: @@ -162,16 +153,13 @@ def __init__(self, model, patterns=None, pattern_type=None, **kwargs): # Prevent this by projecting the values in bottleneck layers to +-1. if not kwargs.get("reverse_project_bottleneck_layers", True): warnings.warn( - "The standard setting for " - "'reverse_project_bottleneck_layers' " + "The standard setting for 'reverse_project_bottleneck_layers'" "is overwritten." ) else: kwargs["reverse_project_bottleneck_layers"] = True - super(PatternNet, self).__init__(model, **kwargs) - - def _get_pattern_for_layer(self, layer, state): + def _get_pattern_for_layer(self, layer, _state): layers = [ l for l in kgraph.get_model_layers(self._model) @@ -180,8 +168,8 @@ def _get_pattern_for_layer(self, layer, state): return self._patterns[layers.index(layer)] - def _prepare_pattern(self, layer, state, pattern): - """Prepares a pattern before it is set in the back-ward pass.""" + def _prepare_pattern(self, _layer, _state, pattern): + """ ""Prepares a pattern before it is set in the back-ward pass.""" return pattern def _create_analysis(self, *args, **kwargs): @@ -199,7 +187,7 @@ def create_kernel_layer_mapping(layer, state): name="patternnet_kernel_layer_mapping", ) - return super(PatternNet, self)._create_analysis(*args, **kwargs) + return super()._create_analysis(*args, **kwargs) def _fit_generator( self, @@ -213,13 +201,14 @@ def _fit_generator( disable_no_training_warning=None, **kwargs ): + # TODO: implement epochs pattern_type = self._pattern_type if pattern_type is None: pattern_type = "relu" if isinstance(pattern_type, (list, tuple)): - raise ValueError("Only one pattern type allowed. " "Please pass a string.") + raise ValueError("Only one pattern type allowed. Please pass a string.") computer = itools.PatternComputer( self._model, pattern_type=pattern_type, **kwargs @@ -235,15 +224,17 @@ def _fit_generator( ) def _get_state(self): - state = super(PatternNet, self)._get_state() + state = super()._get_state() state.update({"patterns": self._patterns, "pattern_type": self._pattern_type}) return state @classmethod - def _state_to_kwargs(clazz, state): + def _state_to_kwargs(cls, state): patterns = state.pop("patterns") pattern_type = state.pop("pattern_type") - kwargs = super(PatternNet, clazz)._state_to_kwargs(state) + # call super after popping class-specific states: + kwargs = super()._state_to_kwargs(state) + kwargs.update({"patterns": patterns, "pattern_type": pattern_type}) return kwargs diff --git a/src/innvestigate/analyzer/relevance_based/relevance_analyzer.py b/src/innvestigate/analyzer/relevance_based/relevance_analyzer.py index a0013daf..1515b4e7 100644 --- a/src/innvestigate/analyzer/relevance_based/relevance_analyzer.py +++ b/src/innvestigate/analyzer/relevance_based/relevance_analyzer.py @@ -1,8 +1,8 @@ -# Get Python six functionality: -from __future__ import absolute_import, division, print_function, unicode_literals +from __future__ import annotations import inspect from builtins import zip +from typing import Dict, List import keras import keras.backend as K @@ -17,20 +17,16 @@ import keras.models import six +import innvestigate.analyzer.relevance_based.relevance_rule as rrule +import innvestigate.analyzer.relevance_based.utils as rutils +import innvestigate.layers as ilayers +import innvestigate.utils as iutils import innvestigate.utils.keras as kutils -from innvestigate import layers as ilayers -from innvestigate import utils as iutils -from innvestigate.utils.keras import checks as kchecks -from innvestigate.utils.keras import graph as kgraph - -from .. import base -from . import relevance_rule as rrule -from . import utils as rutils - -############################################################################### -############################################################################### -############################################################################### - +import innvestigate.utils.keras.checks as kchecks +import innvestigate.utils.keras.graph as kgraph +from innvestigate.analyzer.network_base import AnalyzerNetworkBase +from innvestigate.analyzer.reverse_base import ReverseAnalyzerBase +from innvestigate.utils.types import Layer, LayerCheck, Model, OptionalList, Tensor __all__ = [ "BaselineLRPZ", @@ -58,11 +54,59 @@ ############################################################################### -############################################################################### -############################################################################### - - -class BaselineLRPZ(base.AnalyzerNetworkBase): +BASELINE_LRPZ_LAYERS = ( + keras.engine.topology.InputLayer, + keras.layers.convolutional.Conv1D, + keras.layers.convolutional.Conv2D, + keras.layers.convolutional.Conv2DTranspose, + keras.layers.convolutional.Conv3D, + keras.layers.convolutional.Conv3DTranspose, + keras.layers.convolutional.Cropping1D, + keras.layers.convolutional.Cropping2D, + keras.layers.convolutional.Cropping3D, + keras.layers.convolutional.SeparableConv1D, + keras.layers.convolutional.SeparableConv2D, + keras.layers.convolutional.UpSampling1D, + keras.layers.convolutional.UpSampling2D, + keras.layers.convolutional.UpSampling3D, + keras.layers.convolutional.ZeroPadding1D, + keras.layers.convolutional.ZeroPadding2D, + keras.layers.convolutional.ZeroPadding3D, + keras.layers.core.Activation, + keras.layers.core.ActivityRegularization, + keras.layers.core.Dense, + keras.layers.core.Dropout, + keras.layers.core.Flatten, + keras.layers.core.Lambda, + keras.layers.core.Masking, + keras.layers.core.Permute, + keras.layers.core.RepeatVector, + keras.layers.core.Reshape, + keras.layers.core.SpatialDropout1D, + keras.layers.core.SpatialDropout2D, + keras.layers.core.SpatialDropout3D, + keras.layers.local.LocallyConnected1D, + keras.layers.local.LocallyConnected2D, + keras.layers.Add, + keras.layers.Concatenate, + keras.layers.Dot, + keras.layers.Maximum, + keras.layers.Minimum, + keras.layers.Subtract, + keras.layers.noise.AlphaDropout, + keras.layers.noise.GaussianDropout, + keras.layers.noise.GaussianNoise, + keras.layers.normalization.BatchNormalization, + keras.layers.pooling.GlobalMaxPooling1D, + keras.layers.pooling.GlobalMaxPooling2D, + keras.layers.pooling.GlobalMaxPooling3D, + keras.layers.pooling.MaxPooling1D, + keras.layers.pooling.MaxPooling2D, + keras.layers.pooling.MaxPooling3D, +) + + +class BaselineLRPZ(AnalyzerNetworkBase): """LRPZ analyzer - for testing purpose only. Applies the "LRP-Z" algorithm to analyze the model. @@ -73,59 +117,12 @@ class BaselineLRPZ(base.AnalyzerNetworkBase): :param model: A Keras model. """ - def __init__(self, model, **kwargs): + def __init__(self, model: Model, **kwargs): # Inside function to not break import if Keras changes. - BASELINELRPZ_LAYERS = ( - keras.engine.topology.InputLayer, - keras.layers.convolutional.Conv1D, - keras.layers.convolutional.Conv2D, - keras.layers.convolutional.Conv2DTranspose, - keras.layers.convolutional.Conv3D, - keras.layers.convolutional.Conv3DTranspose, - keras.layers.convolutional.Cropping1D, - keras.layers.convolutional.Cropping2D, - keras.layers.convolutional.Cropping3D, - keras.layers.convolutional.SeparableConv1D, - keras.layers.convolutional.SeparableConv2D, - keras.layers.convolutional.UpSampling1D, - keras.layers.convolutional.UpSampling2D, - keras.layers.convolutional.UpSampling3D, - keras.layers.convolutional.ZeroPadding1D, - keras.layers.convolutional.ZeroPadding2D, - keras.layers.convolutional.ZeroPadding3D, - keras.layers.core.Activation, - keras.layers.core.ActivityRegularization, - keras.layers.core.Dense, - keras.layers.core.Dropout, - keras.layers.core.Flatten, - keras.layers.core.Lambda, - keras.layers.core.Masking, - keras.layers.core.Permute, - keras.layers.core.RepeatVector, - keras.layers.core.Reshape, - keras.layers.core.SpatialDropout1D, - keras.layers.core.SpatialDropout2D, - keras.layers.core.SpatialDropout3D, - keras.layers.local.LocallyConnected1D, - keras.layers.local.LocallyConnected2D, - keras.layers.Add, - keras.layers.Concatenate, - keras.layers.Dot, - keras.layers.Maximum, - keras.layers.Minimum, - keras.layers.Subtract, - keras.layers.noise.AlphaDropout, - keras.layers.noise.GaussianDropout, - keras.layers.noise.GaussianNoise, - keras.layers.normalization.BatchNormalization, - keras.layers.pooling.GlobalMaxPooling1D, - keras.layers.pooling.GlobalMaxPooling2D, - keras.layers.pooling.GlobalMaxPooling3D, - keras.layers.pooling.MaxPooling1D, - keras.layers.pooling.MaxPooling2D, - keras.layers.pooling.MaxPooling3D, - ) + super().__init__(model, **kwargs) + + # Add and run model checks self._add_model_softmax_check() self._add_model_check( lambda layer: not kchecks.only_relu_activation(layer), @@ -133,14 +130,18 @@ def __init__(self, model, **kwargs): check_type="exception", ) self._add_model_check( - lambda layer: not isinstance(layer, BASELINELRPZ_LAYERS), + lambda layer: not isinstance(layer, BASELINE_LRPZ_LAYERS), "BaselineLRPZ only works with a predefined set of layers.", check_type="exception", ) + self._do_model_checks() - super(BaselineLRPZ, self).__init__(model, **kwargs) + def _create_analysis( + self, model: Model, stop_analysis_at_tensors: List[Tensor] = None + ): + if stop_analysis_at_tensors is None: + stop_analysis_at_tensors = [] - def _create_analysis(self, model, stop_analysis_at_tensors=[]): tensors_to_analyze = [ x for x in iutils.to_list(model.inputs) if x not in stop_analysis_at_tensors ] @@ -153,12 +154,10 @@ def _create_analysis(self, model, stop_analysis_at_tensors=[]): ] -############################################################################### -############################################################################### ############################################################################### # Utility list enabling name mappings via string -LRP_RULES = { +LRP_RULES: Dict = { "Z": rrule.ZRule, "ZIgnoreBias": rrule.ZIgnoreBiasRule, "Epsilon": rrule.EpsilonRule, @@ -182,7 +181,7 @@ def __init__(self, layer, state): # TODO: implement rule support. return - def apply(self, Xs, Ys, Rs, reverse_state): + def apply(self, _Xs, _Ys, Rs, _reverse_state: Dict): # the embedding layer outputs for an (indexed) input a vector. # thus, in the relevance backward pass, the embedding layer receives # relevances Rs corresponding to those vectors. @@ -192,13 +191,13 @@ def apply(self, Xs, Ys, Rs, reverse_state): # relevances are given shaped [batch_size, sequence_length, embedding_dims] pool_relevance = keras.layers.Lambda(lambda x: keras.backend.sum(x, axis=-1)) - return [pool_relevance(r) for r in Rs] + return [pool_relevance(R) for R in Rs] class BatchNormalizationReverseLayer(kgraph.ReverseMappingBase): """Special BN handler that applies the Z-Rule""" - def __init__(self, layer, state): + def __init__(self, layer, _state): config = layer.get_config() self._center = config["center"] @@ -210,19 +209,20 @@ def __init__(self, layer, state): if self._center: self._beta = layer.beta - # TODO: implement rule support. for BatchNormalization -> [BNEpsilon, BNAlphaBeta, BNIgnore] - # super(BatchNormalizationReverseLayer, self).__init__(layer, state) + # TODO: implement rule support. + # for BatchNormalization -> [BNEpsilon, BNAlphaBeta, BNIgnore] + # super().__init__(layer, state) # how to do this: # super.__init__ calls select_rule and sets a self._rule class # check if isinstance(self_rule, EpsiloneRule), then reroute # to BatchNormEpsilonRule. Not pretty, but should work. - def apply(self, Xs, Ys, Rs, reverse_state): + def apply(self, Xs, Ys, Rs, _reverse_state: Dict): input_shape = [K.int_shape(x) for x in Xs] if len(input_shape) != 1: # extend below lambda layers towards multiple parameters. raise ValueError( - "BatchNormalizationReverseLayer expects Xs with len(Xs) = 1, but was len(Xs) = {}".format( + "BatchNormalizationReverseLayer expects Xs with len(Xs) = 1, but was len(Xs) = {}".format( # noqa len(Xs) ) ) @@ -281,13 +281,17 @@ def __init__(self, layer, state): ) # TODO: implement rule support. - # super(AddReverseLayer, self).__init__(layer, state) - - def apply(self, Xs, Ys, Rs, reverse_state): - # the outputs of the pooling operation at each location is the sum of its inputs. - # the forward message must be known in this case, and are the inputs for each pooling thing. - # the gradient is 1 for each output-to-input connection, which corresponds to the "weights" - # of the layer. It should thus be sufficient to reweight the relevances and and do a gradient_wrt + # super().__init__(layer, state) + + def apply(self, Xs, _Ys, Rs, _reverse_state: Dict): + # The outputs of the pooling operation at each location + # is the sum of its inputs. + # The forward message must be known in this case, + # and are the inputs for each pooling thing. + # The gradient is 1 for each output-to-input connection, + # which corresponds to the "weights" of the layer. + # It should thus be sufficient to reweight the relevances + # and do a gradient_wrt grad = ilayers.GradientWRT(len(Xs)) # Get activations. Zs = kutils.apply(self._layer_wo_act, Xs) @@ -310,14 +314,17 @@ def __init__(self, layer, state): ) # TODO: implement rule support. - # super(AveragePoolingRerseLayer, self).__init__(layer, state) - - def apply(self, Xs, Ys, Rs, reverse_state): - # the outputs of the pooling operation at each location is the sum of its inputs. - # the forward message must be known in this case, and are the inputs for each pooling thing. - # the gradient is 1 for each output-to-input connection, which corresponds to the "weights" - # of the layer. It should thus be sufficient to reweight the relevances and and do a gradient_wrt - + # super().__init__(layer, state) + + def apply(self, Xs, _Ys, Rs, reverse_state: Dict): + # The outputs of the pooling operation at each location + # is the sum of its inputs. + # The forward message must be known in this case, + # and are the inputs for each pooling thing. + # The gradient is 1 for each output-to-input connection, + # which corresponds to the "weights" of the layer. + # It should thus be sufficient to reweight the relevances + # and do a gradient_wrt grad = ilayers.GradientWRT(len(Xs)) # Get activations. Zs = kutils.apply(self._layer_wo_act, Xs) @@ -331,31 +338,44 @@ def apply(self, Xs, Ys, Rs, reverse_state): return [keras.layers.Multiply()([a, b]) for a, b in zip(Xs, tmp)] -class LRP(base.ReverseAnalyzerBase): +class LRP(ReverseAnalyzerBase): """ Base class for LRP-based model analyzers :param model: A Keras model. - :param rule: A rule can be a string or a Rule object, lists thereof or a list of conditions [(Condition, Rule), ... ] + :param rule: A rule can be a string or a Rule object, lists thereof or + a list of conditions [(Condition, Rule), ... ] gradient. - :param input_layer_rule: either a Rule object, atuple of (low, high) the min/max pixel values of the inputs + :param input_layer_rule: either a Rule object, atuple of (low, high) + the min/max pixel values of the inputs :param bn_layer_rule: either a Rule object or None. None means dedicated BN rule will be applied. """ - def __init__(self, model, *args, **kwargs): - rule = kwargs.pop("rule", None) - input_layer_rule = kwargs.pop("input_layer_rule", None) - until_layer_idx = kwargs.pop("until_layer_idx", None) - until_layer_rule = kwargs.pop("until_layer_rule", None) + def __init__( + self, + model, + *args, + rule=None, + input_layer_rule=None, + until_layer_idx=None, + until_layer_rule=None, + bn_layer_rule=None, + bn_layer_fuse_mode: str = "one_linear", + **kwargs, + ): + super().__init__(model, *args, **kwargs) - bn_layer_rule = kwargs.pop("bn_layer_rule", None) - bn_layer_fuse_mode = kwargs.pop("bn_layer_fuse_mode", "one_linear") - assert bn_layer_fuse_mode in ["one_linear", "two_linear"] + self._input_layer_rule = input_layer_rule + self._until_layer_rule = until_layer_rule + self._until_layer_idx = until_layer_idx + self._bn_layer_rule = bn_layer_rule + self._bn_layer_fuse_mode = bn_layer_fuse_mode + # Add self._add_model_softmax_check() self._add_model_check( lambda layer: not kchecks.is_convnet_layer(layer), @@ -363,29 +383,26 @@ def __init__(self, model, *args, **kwargs): check_type="warning", ) + assert bn_layer_fuse_mode in ["one_linear", "two_linear"] + + # TODO: refactor rule type checking into separate function # check if rule was given explicitly. - # rule can be a string, a list (of strings) or a list of conditions [(Condition, Rule), ... ] for each layer. + # rule can be a string, a list (of strings) or + # a list of conditions [(Condition, Rule), ... ] for each layer. if rule is None: raise ValueError("Need LRP rule(s).") if isinstance(rule, list): - # copy refrences self._rule = list(rule) else: self._rule = rule - self._input_layer_rule = input_layer_rule - self._until_layer_rule = until_layer_rule - self._until_layer_idx = until_layer_idx - - self._bn_layer_rule = bn_layer_rule - self._bn_layer_fuse_mode = bn_layer_fuse_mode - if isinstance(rule, six.string_types) or ( + if isinstance(rule, str) or ( inspect.isclass(rule) and issubclass(rule, kgraph.ReverseMappingBase) ): # NOTE: All LRP rules inherit from kgraph.ReverseMappingBase # the given rule is a single string or single rule implementing cla ss use_conditions = True - rules = [(lambda a, b: True, rule)] + rules = [(lambda _: True, rule)] elif not isinstance(rule[0], tuple): # rule list of rule strings or classes @@ -399,15 +416,8 @@ def __init__(self, model, *args, **kwargs): # apply rule to first self._until_layer_idx layers if self._until_layer_rule is not None and self._until_layer_idx is not None: for i in range(self._until_layer_idx + 1): - rules.insert( - 0, - ( - lambda layer, foo, bound_i=i: kchecks.is_layer_at_idx( - layer, bound_i - ), - self._until_layer_rule, - ), - ) + is_at_idx: LayerCheck = lambda layer: kchecks.is_layer_at_idx(layer, i) + rules.insert(0, (is_at_idx, self._until_layer_rule)) # create a BoundedRule for input layer handling from given tuple if self._input_layer_rule is not None: @@ -417,53 +427,41 @@ def __init__(self, model, *args, **kwargs): class BoundedProxyRule(rrule.BoundedRule): def __init__(self, *args, **kwargs): - super(BoundedProxyRule, self).__init__( - *args, low=low, high=high, **kwargs - ) + super().__init__(*args, low=low, high=high, **kwargs) input_layer_rule = BoundedProxyRule if use_conditions is True: - rules.insert( - 0, - ( - lambda layer, foo: kchecks.is_input_layer(layer), - input_layer_rule, - ), - ) - + is_input: LayerCheck = lambda layer: kchecks.is_input_layer(layer) + rules.insert(0, (is_input, input_layer_rule)) else: rules.insert(0, input_layer_rule) self._rules_use_conditions = use_conditions self._rules = rules - # FINALIZED constructor. - super(LRP, self).__init__(model, *args, **kwargs) - - def create_rule_mapping(self, layer, reverse_state): - rule_class = None + def create_rule_mapping(self, layer: Layer, reverse_state: Dict): if self._rules_use_conditions is True: for condition, rule in self._rules: - if condition(layer, reverse_state): + if condition(layer): rule_class = rule break else: rule_class = self._rules.pop() if rule_class is None: - raise Exception("No rule applies to layer: %s" % layer) + raise Exception(f"No rule applies to layer {layer}") - if isinstance(rule_class, six.string_types): + if isinstance(rule_class, str): rule_class = LRP_RULES[rule_class] rule = rule_class(layer, reverse_state) return rule.apply def _create_analysis(self, *args, **kwargs): - #################################################################### - ### Functionality responible for backwards rule selection below #### - #################################################################### + ################################################################### + # Functionality responible for backwards rule selection below #### + ################################################################### # default backward hook self._add_conditional_reverse_mapping( @@ -472,11 +470,12 @@ def _create_analysis(self, *args, **kwargs): name="lrp_layer_with_kernel_mapping", ) - # specialized backward hooks. TODO: add ReverseLayer class handling layers Without kernel: Add and AvgPool + # specialized backward hooks. + # TODO: add ReverseLayer class handling layers without kernel: Add and AvgPool bn_layer_rule = self._bn_layer_rule if bn_layer_rule is None: - # todo(alber): get rid of this option! + # TODO (alber): get rid of this option! # alternatively a default rule should be applied. bn_mapping = BatchNormalizationReverseLayer else: @@ -509,9 +508,15 @@ def _create_analysis(self, *args, **kwargs): ) # FINALIZED constructor. - return super(LRP, self)._create_analysis(*args, **kwargs) - - def _default_reverse_mapping(self, Xs, Ys, reversed_Ys, reverse_state): + return super()._create_analysis(*args, **kwargs) + + def _default_reverse_mapping( + self, + Xs: OptionalList[Tensor], + Ys: OptionalList[Tensor], + reversed_Ys: OptionalList[Tensor], + reverse_state: Dict, + ): # default_return_layers = [keras.layers.Activation]# TODO extend if ( len(Xs) == len(Ys) @@ -532,12 +537,12 @@ def _default_reverse_mapping(self, Xs, Ys, reversed_Ys, reverse_state): # Cropping return self._gradient_reverse_mapping(Xs, Ys, reversed_Ys, reverse_state) - ######################################## - ### End of Rule Selection Business. #### - ######################################## + ###################################### + # End of Rule Selection Business. #### + ###################################### def _get_state(self): - state = super(LRP, self)._get_state() + state = super()._get_state() state.update({"rule": self._rule}) state.update({"input_layer_rule": self._input_layer_rule}) state.update({"bn_layer_rule": self._bn_layer_rule}) @@ -545,12 +550,14 @@ def _get_state(self): return state @classmethod - def _state_to_kwargs(clazz, state): + def _state_to_kwargs(cls, state): rule = state.pop("rule") input_layer_rule = state.pop("input_layer_rule") bn_layer_rule = state.pop("bn_layer_rule") bn_layer_fuse_mode = state.pop("bn_layer_fuse_mode") - kwargs = super(LRP, clazz)._state_to_kwargs(state) + # call super after popping class-specific states: + kwargs = super()._state_to_kwargs(state) + kwargs.update( { "rule": rule, @@ -569,8 +576,10 @@ def _state_to_kwargs(clazz, state): class _LRPFixedParams(LRP): @classmethod - def _state_to_kwargs(clazz, state): - kwargs = super(_LRPFixedParams, clazz)._state_to_kwargs(state) + def _state_to_kwargs(cls, state): + # call super after popping class-specific states: + kwargs = super()._state_to_kwargs(state) + del kwargs["rule"] del kwargs["bn_layer_rule"] return kwargs @@ -580,16 +589,18 @@ class LRPZ(_LRPFixedParams): """LRP-analyzer that uses the LRP-Z rule""" def __init__(self, model, *args, **kwargs): - super(LRPZ, self).__init__(model, *args, rule="Z", bn_layer_rule="Z", **kwargs) + super().__init__(model, *args, rule="Z", bn_layer_rule="Z", **kwargs) + self._do_model_checks() class LRPZIgnoreBias(_LRPFixedParams): """LRP-analyzer that uses the LRP-Z-ignore-bias rule""" def __init__(self, model, *args, **kwargs): - super(LRPZIgnoreBias, self).__init__( + super().__init__( model, *args, rule="ZIgnoreBias", bn_layer_rule="ZIgnoreBias", **kwargs ) + self._do_model_checks() class LRPEpsilon(_LRPFixedParams): @@ -607,44 +618,43 @@ class EpsilonProxyRule(rrule.EpsilonRule): """ def __init__(self, *args, **kwargs): - super(EpsilonProxyRule, self).__init__( - *args, epsilon=epsilon, bias=bias, **kwargs - ) + super().__init__(*args, epsilon=epsilon, bias=bias, **kwargs) - super(LRPEpsilon, self).__init__( + super().__init__( model, *args, rule=EpsilonProxyRule, bn_layer_rule=EpsilonProxyRule, - **kwargs + **kwargs, ) + self._do_model_checks() + class LRPEpsilonIgnoreBias(LRPEpsilon): """LRP-analyzer that uses the LRP-Epsilon-ignore-bias rule""" def __init__(self, model, epsilon=1e-7, *args, **kwargs): - super(LRPEpsilonIgnoreBias, self).__init__( - model, *args, epsilon=epsilon, bias=False, **kwargs - ) + super().__init__(model, *args, epsilon=epsilon, bias=False, **kwargs) + self._do_model_checks() class LRPWSquare(_LRPFixedParams): """LRP-analyzer that uses the DeepTaylor W**2 rule""" def __init__(self, model, *args, **kwargs): - super(LRPWSquare, self).__init__( + super().__init__( model, *args, rule="WSquare", bn_layer_rule="WSquare", **kwargs ) + self._do_model_checks() class LRPFlat(_LRPFixedParams): """LRP-analyzer that uses the LRP-Flat rule""" def __init__(self, model, *args, **kwargs): - super(LRPFlat, self).__init__( - model, *args, rule="Flat", bn_layer_rule="Flat", **kwargs - ) + super().__init__(model, *args, rule="Flat", bn_layer_rule="Flat", **kwargs) + self._do_model_checks() class LRPAlphaBeta(LRP): @@ -664,20 +674,19 @@ class AlphaBetaProxyRule(rrule.AlphaBetaRule): """ def __init__(self, *args, **kwargs): - super(AlphaBetaProxyRule, self).__init__( - *args, alpha=alpha, beta=beta, bias=bias, **kwargs - ) + super().__init__(*args, alpha=alpha, beta=beta, bias=bias, **kwargs) - super(LRPAlphaBeta, self).__init__( + super().__init__( model, *args, rule=AlphaBetaProxyRule, bn_layer_rule=AlphaBetaProxyRule, - **kwargs + **kwargs, ) + self._do_model_checks() def _get_state(self): - state = super(LRPAlphaBeta, self)._get_state() + state = super()._get_state() del state["rule"] state.update({"alpha": self._alpha}) state.update({"beta": self._beta}) @@ -685,12 +694,14 @@ def _get_state(self): return state @classmethod - def _state_to_kwargs(clazz, state): + def _state_to_kwargs(cls, state): alpha = state.pop("alpha") beta = state.pop("beta") bias = state.pop("bias") state["rule"] = None - kwargs = super(LRPAlphaBeta, clazz)._state_to_kwargs(state) + # call super after popping class-specific states: + kwargs = super()._state_to_kwargs(state) + del kwargs["rule"] del kwargs["bn_layer_rule"] kwargs.update({"alpha": alpha, "beta": beta, "bias": bias}) @@ -699,8 +710,10 @@ def _state_to_kwargs(clazz, state): class _LRPAlphaBetaFixedParams(LRPAlphaBeta): @classmethod - def _state_to_kwargs(clazz, state): - kwargs = super(_LRPAlphaBetaFixedParams, clazz)._state_to_kwargs(state) + def _state_to_kwargs(cls, state): + # call super after popping class-specific states: + kwargs = super()._state_to_kwargs(state) + del kwargs["alpha"] del kwargs["beta"] del kwargs["bias"] @@ -711,36 +724,32 @@ class LRPAlpha2Beta1(_LRPAlphaBetaFixedParams): """LRP-analyzer that uses the LRP-alpha-beta rule with a=2,b=1""" def __init__(self, model, *args, **kwargs): - super(LRPAlpha2Beta1, self).__init__( - model, *args, alpha=2, beta=1, bias=True, **kwargs - ) + super().__init__(model, *args, alpha=2, beta=1, bias=True, **kwargs) + self._do_model_checks() class LRPAlpha2Beta1IgnoreBias(_LRPAlphaBetaFixedParams): """LRP-analyzer that uses the LRP-alpha-beta-ignbias rule with a=2,b=1""" def __init__(self, model, *args, **kwargs): - super(LRPAlpha2Beta1IgnoreBias, self).__init__( - model, *args, alpha=2, beta=1, bias=False, **kwargs - ) + super().__init__(model, *args, alpha=2, beta=1, bias=False, **kwargs) + self._do_model_checks() class LRPAlpha1Beta0(_LRPAlphaBetaFixedParams): """LRP-analyzer that uses the LRP-alpha-beta rule with a=1,b=0""" def __init__(self, model, *args, **kwargs): - super(LRPAlpha1Beta0, self).__init__( - model, *args, alpha=1, beta=0, bias=True, **kwargs - ) + super().__init__(model, *args, alpha=1, beta=0, bias=True, **kwargs) + self._do_model_checks() class LRPAlpha1Beta0IgnoreBias(_LRPAlphaBetaFixedParams): """LRP-analyzer that uses the LRP-alpha-beta-ignbias rule with a=1,b=0""" def __init__(self, model, *args, **kwargs): - super(LRPAlpha1Beta0IgnoreBias, self).__init__( - model, *args, alpha=1, beta=0, bias=False, **kwargs - ) + super().__init__(model, *args, alpha=1, beta=0, bias=False, **kwargs) + self._do_model_checks() class LRPZPlus(LRPAlpha1Beta0IgnoreBias): @@ -748,7 +757,8 @@ class LRPZPlus(LRPAlpha1Beta0IgnoreBias): # TODO: assert that layer inputs are always >= 0 def __init__(self, model, *args, **kwargs): - super(LRPZPlus, self).__init__(model, *args, **kwargs) + super().__init__(model, *args, **kwargs) + self._do_model_checks() class LRPZPlusFast(_LRPFixedParams): @@ -759,82 +769,99 @@ class LRPZPlusFast(_LRPFixedParams): # TODO: assert that layer inputs are always >= 0 def __init__(self, model, *args, **kwargs): - super(LRPZPlusFast, self).__init__( + super().__init__( model, *args, rule="ZPlusFast", bn_layer_rule="ZPlusFast", **kwargs ) + self._do_model_checks() class LRPSequentialPresetA(_LRPFixedParams): # for the lack of a better name """Special LRP-configuration for ConvNets""" - def __init__(self, model, epsilon=1e-1, *args, **kwargs): - - self._add_model_check( - lambda layer: not kchecks.only_relu_activation(layer), - # TODO: fix. specify. extend. - ( - "LRPSequentialPresetA is not advised " - "for networks with non-ReLU activations." - ), - check_type="warning", - ) - + def __init__( + self, + model, + epsilon=1e-1, + *args, + bn_layer_rule=rrule.AlphaBetaX2m100Rule, + **kwargs, + ): class EpsilonProxyRule(rrule.EpsilonRule): def __init__(self, *args, **kwargs): - super(EpsilonProxyRule, self).__init__( - *args, epsilon=epsilon, bias=True, **kwargs - ) + super().__init__(*args, epsilon=epsilon, bias=True, **kwargs) conditional_rules = [ (kchecks.is_dense_layer, EpsilonProxyRule), (kchecks.is_conv_layer, rrule.Alpha1Beta0Rule), ] - bn_layer_rule = kwargs.pop("bn_layer_rule", rrule.AlphaBetaX2m100Rule) - super(LRPSequentialPresetA, self).__init__( + super().__init__( model, *args, rule=conditional_rules, bn_layer_rule=bn_layer_rule, **kwargs ) - -class LRPSequentialPresetB(_LRPFixedParams): - """Special LRP-configuration for ConvNets""" - - def __init__(self, model, epsilon=1e-1, *args, **kwargs): self._add_model_check( lambda layer: not kchecks.only_relu_activation(layer), # TODO: fix. specify. extend. ( - "LRPSequentialPresetB is not advised " + "LRPSequentialPresetA is not advised " "for networks with non-ReLU activations." ), check_type="warning", ) + self._do_model_checks() + + +class LRPSequentialPresetB(_LRPFixedParams): + """Special LRP-configuration for ConvNets""" + + def __init__( + self, + model: Model, + epsilon: float = 1e-1, + *args, + bn_layer_rule=rrule.AlphaBetaX2m100Rule, + **kwargs, + ): class EpsilonProxyRule(rrule.EpsilonRule): def __init__(self, *args, **kwargs): - super(EpsilonProxyRule, self).__init__( - *args, epsilon=epsilon, bias=True, **kwargs - ) + super().__init__(*args, epsilon=epsilon, bias=True, **kwargs) conditional_rules = [ (kchecks.is_dense_layer, EpsilonProxyRule), (kchecks.is_conv_layer, rrule.Alpha2Beta1Rule), ] - bn_layer_rule = kwargs.pop("bn_layer_rule", rrule.AlphaBetaX2m100Rule) - super(LRPSequentialPresetB, self).__init__( + super().__init__( model, *args, rule=conditional_rules, bn_layer_rule=bn_layer_rule, **kwargs ) + # Add and run model checks + self._add_model_check( + lambda layer: not kchecks.only_relu_activation(layer), + # TODO: fix. specify. extend. + ( + "LRPSequentialPresetB is not advised " + "for networks with non-ReLU activations." + ), + check_type="warning", + ) + self._do_model_checks() + # TODO: allow to pass input layer identification by index or id. class LRPSequentialPresetAFlat(LRPSequentialPresetA): """Special LRP-configuration for ConvNets""" def __init__(self, model, *args, **kwargs): - super(LRPSequentialPresetAFlat, self).__init__( - model, *args, input_layer_rule="Flat", **kwargs - ) + # provide functionality for `analyzer.load()` by avoiding multiple kwargs: + if "input_layer_rule" in kwargs: + if kwargs["input_layer_rule"] != "Flat": + raise RuntimeError( + "Unexpected input_layer_rule when loading LRPSequentialPresetAFlat." + ) + kwargs.pop("input_layer_rule") + super().__init__(model, *args, input_layer_rule="Flat", **kwargs) # TODO: allow to pass input layer identification by index or id. @@ -842,23 +869,28 @@ class LRPSequentialPresetBFlat(LRPSequentialPresetB): """Special LRP-configuration for ConvNets""" def __init__(self, model, *args, **kwargs): - super(LRPSequentialPresetBFlat, self).__init__( - model, *args, input_layer_rule="Flat", **kwargs - ) + # provide functionality for `analyzer.load()` by avoiding multiple kwargs: + if "input_layer_rule" in kwargs: + if kwargs["input_layer_rule"] != "Flat": + raise RuntimeError( + "Unexpected input_layer_rule when loading LRPSequentialPresetAFlat." + ) + kwargs.pop("input_layer_rule") + super().__init__(model, *args, input_layer_rule="Flat", **kwargs) class LRPSequentialPresetBFlatUntilIdx(LRPSequentialPresetBFlat): """ - Special LRP-configuration for ConvNets. Allows to perform LRP_flat from (including) layer until_layer_idx down until + Special LRP-configuration for ConvNets. + Allows to perform LRP_flat from (including) layer until_layer_idx down until the input layer. Weightless layers are ignored when counting the index for now. """ - def __init__(self, model, *args, **kwargs): - layer_flat_idx = kwargs.pop("until_layer_idx", None) - super(LRPSequentialPresetBFlatUntilIdx, self).__init__( + def __init__(self, model, *args, until_layer_idx=None, **kwargs): + super().__init__( model, *args, - until_layer_idx=layer_flat_idx, + until_layer_idx=until_layer_idx, until_layer_rule=rrule.FlatRule, - **kwargs + **kwargs, ) diff --git a/src/innvestigate/analyzer/relevance_based/relevance_rule.py b/src/innvestigate/analyzer/relevance_based/relevance_rule.py index dd2e8628..ee18b10d 100644 --- a/src/innvestigate/analyzer/relevance_based/relevance_rule.py +++ b/src/innvestigate/analyzer/relevance_based/relevance_rule.py @@ -1,7 +1,7 @@ -# Get Python six functionality: -from __future__ import absolute_import, division, print_function, unicode_literals +from __future__ import annotations from builtins import zip +from typing import Dict, List, Tuple import keras import keras.backend as K @@ -15,19 +15,15 @@ import keras.layers.pooling import keras.models import numpy as np +from keras.layers import Layer +from tensorflow import Tensor +import innvestigate.analyzer.relevance_based.utils as rutils +import innvestigate.layers as ilayers +import innvestigate.utils as iutils import innvestigate.utils.keras as kutils -from innvestigate import layers as ilayers -from innvestigate import utils as iutils -from innvestigate.utils.keras import backend as iK -from innvestigate.utils.keras import graph as kgraph - -from . import utils as rutils - -############################################################################### -############################################################################### -############################################################################### - +import innvestigate.utils.keras.backend as iK +import innvestigate.utils.keras.graph as kgraph # TODO: differentiate between LRP and DTD rules? # DTD rules are special cases of LRP rules with additional assumptions @@ -63,12 +59,19 @@ class ZRule(kgraph.ReverseMappingBase): which considers the bias a constant input neuron. """ - def __init__(self, layer, state, bias=True): + def __init__(self, layer: Layer, state, bias: bool = True) -> None: self._layer_wo_act = kgraph.copy_layer_wo_activation( layer, keep_bias=bias, name_template="reversed_kernel_%s" ) + super().__init__(layer, state) - def apply(self, Xs, Ys, Rs, reverse_state): + def apply( + self, + Xs: List[Tensor], + _Ys: List[Tensor], + Rs: List[Tensor], + _reverse_state, + ) -> List[Tensor]: grad = ilayers.GradientWRT(len(Xs)) # Get activations. @@ -88,7 +91,7 @@ class ZIgnoreBiasRule(ZRule): """ def __init__(self, *args, **kwargs): - super(ZIgnoreBiasRule, self).__init__(*args, bias=False, **kwargs) + super().__init__(*args, bias=False, **kwargs) class EpsilonRule(kgraph.ReverseMappingBase): @@ -100,15 +103,22 @@ class EpsilonRule(kgraph.ReverseMappingBase): 0 is considered to be positive, ie sign(0) = 1 """ - def __init__(self, layer, state, epsilon=1e-7, bias=True): + def __init__(self, layer: Layer, state, epsilon=1e-7, bias: bool = True): self._epsilon = rutils.assert_lrp_epsilon_param(epsilon, self) self._layer_wo_act = kgraph.copy_layer_wo_activation( layer, keep_bias=bias, name_template="reversed_kernel_%s" ) - def apply(self, Xs, Ys, Rs, reverse_state): + def apply( + self, + Xs: List[Tensor], + _Ys: List[Tensor], + Rs: List[Tensor], + _reverse_state: Dict, + ): grad = ilayers.GradientWRT(len(Xs)) - # The epsilon rule aligns epsilon with the (extended) sign: 0 is considered to be positive + # The epsilon rule aligns epsilon with the (extended) sign: + # 0 is considered to be positive prepare_div = keras.layers.Lambda( lambda x: x + (K.cast(K.greater_equal(x, 0), K.floatx()) * 2 - 1) * self._epsilon @@ -130,13 +140,13 @@ class EpsilonIgnoreBiasRule(EpsilonRule): """Same as EpsilonRule but ignores the bias.""" def __init__(self, *args, **kwargs): - super(EpsilonIgnoreBiasRule, self).__init__(*args, bias=False, **kwargs) + super().__init__(*args, bias=False, **kwargs) class WSquareRule(kgraph.ReverseMappingBase): """W**2 rule from Deep Taylor Decomposition""" - def __init__(self, layer, state, copy_weights=False): + def __init__(self, layer: Layer, state, copy_weights=False) -> None: # W-square rule works with squared weights and no biases. if copy_weights: weights = layer.get_weights() @@ -150,7 +160,13 @@ def __init__(self, layer, state, copy_weights=False): layer, keep_bias=False, weights=weights, name_template="reversed_kernel_%s" ) - def apply(self, Xs, Ys, Rs, reverse_state): + def apply( + self, + Xs: List[Tensor], + Ys: List[Tensor], + Rs: List[Tensor], + _reverse_state: Dict, + ) -> List[Tensor]: grad = ilayers.GradientWRT(len(Xs)) # Create dummy forward path to take the derivative below. Ys = kutils.apply(self._layer_wo_act_b, Xs) @@ -168,7 +184,7 @@ def apply(self, Xs, Ys, Rs, reverse_state): class FlatRule(WSquareRule): """Same as W**2 rule but sets all weights to ones.""" - def __init__(self, layer, state, copy_weights=False): + def __init__(self, layer: Layer, state, copy_weights: bool = False) -> None: # The flat rule works with weights equal to one and # no biases. if copy_weights: @@ -207,8 +223,14 @@ class AlphaBetaRule(kgraph.ReverseMappingBase): """ def __init__( - self, layer, state, alpha=None, beta=None, bias=True, copy_weights=False - ): + self, + layer: Layer, + _state, + alpha=None, + beta=None, + bias: bool = True, + copy_weights=False, + ) -> None: alpha, beta = rutils.assert_infer_lrp_alpha_beta_param(alpha, beta, self) self._alpha = alpha self._beta = beta @@ -241,7 +263,13 @@ def __init__( name_template="reversed_kernel_negative_%s", ) - def apply(self, Xs, Ys, Rs, reverse_state): + def apply( + self, + Xs: List[Tensor], + _Ys: List[Tensor], + Rs: List[Tensor], + _reverse_state: Dict, + ): # this method is correct, but wasteful grad = ilayers.GradientWRT(len(Xs)) times_alpha = keras.layers.Lambda(lambda x: x * self._alpha) @@ -295,43 +323,35 @@ class AlphaBetaIgnoreBiasRule(AlphaBetaRule): """Same as AlphaBetaRule but ignores biases.""" def __init__(self, *args, **kwargs): - super(AlphaBetaIgnoreBiasRule, self).__init__(*args, bias=False, **kwargs) + super().__init__(*args, bias=False, **kwargs) class Alpha2Beta1Rule(AlphaBetaRule): """AlphaBetaRule with alpha=2, beta=1""" def __init__(self, *args, **kwargs): - super(Alpha2Beta1Rule, self).__init__( - *args, alpha=2, beta=1, bias=True, **kwargs - ) + super().__init__(*args, alpha=2, beta=1, bias=True, **kwargs) class Alpha2Beta1IgnoreBiasRule(AlphaBetaRule): """AlphaBetaRule with alpha=2, beta=1 and ignores biases""" def __init__(self, *args, **kwargs): - super(Alpha2Beta1IgnoreBiasRule, self).__init__( - *args, alpha=2, beta=1, bias=False, **kwargs - ) + super().__init__(*args, alpha=2, beta=1, bias=False, **kwargs) class Alpha1Beta0Rule(AlphaBetaRule): """AlphaBetaRule with alpha=1, beta=0""" def __init__(self, *args, **kwargs): - super(Alpha1Beta0Rule, self).__init__( - *args, alpha=1, beta=0, bias=True, **kwargs - ) + super().__init__(*args, alpha=1, beta=0, bias=True, **kwargs) class Alpha1Beta0IgnoreBiasRule(AlphaBetaRule): """AlphaBetaRule with alpha=1, beta=0 and ignores biases""" def __init__(self, *args, **kwargs): - super(Alpha1Beta0IgnoreBiasRule, self).__init__( - *args, alpha=1, beta=0, bias=False, **kwargs - ) + super().__init__(*args, alpha=1, beta=0, bias=False, **kwargs) class AlphaBetaXRule(kgraph.ReverseMappingBase): @@ -341,13 +361,13 @@ class AlphaBetaXRule(kgraph.ReverseMappingBase): def __init__( self, - layer, - state, - alpha=(0.5, 0.5), - beta=(0.5, 0.5), - bias=True, - copy_weights=False, - ): + layer: Layer, + _state, + alpha: Tuple[float, float] = (0.5, 0.5), + beta: Tuple[float, float] = (0.5, 0.5), + bias: bool = True, + copy_weights: bool = False, + ) -> None: self._alpha = alpha self._beta = beta @@ -379,11 +399,17 @@ def __init__( name_template="reversed_kernel_negative_%s", ) - def apply(self, Xs, Ys, Rs, reverse_state): + def apply( + self, + Xs: List[Tensor], + _Ys: List[Tensor], + Rs: List[Tensor], + _reverse_state: Dict, + ): # this method is correct, but wasteful grad = ilayers.GradientWRT(len(Xs)) times_alpha0 = keras.layers.Lambda(lambda x: x * self._alpha[0]) - times_alpha1 = keras.layers.Lambda(lambda x: x * self._alpha[1]) + # times_alpha1 = keras.layers.Lambda(lambda x: x * self._alpha[1]) # unused times_beta0 = keras.layers.Lambda(lambda x: x * self._beta[0]) times_beta1 = keras.layers.Lambda(lambda x: x * self._beta[1]) keep_positives = keras.layers.Lambda( @@ -393,7 +419,7 @@ def apply(self, Xs, Ys, Rs, reverse_state): lambda x: x * K.cast(K.less(x, 0), K.floatx()) ) - def f(layer, X): + def f(layer: Layer, X): Zs = kutils.apply(layer, X) # Divide incoming relevance by the activations. tmp = [ilayers.SafeDivide()([a, b]) for a, b in zip(Rs, Zs)] @@ -433,30 +459,22 @@ def f(layer, X): class AlphaBetaX1000Rule(AlphaBetaXRule): def __init__(self, *args, **kwargs): - super(AlphaBetaX1000Rule, self).__init__( - *args, alpha=(1, 0), beta=(0, 0), bias=True, **kwargs - ) + super().__init__(*args, alpha=(1, 0), beta=(0, 0), bias=True, **kwargs) class AlphaBetaX1010Rule(AlphaBetaXRule): def __init__(self, *args, **kwargs): - super(AlphaBetaX1010Rule, self).__init__( - *args, alpha=(1, 0), beta=(0, -1), bias=True, **kwargs - ) + super().__init__(*args, alpha=(1, 0), beta=(0, -1), bias=True, **kwargs) class AlphaBetaX1001Rule(AlphaBetaXRule): def __init__(self, *args, **kwargs): - super(AlphaBetaX1001Rule, self).__init__( - *args, alpha=(1, 1), beta=(0, 0), bias=True, **kwargs - ) + super().__init__(*args, alpha=(1, 1), beta=(0, 0), bias=True, **kwargs) class AlphaBetaX2m100Rule(AlphaBetaXRule): def __init__(self, *args, **kwargs): - super(AlphaBetaX2m100Rule, self).__init__( - *args, alpha=(2, 0), beta=(1, 0), bias=True, **kwargs - ) + super().__init__(*args, alpha=(2, 0), beta=(1, 0), bias=True, **kwargs) class BoundedRule(kgraph.ReverseMappingBase): @@ -464,7 +482,9 @@ class BoundedRule(kgraph.ReverseMappingBase): # TODO: this only works for relu networks, needs to be extended. # TODO: check - def __init__(self, layer, state, low=-1, high=1, copy_weights=False): + def __init__( + self, layer: Layer, _state, low=-1, high=1, copy_weights: bool = False + ) -> None: self._low = low self._high = high @@ -501,7 +521,7 @@ def __init__(self, layer, state, low=-1, high=1, copy_weights=False): ) # TODO: clean up this implementation and add more documentation - def apply(self, Xs, Ys, Rs, reverse_state): + def apply(self, Xs, _Ys, Rs, reverse_state: Dict): grad = ilayers.GradientWRT(len(Xs)) to_low = keras.layers.Lambda(lambda x: x * 0 + self._low) to_high = keras.layers.Lambda(lambda x: x * 0 + self._high) @@ -521,17 +541,17 @@ def apply(self, Xs, Ys, Rs, reverse_state): # Divide relevances with the value. tmp = [ilayers.SafeDivide()([a, b]) for a, b in zip(Rs, Zs)] # Distribute along the gradient. - tmpA = iutils.to_list(grad(Xs + A + tmp)) - tmpB = iutils.to_list(grad(low + B + tmp)) - tmpC = iutils.to_list(grad(high + C + tmp)) + tmp_a = iutils.to_list(grad(Xs + A + tmp)) + tmp_b = iutils.to_list(grad(low + B + tmp)) + tmp_c = iutils.to_list(grad(high + C + tmp)) - tmpA = [keras.layers.Multiply()([a, b]) for a, b in zip(Xs, tmpA)] - tmpB = [keras.layers.Multiply()([a, b]) for a, b in zip(low, tmpB)] - tmpC = [keras.layers.Multiply()([a, b]) for a, b in zip(high, tmpC)] + tmp_a = [keras.layers.Multiply()([a, b]) for a, b in zip(Xs, tmp_a)] + tmp_b = [keras.layers.Multiply()([a, b]) for a, b in zip(low, tmp_b)] + tmp_c = [keras.layers.Multiply()([a, b]) for a, b in zip(high, tmp_c)] tmp = [ keras.layers.Subtract()([a, keras.layers.Add()([b, c])]) - for a, b, c in zip(tmpA, tmpB, tmpC) + for a, b, c in zip(tmp_a, tmp_b, tmp_c) ] return tmp @@ -548,7 +568,7 @@ class ZPlusRule(Alpha1Beta0IgnoreBiasRule): # TODO: assert that layer inputs are always >= 0 def __init__(self, *args, **kwargs): - super(ZPlusRule, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) class ZPlusFastRule(kgraph.ReverseMappingBase): @@ -557,7 +577,7 @@ class ZPlusFastRule(kgraph.ReverseMappingBase): for alpha=1, beta=0 and assumes inputs x >= 0. """ - def __init__(self, layer, state, copy_weights=False): + def __init__(self, layer: Layer, state, copy_weights=False): # The z-plus rule only works with positive weights and # no biases. # TODO: assert that layer inputs are always >= 0 @@ -579,11 +599,13 @@ def __init__(self, layer, state, copy_weights=False): name_template="reversed_kernel_positive_%s", ) - def apply(self, Xs, Ys, Rs, reverse_state): + def apply(self, Xs, _Ys, Rs, reverse_state: Dict): grad = ilayers.GradientWRT(len(Xs)) # TODO: assert all inputs are positive, instead of only keeping the positives. - # keep_positives = keras.layers.Lambda(lambda x: x * K.cast(K.greater(x,0), K.floatx())) + # keep_positives = keras.layers.Lambda( + # lambda x: x * K.cast(K.greater(x, 0), K.floatx()) + # ) # Xs = kutils.apply(keep_positives, Xs) # Get activations. diff --git a/src/innvestigate/analyzer/relevance_based/utils.py b/src/innvestigate/analyzer/relevance_based/utils.py index 05ce2e33..26a1b531 100644 --- a/src/innvestigate/analyzer/relevance_based/utils.py +++ b/src/innvestigate/analyzer/relevance_based/utils.py @@ -1,19 +1,8 @@ -# Get Python six functionality: -from __future__ import absolute_import, division, print_function, unicode_literals - -############################################################################### -############################################################################### -############################################################################### - +from __future__ import annotations __all__ = ["assert_lrp_epsilon_param", "assert_infer_lrp_alpha_beta_param"] -############################################################################### -############################################################################### -############################################################################### - - def assert_lrp_epsilon_param(epsilon, caller): """ Function for asserting epsilon parameter choice @@ -55,6 +44,8 @@ def assert_infer_lrp_alpha_beta_param(alpha, beta, caller): :param caller: the class instance calling this assertion function """ + # TODO: Rework error messages + err_head = "Constructor call to {} : ".format(caller.__class__.__name__) if alpha is None and beta is None: err_msg = err_head + "Neither alpha or beta were given" @@ -85,7 +76,7 @@ def assert_infer_lrp_alpha_beta_param(alpha, beta, caller): if alpha < 1: err_msg = ( err_head - + "Inferring alpha from given beta {} s.t. alpha - beta = 1, with condition alpha >= 1 not possible.".format( + + "Inferring alpha from given beta {} s.t. alpha - beta = 1, with condition alpha >= 1 not possible.".format( # noqa beta ) ) @@ -96,7 +87,7 @@ def assert_infer_lrp_alpha_beta_param(alpha, beta, caller): if beta < 0: err_msg = ( err_head - + "Inferring beta from given alpha {} s.t. alpha - beta = 1, with condition beta >= 0 not possible.".format( + + "Inferring beta from given alpha {} s.t. alpha - beta = 1, with condition beta >= 0 not possible.".format( # noqa alpha ) ) @@ -107,7 +98,7 @@ def assert_infer_lrp_alpha_beta_param(alpha, beta, caller): if amb != 1: err_msg = ( err_head - + "Condition alpha - beta = 1 not fulfilled. alpha={} ; beta={} -> alpha - beta = {}".format( + + "Condition alpha - beta = 1 not fulfilled. alpha={} ; beta={} -> alpha - beta = {}".format( # noqa alpha, beta, amb ) ) @@ -115,8 +106,3 @@ def assert_infer_lrp_alpha_beta_param(alpha, beta, caller): # return benign values for alpha and beta return alpha, beta - - -############################################################################### -############################################################################### -############################################################################### diff --git a/src/innvestigate/analyzer/reverse_base.py b/src/innvestigate/analyzer/reverse_base.py new file mode 100644 index 00000000..0e80e5ff --- /dev/null +++ b/src/innvestigate/analyzer/reverse_base.py @@ -0,0 +1,405 @@ +from __future__ import annotations + +from typing import Callable, Dict, List, Optional, Tuple + +import keras +import keras.layers +import keras.models +import numpy as np +import six + +import innvestigate.layers as ilayers +import innvestigate.utils as iutils +import innvestigate.utils.keras.graph as kgraph +from innvestigate.analyzer.network_base import AnalyzerNetworkBase +from innvestigate.utils.types import ( + CondReverseMapping, + Layer, + Model, + OptionalList, + Tensor, +) + +__all__ = ["ReverseAnalyzerBase"] + + +class ReverseAnalyzerBase(AnalyzerNetworkBase): + """Convenience class for analyzers that revert the model's structure. + + This class contains many helper functions around the graph + reverse function :func:`innvestigate.utils.keras.graph.reverse_model`. + + The deriving classes should specify how the graph should be reverted + by implementing the following functions: + + * :func:`_reverse_mapping(layer)` given a layer this function + returns a reverse mapping for the layer as specified in + :func:`innvestigate.utils.keras.graph.reverse_model` or None. + + This function can be implemented, but it is encouraged to + implement a default mapping and add additional changes with + the function :func:`_add_conditional_reverse_mapping` (see below). + + The default behavior is finding a conditional mapping (see below), + if none is found, :func:`_default_reverse_mapping` is applied. + * :func:`_default_reverse_mapping` defines the default + reverse mapping. + * :func:`_head_mapping` defines how the outputs of the model + should be instantiated before the are passed to the reversed + network. + + Furthermore other parameters of the function + :func:`innvestigate.utils.keras.graph.reverse_model` can + be changed by setting the according parameters of the + init function: + + :param reverse_verbose: Print information on the reverse process. + :param reverse_clip_values: Clip the values that are passed along + the reverted network. Expects tuple (min, max). + :param reverse_project_bottleneck_layers: Project the value range + of bottleneck tensors in the reverse network into another range. + :param reverse_check_min_max_values: Print the min/max values + observed in each tensor along the reverse network whenever + :func:`analyze` is called. + :param reverse_check_finite: Check if values passed along the + reverse network are finite. + :param reverse_keep_tensors: Keeps the tensors created in the + backward pass and stores them in the attribute + :attr:`_reversed_tensors`. + :param reverse_reapply_on_copied_layers: See + :func:`innvestigate.utils.keras.graph.reverse_model`. + """ + + def __init__( + self, + model: keras.Model, + reverse_verbose: bool = False, + reverse_clip_values: bool = False, + reverse_project_bottleneck_layers: bool = False, + reverse_check_min_max_values: bool = False, + reverse_check_finite: bool = False, + reverse_keep_tensors: bool = False, + reverse_reapply_on_copied_layers: bool = False, + **kwargs + ) -> None: + """ + From AnalyzerBase super init: + * Initializes empty list of _model_checks + + From AnalyzerNetworkBase super init: + * set _neuron_selection_mode + * add check for lambda layers through 'allow_lambda_layers' + * define attributes for '_prepare_model', which is later called + through 'create_analyzer_model' + + Here: + * define attributes required for calling '_conditional_reverse_mapping' + """ + super().__init__(model, **kwargs) + + self._reverse_verbose = reverse_verbose + self._reverse_clip_values = reverse_clip_values + self._reverse_project_bottleneck_layers = reverse_project_bottleneck_layers + self._reverse_check_min_max_values = reverse_check_min_max_values + self._reverse_check_finite = reverse_check_finite + self._reverse_keep_tensors = reverse_keep_tensors + self._reverse_reapply_on_copied_layers = reverse_reapply_on_copied_layers + self._reverse_mapping_applied: bool = False + + # map priorities to lists of conditional reverse mappings + self._conditional_reverse_mappings: Dict[int, List[CondReverseMapping]] = {} + + # Maps keys "min", "max", "finite", "keep" to tuples of indices + self._debug_tensors_indices: Dict[str, Tuple[int, int]] = {} + + def _gradient_reverse_mapping( + self, + Xs: OptionalList[Tensor], + Ys: OptionalList[Tensor], + reversed_Ys: OptionalList[Tensor], + reverse_state: Dict, + ): + mask = [x not in reverse_state["stop_mapping_at_tensors"] for x in Xs] + masked_grad = ilayers.GradientWRT(len(Xs), mask=mask) + return masked_grad(Xs + Ys + reversed_Ys) + + def _reverse_mapping(self, layer: keras.layers.Layer): + """ + This function should return a reverse mapping for the passed layer. + + If this function returns None, :func:`_default_reverse_mapping` + is applied. + + :param layer: The layer for which a mapping should be returned. + :return: The mapping can be of the following forms: + * A function of form (A) f(Xs, Ys, reversed_Ys, reverse_state) + that maps reversed_Ys to reversed_Xs (which should contain + tensors of the same shape and type). + * A function of form f(B) f(layer, reverse_state) that returns + a function of form (A). + * A :class:`ReverseMappingBase` subclass. + """ + if layer in self._special_helper_layers: + # Special layers added by AnalyzerNetworkBase + # that should not be exposed to user. + return self._gradient_reverse_mapping + + return self._apply_conditional_reverse_mappings(layer) + + def _add_conditional_reverse_mapping( + self, + condition: Callable[[Layer], bool], + mapping: Callable, # TODO: specify type of Callable + priority: int = -1, + name: Optional[str] = None, + ): + """ + This function should return a reverse mapping for the passed layer. + + If this function returns None, :func:`_default_reverse_mapping` + is applied. + + :param condition: Condition when this mapping should be applied. + Form: f(layer) -> bool + :param mapping: The mapping can be of the following forms: + * A function of form (A) f(Xs, Ys, reversed_Ys, reverse_state) + that maps reversed_Ys to reversed_Xs (which should contain + tensors of the same shape and type). + * A function of form f(B) f(layer, reverse_state) that returns + a function of form (A). + * A :class:`ReverseMappingBase` subclass. + :param priority: The higher the earlier the condition gets + evaluated. + :param name: An identifying name. + """ + if self._reverse_mapping_applied is True: + raise Exception( + "Cannot add conditional mapping " "after first application." + ) + + # Add key `priority` to dict _conditional_reverse_mappings + # if it doesn't exist yet. + if priority not in self._conditional_reverse_mappings: + self._conditional_reverse_mappings[priority] = [] + + # Add Conditional Reveserse mapping at given priority + tmp: CondReverseMapping = { + "condition": condition, + "mapping": mapping, + "name": name, + } + self._conditional_reverse_mappings[priority].append(tmp) + + def _apply_conditional_reverse_mappings(self, layer): + mappings = getattr(self, "_conditional_reverse_mappings", {}) + self._reverse_mapping_applied = True + + # Search for mapping. First consider ones with highest priority, + # inside priority in order of adding. + sorted_keys = sorted(mappings.keys())[::-1] + for key in sorted_keys: + for mapping in mappings[key]: + if mapping["condition"](layer): + return mapping["mapping"] + + return None + + def _default_reverse_mapping( + self, + Xs: OptionalList[Tensor], + Ys: OptionalList[Tensor], + reversed_Ys: OptionalList[Tensor], + reverse_state: Dict, + ): + """ + Fallback function to map reversed_Ys to reversed_Xs + (which should contain tensors of the same shape and type). + """ + return self._gradient_reverse_mapping(Xs, Ys, reversed_Ys, reverse_state) + + def _head_mapping(self, X): + """ + Map output tensors to new values before passing + them into the reverted network. + """ + # Here: Keep the output signal. + # Should be re-implemented by inheritance. + # Refer to the "Introduction to development notebook". + return X + + def _postprocess_analysis(self, X: OptionalList[Tensor]) -> OptionalList[Tensor]: + return X + + def _reverse_model( + self, + model: Model, + stop_analysis_at_tensors: List[Tensor] = None, + return_all_reversed_tensors=False, + ): + if stop_analysis_at_tensors is None: + stop_analysis_at_tensors = [] + + return kgraph.reverse_model( + model, + reverse_mappings=self._reverse_mapping, + default_reverse_mapping=self._default_reverse_mapping, + head_mapping=self._head_mapping, + stop_mapping_at_tensors=stop_analysis_at_tensors, + verbose=self._reverse_verbose, + clip_all_reversed_tensors=self._reverse_clip_values, + project_bottleneck_tensors=self._reverse_project_bottleneck_layers, + return_all_reversed_tensors=return_all_reversed_tensors, + ) + + def _create_analysis( + self, model: Model, stop_analysis_at_tensors: List[Tensor] = None + ): + + if stop_analysis_at_tensors is None: + stop_analysis_at_tensors = [] + + return_all_reversed_tensors = ( + self._reverse_check_min_max_values + or self._reverse_check_finite + or self._reverse_keep_tensors + ) + ret = self._reverse_model( + model, + stop_analysis_at_tensors=stop_analysis_at_tensors, + return_all_reversed_tensors=return_all_reversed_tensors, + ) + + if return_all_reversed_tensors: + ret = (self._postprocess_analysis(ret[0]), ret[1]) + else: + ret = self._postprocess_analysis(ret) + + if return_all_reversed_tensors: + debug_tensors: List[Tensor] + tmp: List[Tensor] + + debug_tensors = [] + values = list(six.itervalues(ret[1])) + mapping = {i: v["id"] for i, v in enumerate(values)} + tensors = [v["final_tensor"] for v in values] + self._reverse_tensors_mapping = mapping + + if self._reverse_check_min_max_values: + tmp = [ilayers.Min(None)(x) for x in tensors] + self._debug_tensors_indices["min"] = ( + len(debug_tensors), + len(debug_tensors) + len(tmp), + ) + debug_tensors += tmp + + tmp = [ilayers.Max(None)(x) for x in tensors] + self._debug_tensors_indices["max"] = ( + len(debug_tensors), + len(debug_tensors) + len(tmp), + ) + debug_tensors += tmp + + if self._reverse_check_finite: + tmp = iutils.to_list(ilayers.FiniteCheck()(tensors)) + self._debug_tensors_indices["finite"] = ( + len(debug_tensors), + len(debug_tensors) + len(tmp), + ) + debug_tensors += tmp + + if self._reverse_keep_tensors: + self._debug_tensors_indices["keep"] = ( + len(debug_tensors), + len(debug_tensors) + len(tensors), + ) + debug_tensors += tensors + + ret = (ret[0], debug_tensors) + return ret + + def _handle_debug_output(self, debug_values): + + if self._reverse_check_min_max_values: + indices = self._debug_tensors_indices["min"] + tmp = debug_values[indices[0] : indices[1]] + tmp = sorted( + [(self._reverse_tensors_mapping[i], v) for i, v in enumerate(tmp)] + ) + print( + "Minimum values in tensors: " + "((NodeID, TensorID), Value) - {}".format(tmp) + ) + + indices = self._debug_tensors_indices["max"] + tmp = debug_values[indices[0] : indices[1]] + tmp = sorted( + [(self._reverse_tensors_mapping[i], v) for i, v in enumerate(tmp)] + ) + print( + "Maximum values in tensors: " + "((NodeID, TensorID), Value) - {}".format(tmp) + ) + + if self._reverse_check_finite: + indices = self._debug_tensors_indices["finite"] + tmp = debug_values[indices[0] : indices[1]] + nfinite_tensors = np.flatnonzero(np.asarray(tmp) > 0) + + if len(nfinite_tensors) > 0: + nfinite_tensors = sorted( + [self._reverse_tensors_mapping[i] for i in nfinite_tensors] + ) + print( + "Not finite values found in following nodes: " + "(NodeID, TensorID) - {}".format(nfinite_tensors) + ) + + if self._reverse_keep_tensors: + indices = self._debug_tensors_indices["keep"] + tmp = debug_values[indices[0] : indices[1]] + tmp = sorted( + [(self._reverse_tensors_mapping[i], v) for i, v in enumerate(tmp)] + ) + self._reversed_tensors = tmp + + def _get_state(self): + state = super()._get_state() + state.update( + { + "reverse_verbose": self._reverse_verbose, + "reverse_clip_values": self._reverse_clip_values, + "reverse_project_bottleneck_layers": self._reverse_project_bottleneck_layers, # noqa + "reverse_check_min_max_values": self._reverse_check_min_max_values, + "reverse_check_finite": self._reverse_check_finite, + "reverse_keep_tensors": self._reverse_keep_tensors, + "reverse_reapply_on_copied_layers": self._reverse_reapply_on_copied_layers, # noqa + } + ) + return state + + @classmethod + def _state_to_kwargs(cls, state: dict): + reverse_verbose = state.pop("reverse_verbose") + reverse_clip_values = state.pop("reverse_clip_values") + reverse_project_bottleneck_layers = state.pop( + "reverse_project_bottleneck_layers" + ) + reverse_check_min_max_values = state.pop("reverse_check_min_max_values") + reverse_check_finite = state.pop("reverse_check_finite") + reverse_keep_tensors = state.pop("reverse_keep_tensors") + reverse_reapply_on_copied_layers = state.pop("reverse_reapply_on_copied_layers") + # call super after popping class-specific states: + kwargs = super()._state_to_kwargs(state) + + kwargs.update( + { + "reverse_verbose": reverse_verbose, + "reverse_clip_values": reverse_clip_values, + "reverse_project_bottleneck_layers": reverse_project_bottleneck_layers, + "reverse_check_min_max_values": reverse_check_min_max_values, + "reverse_check_finite": reverse_check_finite, + "reverse_keep_tensors": reverse_keep_tensors, + "reverse_reapply_on_copied_layers": reverse_reapply_on_copied_layers, + } + ) + return kwargs diff --git a/src/innvestigate/analyzer/wrapper.py b/src/innvestigate/analyzer/wrapper.py index 0045c76f..349e4d7f 100644 --- a/src/innvestigate/analyzer/wrapper.py +++ b/src/innvestigate/analyzer/wrapper.py @@ -1,21 +1,19 @@ -# Get Python six functionality: -from __future__ import absolute_import, division, print_function, unicode_literals +from __future__ import annotations +import warnings from builtins import zip +from typing import List, Optional, Union -import keras.backend as K +import keras.backend as kbackend import keras.models import numpy as np -from .. import layers as ilayers -from .. import utils as iutils -from ..utils import keras as kutils -from . import base - -############################################################################### -############################################################################### -############################################################################### - +import innvestigate.layers as ilayers +import innvestigate.utils as iutils +import innvestigate.utils.keras as kutils +from innvestigate.analyzer.base import AnalyzerBase +from innvestigate.analyzer.network_base import AnalyzerNetworkBase +from innvestigate.utils.types import OptionalList, Tensor __all__ = [ "WrapperBase", @@ -25,12 +23,7 @@ ] -############################################################################### -############################################################################### -############################################################################### - - -class WrapperBase(base.AnalyzerBase): +class WrapperBase(AnalyzerBase): """Interface for wrappers around analyzers This class is the basic interface for wrappers around analyzers. @@ -38,37 +31,39 @@ class WrapperBase(base.AnalyzerBase): :param subanalyzer: The analyzer to be wrapped. """ - def __init__(self, subanalyzer, *args, **kwargs): - self._subanalyzer = subanalyzer - model = None + def __init__(self, subanalyzer: AnalyzerBase, *args, **kwargs): + # To simplify serialization, additionaly passed models are popped + # and the subanalyzer model is passed to `AnalyzerBase`. + kwargs.pop("model", None) + super().__init__(subanalyzer._model, *args, **kwargs) - super(WrapperBase, self).__init__(model, *args, **kwargs) + self._subanalyzer_name = subanalyzer.__class__.__name__ + self._subanalyzer = subanalyzer def analyze(self, *args, **kwargs): return self._subanalyzer.analyze(*args, **kwargs) - def _get_state(self): + def _get_state(self) -> dict: sa_class_name, sa_state = self._subanalyzer.save() - state = {} + state = super()._get_state() state.update({"subanalyzer_class_name": sa_class_name}) state.update({"subanalyzer_state": sa_state}) return state @classmethod - def _state_to_kwargs(clazz, state): + def _state_to_kwargs(cls, state: dict): sa_class_name = state.pop("subanalyzer_class_name") sa_state = state.pop("subanalyzer_state") - assert len(state) == 0 + # call super after popping class-specific states: + kwargs = super()._state_to_kwargs(state) - subanalyzer = base.AnalyzerBase.load(sa_class_name, sa_state) - kwargs = {"subanalyzer": subanalyzer} + subanalyzer = AnalyzerBase.load(sa_class_name, sa_state) + kwargs.update({"subanalyzer": subanalyzer}) return kwargs ############################################################################### -############################################################################### -############################################################################### class AugmentReduceBase(WrapperBase): @@ -82,44 +77,53 @@ class AugmentReduceBase(WrapperBase): :param augment_by_n: Number of samples to create. """ - def __init__(self, subanalyzer, *args, **kwargs): - self._augment_by_n = kwargs.pop("augment_by_n", 2) - self._neuron_selection_mode = subanalyzer._neuron_selection_mode - - if self._neuron_selection_mode != "all": - # TODO: this is not transparent, find a better way. + def __init__( + self, + subanalyzer: AnalyzerNetworkBase, + *args, + augment_by_n: int = 2, + neuron_selection_mode="max_activation", + **kwargs, + ): + + if neuron_selection_mode == "max_activation": + # TODO: find a more transparent way. + # + # Since AugmentReduceBase analyzers augment the input, + # it is possible that the neuron w/ max activation changes. + # As a workaround, the index of the maximally activated neuron + # w.r.t. the "unperturbed" input is computed and used in combination + # with neuron_selection_mode = "index" in the subanalyzer. + # + # NOTE: + # The analyzer will still have neuron_selection_mode = "max_activation"! subanalyzer._neuron_selection_mode = "index" - super(AugmentReduceBase, self).__init__(subanalyzer, *args, **kwargs) - if isinstance(self._subanalyzer, base.AnalyzerNetworkBase): - # Take the keras analyzer model and - # add augment and reduce functionality. - self._keras_based_augment_reduce = True - else: - raise NotImplementedError("Keras-based subanalyzer required.") + super().__init__( + subanalyzer, *args, neuron_selection_mode=neuron_selection_mode, **kwargs + ) - def create_analyzer_model(self): - if not self._keras_based_augment_reduce: - return + self._augment_by_n: int = augment_by_n # number of samples to create + def create_analyzer_model(self): self._subanalyzer.create_analyzer_model() if self._subanalyzer._n_debug_output > 0: - raise Exception("No debug output at subanalyzer is supported.") + raise NotImplementedError("No debug output at subanalyzer is supported.") model = self._subanalyzer._analyzer_model if None in model.input_shape[1:]: raise ValueError( "The input shape for the model needs " "to be fully specified (except the batch axis). " - "Model input shape is: %s" % (model.input_shape,) + f"Model input shape is: {model.input_shape}" ) inputs = model.inputs[: self._subanalyzer._n_data_input] extra_inputs = model.inputs[self._subanalyzer._n_data_input :] - # todo: check this, index seems not right. - # outputs = model.outputs[:self._subanalyzer._n_data_input] - extra_outputs = model.outputs[self._subanalyzer._n_data_input :] + + outputs = model.outputs[: self._subanalyzer._n_data_output] + extra_outputs = model.outputs[self._subanalyzer._n_data_output :] if len(extra_outputs) > 0: raise Exception("No extra output is allowed " "with this wrapper.") @@ -136,30 +140,40 @@ def create_analyzer_model(self): ) self._subanalyzer._analyzer_model = new_model - def analyze(self, X, *args, **kwargs): - if self._keras_based_augment_reduce is True: - if not hasattr(self._subanalyzer, "_analyzer_model"): - self.create_analyzer_model() - - ns_mode = self._neuron_selection_mode - if ns_mode in ["max_activation", "index"]: - if ns_mode == "max_activation": - tmp = self._subanalyzer._model.predict(X) - indices = np.argmax(tmp, axis=1) - else: - if len(args): - args = list(args) - indices = args.pop(0) - else: - indices = kwargs.pop("neuron_selection") - - # broadcast to match augmented samples. - indices = np.repeat(indices, self._augment_by_n) - - kwargs["neuron_selection"] = indices + def analyze( + self, X: OptionalList[np.ndarray], *args, **kwargs + ) -> OptionalList[np.ndarray]: + if self._subanalyzer._analyzer_model is None: + self.create_analyzer_model() + + ns_mode = self._neuron_selection_mode + if ns_mode == "all": return self._subanalyzer.analyze(X, *args, **kwargs) - else: - raise DeprecationWarning("Not supported anymore.") + + # As described in the AugmentReduceBase init, + # both ns_mode "max_activation" and "index" make use + # of a subanalyzer using neuron_selection_mode="index". + elif ns_mode == "max_activation": + # obtain max neuron activations over batch + pred = self._subanalyzer._model.predict(X) + indices = np.argmax(pred, axis=1) + elif ns_mode == "index": + # TODO: make neuron_selection arg or kwarg, not both + if len(args): + arglist = list(args) + indices = arglist.pop(0) + else: + indices = kwargs.pop("neuron_selection") + + if not self._subanalyzer._neuron_selection_mode == "index": + raise AssertionError( + 'Subanalyzer neuron_selection_mode has to be "index" ' + 'when using analyzer with neuron_selection_mode != "all".' + ) + # broadcast to match augmented samples. + indices = np.repeat(indices, self._augment_by_n) + kwargs["neuron_selection"] = indices + return self._subanalyzer.analyze(X, *args, **kwargs) def _keras_get_constant_inputs(self): return list() @@ -169,7 +183,7 @@ def _augment(self, X): return [repeat(x) for x in iutils.to_list(X)] def _reduce(self, X): - X_shape = [K.int_shape(x) for x in iutils.to_list(X)] + X_shape = [kbackend.int_shape(x) for x in iutils.to_list(X)] reshape = [ ilayers.Reshape((-1, self._augment_by_n) + shape[1:]) for shape in X_shape ] @@ -178,26 +192,21 @@ def _reduce(self, X): return [mean(reshape_x(x)) for x, reshape_x in zip(X, reshape)] def _get_state(self): - if self._neuron_selection_mode != "all": - # TODO: this is not transparent, find a better way. - # revert the tempering in __init__ - tmp = self._neuron_selection_mode - self._subanalyzer._neuron_selection_mode = tmp - state = super(AugmentReduceBase, self)._get_state() + state = super()._get_state() state.update({"augment_by_n": self._augment_by_n}) return state @classmethod - def _state_to_kwargs(clazz, state): + def _state_to_kwargs(cls, state): augment_by_n = state.pop("augment_by_n") - kwargs = super(AugmentReduceBase, clazz)._state_to_kwargs(state) + # call super after popping class-specific states: + kwargs = super()._state_to_kwargs(state) + kwargs.update({"augment_by_n": augment_by_n}) return kwargs ############################################################################### -############################################################################### -############################################################################### class GaussianSmoother(AugmentReduceBase): @@ -211,31 +220,31 @@ class GaussianSmoother(AugmentReduceBase): :param augment_by_n: Number of samples to create. """ - def __init__(self, subanalyzer, *args, **kwargs): - self._noise_scale = kwargs.pop("noise_scale", 1) - super(GaussianSmoother, self).__init__(subanalyzer, *args, **kwargs) + def __init__(self, subanalyzer, *args, noise_scale: float = 1, **kwargs): + super().__init__(subanalyzer, *args, **kwargs) + self._noise_scale = noise_scale def _augment(self, X): - tmp = super(GaussianSmoother, self)._augment(X) + tmp = super()._augment(X) noise = ilayers.TestPhaseGaussianNoise(stddev=self._noise_scale) return [noise(x) for x in tmp] def _get_state(self): - state = super(GaussianSmoother, self)._get_state() + state = super()._get_state() state.update({"noise_scale": self._noise_scale}) return state @classmethod - def _state_to_kwargs(clazz, state): + def _state_to_kwargs(cls, state): noise_scale = state.pop("noise_scale") - kwargs = super(GaussianSmoother, clazz)._state_to_kwargs(state) + # call super after popping class-specific states: + kwargs = super()._state_to_kwargs(state) + kwargs.update({"noise_scale": noise_scale}) return kwargs ############################################################################### -############################################################################### -############################################################################### class PathIntegrator(AugmentReduceBase): @@ -243,8 +252,8 @@ class PathIntegrator(AugmentReduceBase): This analyzer: * creates a path from input to reference image. - * creates steps number of intermediate inputs and - crests an analysis for them. + * creates `steps` number of intermediate inputs and + creates an analysis for them. * sums the analyses and multiplies them with the input-reference_input. This wrapper is used to implement Integrated Gradients. @@ -255,37 +264,39 @@ class PathIntegrator(AugmentReduceBase): :param reference_inputs: The reference input. """ - def __init__(self, subanalyzer, *args, **kwargs): - steps = kwargs.pop("steps", 16) - self._reference_inputs = kwargs.pop("reference_inputs", 0) - self._keras_constant_inputs = None - super(PathIntegrator, self).__init__( - subanalyzer, *args, augment_by_n=steps, **kwargs - ) + def __init__( + self, subanalyzer, *args, steps: int = 16, reference_inputs=0, **kwargs + ): + super().__init__(subanalyzer, *args, augment_by_n=steps, **kwargs) + + self._reference_inputs = reference_inputs + self._keras_constant_inputs: Optional[List[Tensor]] = None - def _keras_set_constant_inputs(self, inputs): - tmp = [K.variable(x) for x in inputs] + def _keras_set_constant_inputs(self, inputs: List[Tensor]) -> None: + tmp = [kbackend.variable(x) for x in inputs] self._keras_constant_inputs = [ - keras.layers.Input(tensor=x, shape=x.shape[1:]) for x in tmp + keras.layers.Input(tensor=X, shape=X.shape[1:]) for X in tmp ] - def _keras_get_constant_inputs(self): + def _keras_get_constant_inputs(self) -> Optional[List[Tensor]]: return self._keras_constant_inputs - def _compute_difference(self, X): + def _compute_difference(self, X: List[Tensor]) -> List[Tensor]: if self._keras_constant_inputs is None: tmp = kutils.broadcast_np_tensors_to_keras_tensors( X, self._reference_inputs ) self._keras_set_constant_inputs(tmp) - reference_inputs = self._keras_get_constant_inputs() + # Type not Optional anymore as as `_keras_set_constant_inputs` has been called. + reference_inputs: List[Tensor] + reference_inputs = self._keras_get_constant_inputs() # type: ignore return [keras.layers.Subtract()([x, ri]) for x, ri in zip(X, reference_inputs)] def _augment(self, X): - tmp = super(PathIntegrator, self)._augment(X) + tmp = super()._augment(X) tmp = [ - ilayers.Reshape((-1, self._augment_by_n) + K.int_shape(x)[1:])(x) + ilayers.Reshape((-1, self._augment_by_n) + kbackend.int_shape(x)[1:])(x) for x in tmp ] @@ -293,7 +304,7 @@ def _augment(self, X): self._keras_difference = difference # Make broadcastable. difference = [ - ilayers.Reshape((-1, 1) + K.int_shape(x)[1:])(x) for x in difference + ilayers.Reshape((-1, 1) + kbackend.int_shape(x)[1:])(x) for x in difference ] # Compute path steps. @@ -304,27 +315,30 @@ def _augment(self, X): reference_inputs = self._keras_get_constant_inputs() ret = [keras.layers.Add()([x, p]) for x, p in zip(reference_inputs, path_steps)] - ret = [ilayers.Reshape((-1,) + K.int_shape(x)[2:])(x) for x in ret] + ret = [ilayers.Reshape((-1,) + kbackend.int_shape(x)[2:])(x) for x in ret] return ret def _reduce(self, X): - tmp = super(PathIntegrator, self)._reduce(X) + tmp = super()._reduce(X) difference = self._keras_difference del self._keras_difference return [keras.layers.Multiply()([x, d]) for x, d in zip(tmp, difference)] def _get_state(self): - state = super(PathIntegrator, self)._get_state() + state = super()._get_state() state.update({"reference_inputs": self._reference_inputs}) return state @classmethod - def _state_to_kwargs(clazz, state): + def _state_to_kwargs(cls, state): reference_inputs = state.pop("reference_inputs") - kwargs = super(PathIntegrator, clazz)._state_to_kwargs(state) - kwargs.update({"reference_inputs": reference_inputs}) + # call super after popping class-specific states: + kwargs = super()._state_to_kwargs(state) + # We use steps instead. - kwargs.update({"steps": kwargs["augment_by_n"]}) + kwargs.update( + {"reference_inputs": reference_inputs, "steps": kwargs["augment_by_n"]} + ) del kwargs["augment_by_n"] return kwargs diff --git a/src/innvestigate/layers.py b/src/innvestigate/layers.py index cf598158..f1d010e1 100644 --- a/src/innvestigate/layers.py +++ b/src/innvestigate/layers.py @@ -404,7 +404,7 @@ def compute_output_shape( class Reshape(keras.layers.Layer): - def __init__(self, shape: Tuple[int, ...], *args, **kwargs): + def __init__(self, shape: Iterable[int], *args, **kwargs): super().__init__(*args, **kwargs) self._shape = shape diff --git a/src/innvestigate/utils/keras/graph.py b/src/innvestigate/utils/keras/graph.py index 36da163a..accaaf13 100644 --- a/src/innvestigate/utils/keras/graph.py +++ b/src/innvestigate/utils/keras/graph.py @@ -29,7 +29,6 @@ __all__ = [ "get_kernel", "get_layer_inbound_count", - "get_layer_outbound_count", "get_layer_neuronwise_io", "copy_layer_wo_activation", "copy_layer", @@ -756,8 +755,11 @@ def get_model_execution_trace( nid_to_nodes: Dict[Layer, Tuple[Optional[int], Layer, List[Tensor], List[Tensor]]] model_execution_trace: List[NodeDict] + # TODO: fix invariance of type hints Xs_nids: List[Optional[int]] Ys_nids: List[Union[List[int], List[None]]] + # Xs_layers = List[Layer] + # Ys_layers = List[List[Layer]] nid_to_nodes = {t[0]: t for t in id_execution_trace}