Skip to content

Commit

Permalink
Adding successful reversed gradient model.
Browse files Browse the repository at this point in the history
  • Loading branch information
albermax committed Dec 21, 2017
1 parent 74d0cf7 commit c048180
Show file tree
Hide file tree
Showing 15 changed files with 483 additions and 25 deletions.
20 changes: 20 additions & 0 deletions innvestigate/analyzer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Begin: Python 2/3 compatibility header small
# Get Python 3 functionality:
from __future__ import\
absolute_import, print_function, division, unicode_literals
from future.utils import raise_with_traceback, raise_from
# catch exception with: except Exception as e
from builtins import range, map, zip, filter
from io import open
import six
# End: Python 2/3 compatability header small


###############################################################################
###############################################################################
###############################################################################

from .base import *

Expand All @@ -7,6 +22,11 @@
#from .relevance_based import *


###############################################################################
###############################################################################
###############################################################################


def create_analyzer(name, moderl, **kwargs):
return {
# Utility.
Expand Down
41 changes: 33 additions & 8 deletions innvestigate/analyzer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,28 @@
# End: Python 2/3 compatability header small


###############################################################################
###############################################################################
###############################################################################


from .. import layers as ilayers
from ..utils.keras import graph

import keras.layers
import keras.models


__all__ = ["BaseAnalyzer", "BaseNetworkAnalyzer"]
__all__ = [
"BaseAnalyzer",
"BaseNetworkAnalyzer",
"BaseReverseNetworkAnalyzer"
]


###############################################################################
###############################################################################
###############################################################################


class BaseAnalyzer(object):
Expand All @@ -34,15 +49,13 @@ def explain(self, X):
raise NotImplementedError("Has to be implemented by the subclass")


###############################################################################
###############################################################################
###############################################################################


class BaseNetworkAnalyzer(BaseAnalyzer):

properties = {
"name": "undefined",
"show_as": "undefined",
}

def __init__(self, model, neuron_selection_mode="max_activation"):
super(BaseNetworkAnalyzer, self).__init__(model)

Expand All @@ -61,13 +74,13 @@ def __init__(self, model, neuron_selection_mode="max_activation"):
raise NotImplementedError("Only a stub present so far.")
neuron_indexing = keras.layers.Input(shape=[None, None])
neuron_selection_inputs += neuron_indexing

model_output = keras.layers.Index()([model_output, neuron_indexing])

model = keras.models.Model(inputs=model_inputs+neuron_selection_inputs,
outputs=model_output)
analysis_output = self._create_analysis(model)

self._analyzer_model = keras.models.Model(
inputs=model_inputs+neuron_selection_inputs,
outputs=analysis_output)
Expand All @@ -91,3 +104,15 @@ def analyze(self, X, neuron_selection=None):
return self._analyzer_model.predict_on_batch(X, neuron_selection)
else:
return self._analyzer_model.predict_on_batch(X)


class BaseReverseNetworkAnalyzer(BaseNetworkAnalyzer):

# Should be specified by the base class.
reverse_mappings = {}
default_reverse = None

def _create_analysis(self, model):
return graph.reverse_model(model,
reverse_mapping=self.reverse_mappings,
default_reverse=self.default_reverse)
45 changes: 39 additions & 6 deletions innvestigate/analyzer/gradient_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,32 @@
# End: Python 2/3 compatability header small


__all__ = ["BaselineGradientAnalyzer"]
###############################################################################
###############################################################################
###############################################################################


from . import base
from .. import layers as ilayers
from .. import utils

import keras.backend as K
import keras.models


__all__ = [
"BaselineGradientAnalyzer",
"GradientAnalyzer",
]


###############################################################################
###############################################################################
###############################################################################


class BaselineGradientAnalyzer(base.BaseNetworkAnalyzer):

properties = {
"name": "BaselineGradient",
"show_as": "rgb",
Expand All @@ -29,7 +44,25 @@ class BaselineGradientAnalyzer(base.BaseNetworkAnalyzer):

def _create_analysis(self, model):
return ilayers.Gradient()(model.inputs+[model.outputs[0],])
import keras.layers
ret = keras.layers.Lambda(lambda x: K.gradients(x[1].sum(), x[0]))([model.inputs[0], model.outputs[0]])
print(type(ret[0]))
return ret[0]


class GradientAnalyzer(base.BaseReverseNetworkAnalyzer):

properties = {
"name": "Gradient",
"show_as": "rgb",
}

def __init__(self, *args, **kwargs):
# we assume there is only one head!
gradient_head_processed = [False]
def gradient_reverse(Xs, Ys, reversed_Ys, reverse_state):
if gradient_head_processed[0] is not True:
# replace function value with ones as the last element
# chain rule is a one.
gradient_head_processed[0] = True
reversed_Ys = utils.listify(ilayers.OnesLike()(reversed_Ys))
return ilayers.GradientWRT(len(Xs))(Xs+Ys+reversed_Ys)

self.default_reverse = gradient_reverse
return super(GradientAnalyzer, self).__init__(*args, **kwargs)
10 changes: 10 additions & 0 deletions innvestigate/analyzer/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
# End: Python 2/3 compatability header small


###############################################################################
###############################################################################
###############################################################################


import numpy as np

from .base import BaseAnalyzer
Expand All @@ -18,6 +23,11 @@
__all__ = ["RandomAnalyzer", "InputAnalyzer"]


###############################################################################
###############################################################################
###############################################################################


class InputAnalyzer(BaseAnalyzer):

properties = {
Expand Down
57 changes: 56 additions & 1 deletion innvestigate/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,43 @@
# End: Python 2/3 compatability header small


###############################################################################
###############################################################################
###############################################################################


from . import utils
from .utils.keras import backend as iK

import keras
import keras.backend as K
from keras.engine.topology import Layer


__all__ = ["Gradient", "Max", "Sum"]
__all__ = [
"OnesLike",

"Gradient",
"GradientWRT",

"Max",
"Sum",
]


###############################################################################
###############################################################################
###############################################################################


class OnesLike(keras.layers.Layer):
def call(self, x):
return [K.ones_like(tmp) for tmp in utils.listify(x)]


###############################################################################
###############################################################################
###############################################################################


class Gradient(keras.layers.Layer):
Expand All @@ -29,6 +60,30 @@ def compute_output_shape(self, input_shapes):
return input_shapes[:-1]


class GradientWRT(keras.layers.Layer):
"Returns gradient wrt to another layer and given gradient,"
" expects inputs+[output,]."

def __init__(self, n_inputs, *args, **kwargs):
self.n_inputs = n_inputs
super(GradientWRT, self).__init__(*args, **kwargs)

def call(self, x):
Xs, tmp_Ys = x[:self.n_inputs], x[self.n_inputs:]
assert len(tmp_Ys) % 2 == 0
len_Ys = len(tmp_Ys) // 2
Ys, known_Ys = tmp_Ys[:len_Ys], tmp_Ys[len_Ys:]
return iK.gradients(Xs, Ys, known_Ys)

def compute_output_shape(self, input_shapes):
return input_shapes[:self.n_inputs]


###############################################################################
###############################################################################
###############################################################################


class Max(keras.layers.Layer):
"Returns maximum along the last dimension."

Expand Down
32 changes: 32 additions & 0 deletions innvestigate/tests/analyzer/gradient_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
import six
# End: Python 2/3 compatability header small


###############################################################################
###############################################################################
###############################################################################


# todo:fix relative imports:
#from ...utils.tests import dryrun

Expand All @@ -18,9 +24,35 @@
from innvestigate.utils.tests import dryrun

from innvestigate.analyzer import BaselineGradientAnalyzer
from innvestigate.analyzer import GradientAnalyzer


###############################################################################
###############################################################################
###############################################################################


class TestBaselineGradientAnalyzer(dryrun.AnalyzerTestCase):

def _method(self, model):
return BaselineGradientAnalyzer(model)


class TestGradientAnalyzer(dryrun.AnalyzerTestCase):

def _method(self, model):
return GradientAnalyzer(model)


###############################################################################
###############################################################################
###############################################################################


class TestBasicGraphReversalAnalyzer(dryrun.EqualAnalyzerTestCase):

def _method1(self, model):
return BaselineGradientAnalyzer(model)

def _method2(self, model):
return GradientAnalyzer(model)
10 changes: 10 additions & 0 deletions innvestigate/tests/analyzer/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
import six
# End: Python 2/3 compatability header small


###############################################################################
###############################################################################
###############################################################################

# todo:fix relative imports:
#from ...utils.tests import dryrun

Expand All @@ -21,6 +26,11 @@
from innvestigate.analyzer import RandomAnalyzer


###############################################################################
###############################################################################
###############################################################################


class TestInputAnalyzer(dryrun.AnalyzerTestCase):

def _method(self, model):
Expand Down
10 changes: 10 additions & 0 deletions innvestigate/tests/utils/tests/dryrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,21 @@
import six
# End: Python 2/3 compatability header small


###############################################################################
###############################################################################
###############################################################################

# Todo: fix:
#from ...utils.tests import dryrun
from innvestigate.utils.tests import dryrun


###############################################################################
###############################################################################
###############################################################################


class TestDryRunAnalyzerTestCase(dryrun.AnalyzerTestCase):
"""
Sanity test for the TestCase.
Expand Down

0 comments on commit c048180

Please sign in to comment.