In [None]:
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import cv2

from preprocessing import preprocess_for_eval

from IPython.display import Image, display
from typing import Union
from glob import glob

%matplotlib inline

In [None]:
model_path = "ResnetV2_50.pb"

In [None]:
images = [cv2.imread(filename, cv2.COLOR_BGR2RGB) for filename in glob("data/*.jpeg")]
with tf.compat.v1.Session() as sess:
    images = np.stack([preprocess_for_eval(image, 224, 224, 256).eval() / 255 for image in images])
images.shape

In [None]:
# Load the graph from the .pb file
with tf.io.gfile.GFile(model_path, 'rb') as f:
    graph_def = tf.compat.v1.GraphDef()
    graph_def.ParseFromString(f.read())

In [None]:
with tf.compat.v1.Session() as sess:
    # Set the graph as the default graph
    tf.compat.v1.import_graph_def(graph_def, name='')
    graph = sess.graph

In [None]:
def get_tensor(tensor_name: str) -> np.ndarray:
    with tf.compat.v1.Session() as sess:
        # Set the graph as the default graph
        tf.compat.v1.import_graph_def(graph_def, name='')

        # Get input and output tensors
        input_tensor = sess.graph.get_tensor_by_name("input:0")
        output_tensor = sess.graph.get_tensor_by_name(tensor_name)

        # Perform inference
        input_data = images
        return sess.run(output_tensor, feed_dict={input_tensor: input_data})

In [None]:
def get_images(tensor_name: str, feature_id: int) -> Image:
    return Image(url=f"https://openaipublic.blob.core.windows.net/microscopeprod/2020-07-25/2020-07-25/resnetv2_50_slim/lucid.dataset_examples/_dataset_examples/dataset%3Dimagenet%26op%3D{tensor_name.replace('/', '%252F')}%253A0/channel_{feature_id}_40.png")

In [None]:
def to_op(tensor: Union[tf.Tensor, tf.Operation]) -> tf.Operation:
    if isinstance(tensor, tf.Tensor):
        return tensor.op
    return tensor

In [None]:
def go_backwards(layer: Union[tf.Operation, str], num_layers: int=1) -> tf.Tensor:
    if isinstance(layer, str):
        layer = graph.get_operation_by_name(layer)
    previous_layer = layer
    for _ in range(num_layers):
        previous_layer = list(to_op(previous_layer).inputs)[0]
        print(previous_layer)
    return to_op(previous_layer)


In [None]:
class_id = 2 # goldfish

In [None]:
graph.get_operations()[-10:]

In [None]:
logit_layer = "resnet_v2_50/logits/BiasAdd"
list(graph.get_operation_by_name(logit_layer).inputs) 

In [None]:
predictions = get_tensor(logit_layer + ":0")
predictions.shape

In [None]:
average_predictions = predictions[:, 0, 0, :].mean(0)
plt.hist(average_predictions)
top_5 = np.argsort(-average_predictions)[:5]
top_5

In [None]:
average_predictions[top_5]

In [None]:
(np.exp(average_predictions) / np.exp(average_predictions).sum())[top_5]

In [None]:
weights_output, bias_layer = graph.get_operation_by_name(logit_layer).inputs

In [None]:
weights = get_tensor(bias_layer.name)
plt.hist(weights, bins=100)
weights[class_id]

In [None]:
weights_layer = graph.get_operation_by_name(to_op(weights_output).name).inputs[1]
weights_layer

In [None]:
weight_matrix = get_tensor(weights_layer.name)
weight_matrix.shape

In [None]:
relevant_weights = weight_matrix[0, 0, :, class_id]
ordering = (-relevant_weights).argsort()
relevant_weights.min(), relevant_weights.max()

In [None]:
plt.hist(relevant_weights, bins=100)
None

In [None]:
important_features = ordering[:5]
important_features

In [None]:
relevant_weights[important_features]

In [None]:
main_feature = important_features[0]
main_feature

In [None]:
pooling_layer = graph.get_operation_by_name(to_op(weights_output).name).inputs[0]
pooling_layer

In [None]:
pooled_outputs = get_tensor(pooling_layer.name)
pooled_outputs.shape

In [None]:
activations = pooled_outputs.mean(axis=0)[0, 0]
ordering = np.argsort(-activations)
plt.hist(activations)
ordering[:5]

In [None]:
previous_layer = logit_layer
batchnorm_layer = go_backwards(previous_layer, 4)

In [None]:
batchnorm_inputs = list(batchnorm_layer.inputs)
batchnorm_inputs

In [None]:
previous_layer, *params = [get_tensor(layer.name) for layer in batchnorm_inputs]

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
axs = axs.flatten()
for i, param in enumerate(params):
    axs[i].hist(param.flatten(), bins=100)
    axs[i].set_title(batchnorm_inputs[i+1].name)

In [None]:
gamma, beta, mean, variance = np.stack(params, axis=0)[:,main_feature]
gamma, beta, mean, variance

In [None]:
f"f(x) = {gamma:.2f} * (x {-mean:+.2f})/{variance:.2f} {beta:+.2f}"

In [None]:
gammas, betas, means, variances = np.stack(params, axis=0)
expected_features = (np.ones_like(gammas) - means) / variances + betas
plt.hist(expected_features, bins=100)
None

In [None]:
expected_inputs = (np.ones_like(gammas) - means) / variances + betas
expected_inputs = -betas * variances + means
plt.hist(expected_inputs, bins=100)
None

In [None]:
feature_layer = batchnorm_inputs[0].op.name
for i in range(5):
    display(get_images(feature_layer, important_features[i]))

In [None]:
display(get_images(feature_layer, ordering[-1]))

In [None]:
conv_out = batchnorm_inputs[0]
conv_out

In [None]:
conv_inputs = list(to_op(conv_out).inputs)
conv_inputs

In [None]:
previous_unit, current_layer = conv_inputs
bias_add_inputs = list(go_backwards(current_layer).inputs)
bias_add_inputs

In [None]:
bias_values = get_tensor(current_layer.name).flatten()
plt.hist(bias_values, bins=100)
bias_values[main_feature]

In [None]:
relu, conv_weights = bias_add_inputs
conv_weights

In [None]:
relevant_weights = get_tensor(conv_weights.name)[0, 0, :, main_feature]
plt.hist(relevant_weights, bins=100)
ordering = (-relevant_weights).argsort()

In [None]:
important_features = ordering[:5]
important_features

In [None]:
activations = get_tensor(relu.name)
activations.shape

In [None]:
average_activations = activations.mean((0, 1, 2))
plt.hist(average_activations, bins=100)
ordering = (-average_activations).argsort()
ordering[:5]

In [None]:
average_activations[ordering[:5]]

In [None]:
for i in range(5):
    display(get_images(to_op(relu).name, important_features[i]))

In [None]:
for i in range(5):
    display(get_images(to_op(relu).name, ordering[i]))

In [None]:
key_feature = important_features[2]
key_feature

In [None]:
conv_in = go_backwards(conv_out)
conv_in

In [None]:
skip, add = conv_inputs