Skip to content

Commit

Permalink
z_plus_fast (which must have some bug)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastian-lapuschkin committed Mar 26, 2018
1 parent 1bf2dba commit 6517058
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 35 deletions.
1 change: 1 addition & 0 deletions innvestigate/analyzer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def create_analyzer(name, model, **kwargs):
"lrp.alpha_1_beta_0": LRPAlpha1Beta0,
"lrp.alpha_1_beta_0_IB": LRPAlpha1Beta0IgnoreBias,
"lrp.z_plus": LRPZPlus,
"lrp.z_plus_fast": LRPZPlusFast,

# Pattern based
"pattern.net": PatternNet,
Expand Down
98 changes: 63 additions & 35 deletions innvestigate/analyzer/relevance_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"LRPAlpha1Beta0",
"LRPAlpha1Beta0IgnoreBias",
"LRPZPlus",
"LRPZPlusFast",
]


Expand Down Expand Up @@ -309,35 +310,6 @@ def __init__(self, *args, **kwargs):
**kwargs)


#TODO: make subclass of AlphaBetaRule, after fix for alphabeta
#TODO: fix computation of z+ to not depend on positive inputs, but positive preactivations
class ZPlusRule(kgraph.ReverseMappingBase):

def __init__(self, layer, state):
# The z-plus rule only works with positive weights and
# no biases.
self._layer_wo_act_b_positive = kgraph.copy_layer_wo_activation(
layer, keep_bias=False,
name_template="reversed_kernel_positive_%s")
tmp = [x * (x > 0)
for x in self._layer_wo_act_b_positive.get_weights()]
self._layer_wo_act_b_positive.set_weights(tmp)

def apply(self, Xs, Ys, Rs, reverse_state):
grad = ilayers.GradientWRT(len(Xs))

# Get activations.
Zs = kutils.apply(self._layer_wo_act_b_positive, Xs)
# Divide incoming relevance by the activations.
tmp = [ilayers.SafeDivide()([a, b])
for a, b in zip(Rs, Zs)]
# Propagate the relevance to input neurons
# using the gradient.
tmp = iutils.to_list(grad(Xs+Zs+tmp))
# Re-weight relevance with the input values.
return [keras.layers.Multiply()([a, b])
for a, b in zip(Xs, tmp)]


class EpsilonRule(kgraph.ReverseMappingBase):
"""
Expand Down Expand Up @@ -619,6 +591,52 @@ def f(Xs):
return tmp


class ZPlusRule(Alpha1Beta0IgnoreBiasRule):
"""
The ZPlus rule is a special case of the AlphaBetaRule
for alpha=1, beta=0, which assumes inputs x >= 0
and ignores the bias.
"""
#TODO: assert that layer inputs are always >= 0
def __init__(self, *args, **kwargs):
super(Alpha1Beta0IgnoreBiasRule, self).__init__(*args,
**kwargs)


class ZPlusFastRule(kgraph.ReverseMappingBase):
"""
The ZPlus rule is a special case of the AlphaBetaRule
for alpha=1, beta=0 and assumes inputs x >= 0.
"""

def __init__(self, layer, state):
# The z-plus rule only works with positive weights and
# no biases.
#TODO: assert that layer inputs are always >= 0
self._layer_wo_act_b_positive = kgraph.copy_layer_wo_activation(
layer, keep_bias=False,
name_template="reversed_kernel_positive_%s")
tmp = [x * (x > 0)
for x in self._layer_wo_act_b_positive.get_weights()]
self._layer_wo_act_b_positive.set_weights(tmp)

def apply(self, Xs, Ys, Rs, reverse_state):
grad = ilayers.GradientWRT(len(Xs))

# Get activations.
Zs = kutils.apply(self._layer_wo_act_b_positive, Xs)
# Divide incoming relevance by the activations.
tmp = [ilayers.SafeDivide()([a, b])
for a, b in zip(Rs, Zs)]
# Propagate the relevance to input neurons
# using the gradient.
tmp = iutils.to_list(grad(Xs+Zs+tmp))
# Re-weight relevance with the input values.
return [keras.layers.Multiply()([a, b])
for a, b in zip(Xs, tmp)]



# alpha-beta all networks
# bias+- for some other rules
LRP_RULES = {
Expand All @@ -640,10 +658,13 @@ def f(Xs):
"Alpha1Beta0IgnoreBias": Alpha1Beta0IgnoreBiasRule,

"ZPlus": ZPlusRule,
"ZPlusFast": ZPlusFastRule,
"Bounded": BoundedRule,
}




###############################################################################
###############################################################################
###############################################################################
Expand Down Expand Up @@ -838,12 +859,6 @@ def __init__(self, model, *args, **kwargs):
rule="ZIgnoreBias", **kwargs)


class LRPZPlus(_LRPFixedParams):

def __init__(self, model, *args, **kwargs):
super(LRPZPlus, self).__init__(model, *args,
rule="ZPlus", **kwargs)


class LRPEpsilon(_LRPFixedParams):

Expand Down Expand Up @@ -995,3 +1010,16 @@ def __init__(self, model, *args, **kwargs):
beta=0,
bias=False,
**kwargs)

class LRPZPlus(LRPAlpha1Beta0IgnoreBias):
#TODO: assert that layer inputs are always >= 0
def __init__(self, model, *args, **kwargs):
super(LRPZPlus, self).__init__(model, *args, **kwargs)


#TODO: decide whether to remove or reinstate
class LRPZPlusFast(_LRPFixedParams):
#TODO: assert that layer inputs are always >= 0
def __init__(self, model, *args, **kwargs):
super(LRPZPlusFast, self).__init__(model, *args,
rule="ZPlusFast", **kwargs)

0 comments on commit 6517058

Please sign in to comment.