# Setting up

In [1]:
# %pip install tensorflow==2.14.0
# %pip install innvestigate==2.1.2

In [2]:
import tensorflow as tf
tf.compat.v1.disable_eager_execution()

# Reconstruct Model

## Define the architecture that was made before

In [3]:
from tensorflow.keras.layers import Dense, Flatten, GlobalAveragePooling2D, Input
from tensorflow.keras.applications import EfficientNetB1
from tensorflow.keras.models import Model

input_shape = (256, 256, 3)
input_layer = Input(shape=input_shape)

base_model = EfficientNetB1(weights='imagenet', include_top=False, input_tensor=input_layer)
base_model.trainable=False
x = GlobalAveragePooling2D()(base_model.output)
# x = Flatten()(x)
x = Dense(64, activation='relu')(x)
x = Dense(32, activation='relu')(x)

output_layer = Dense(9, activation='softmax')(x)

model = Model(inputs=input_layer, outputs=output_layer)

Instructions for updating:
Colocations handled automatically by placer.


## Load the saved weight

In [4]:
model.load_weights('../src/models/efficient_net_model_weight.h5')

## Remove softmax

In [5]:
model = Model(inputs=model.input, outputs=model.layers[-2].output)
model.output.shape

TensorShape([None, 32])

# Define LRP rules

In [6]:
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 256, 256, 3)]        0         []                            
                                                                                                  
 rescaling (Rescaling)       (None, 256, 256, 3)          0         ['input_1[0][0]']             
                                                                                                  
 normalization (Normalizati  (None, 256, 256, 3)          7         ['rescaling[0][0]']           
 on)                                                                                              
                                                                                                  
 rescaling_1 (Rescaling)     (None, 256, 256, 3)          0         ['normalization[0][0]'] 

In [7]:
import tensorflow as tf

lrp_rules = [
    # Rule shape = (condition, rule)
    (lambda layer: isinstance(layer, tf.keras.layers.Conv2D) and ('block1' in layer.name or 'stem' in layer.name), "Flat"),
    (lambda layer: isinstance(layer, tf.keras.layers.Conv2D) and not ('block1' in layer.name or 'stem' in layer.name), "Alpha2Beta1"),
    (lambda layer: isinstance(layer, tf.keras.layers.Dense), "Epsilon"),
    (lambda layer: True, "Epsilon") # default
]

# Create LRP analyzer

In [16]:
# Create an LRP analyzer with layer-specific rules

# import innvestigate
# analyzer = innvestigate.create_analyzer("lrp", model, rule=lrp_rules)

from innvestigate.analyzer.relevance_based.relevance_analyzer import LRP
analyzer = LRP(model, rule=lrp_rules, reverse_verbose=True)

# Preprocess input image

In [17]:
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.efficientnet import preprocess_input
import numpy as np

# Load an image
img_path = "../data/test/Bacterial Leaf Blight/aug_0_14.jpg"
img = image.load_img(img_path, target_size=(256, 256))
img_array = image.img_to_array(img)
img_array = preprocess_input(img_array)
img_array = np.expand_dims(img_array, axis=0)

In [18]:
img_array.shape

(1, 256, 256, 3)

# Apply LRP

In [None]:
# Get relevance heatmap
relevance = analyzer.analyze(img_array)

# Post-process for visualization
relevance = relevance.squeeze()
relevance = relevance.sum(axis=-1)  # Convert to grayscale heatmap
relevance = np.maximum(relevance, 0)  # Remove negative relevance

Reverse model: <keras.src.engine.functional.Functional object at 0x0000020AE2E5FEB0>
[NID: 342] Reverse layer-node <innvestigate.layers.MaxNeuronSelection object at 0x0000020AE2E5FE80>
[NID: 341] Reverse layer-node <keras.src.layers.core.dense.Dense object at 0x0000020ACC10AAA0>
[NID: 340] Reverse layer-node <keras.src.layers.core.dense.Dense object at 0x0000020ACC10A500>
[NID: 339] Reverse layer-node <keras.src.layers.pooling.global_average_pooling2d.GlobalAveragePooling2D object at 0x0000020A8D93E260>


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Assuming 'relevance' is the result from analyzer.analyze()
# Already squeezed, summed, and ReLU-ed like you did:

# Normalize relevance to [0, 1]
relevance -= relevance.min()
if relevance.max() > 0:
    relevance /= relevance.max()

# Optional: Resize if needed (only if your model input was resized too)
# import cv2
# relevance = cv2.resize(relevance, (256, 256))

# Plot
plt.figure(figsize=(6, 6))
plt.imshow(relevance, cmap='jet')
plt.title("LRP Relevance Heatmap")
plt.axis('off')
plt.colorbar()
plt.show()