In [None]:
import tensorflow as tf
import seaborn as sns
import numpy as np
import matplotlib

from IPython.display import Image, display
from typing import Union
from glob import glob

%matplotlib inline
matplotlib.rcParams['figure.figsize'] = (20, 8)

In [None]:
model_path = "ResnetV2_50.pb"
class_id = 773 # safety pin (classes are 1-indexed)

In [None]:
# Load the graph from the .pb file
with tf.io.gfile.GFile(model_path, 'rb') as f:
    graph_def = tf.compat.v1.GraphDef()
    graph_def.ParseFromString(f.read())

In [None]:
with tf.compat.v1.Session() as sess:
    # Set the graph as the default graph
    tf.compat.v1.import_graph_def(graph_def, name='')
    graph = sess.graph

In [None]:
def get_tensor(tensor_name: str) -> np.ndarray:
    with tf.compat.v1.Session() as sess:
        # Set the graph as the default graph
        tf.compat.v1.import_graph_def(graph_def, name='')

        # Get input and output tensors
        input_tensor = sess.graph.get_tensor_by_name("input:0")
        output_tensor = sess.graph.get_tensor_by_name(tensor_name)

        # Perform inference
        input_data = np.random.randn(1, 224, 224, 3)
        return sess.run(output_tensor, feed_dict={input_tensor: input_data})

In [None]:
def get_images(tensor_name: str, feature_id: int) -> Image:
    return Image(url=f"https://openaipublic.blob.core.windows.net/microscopeprod/2020-07-25/2020-07-25/resnetv2_50_slim/lucid.dataset_examples/_dataset_examples/dataset%3Dimagenet%26op%3D{tensor_name.replace('/', '%252F')}%253A0/channel_{feature_id}_40.png")

In [None]:
def to_op(tensor: Union[tf.Tensor, tf.Operation]) -> tf.Operation:
    if isinstance(tensor, tf.Tensor):
        return tensor.op
    return tensor

In [None]:
def go_backwards(layer: Union[tf.Operation, str], num_layers: int=1) -> tf.Tensor:
    if isinstance(layer, str):
        layer = graph.get_operation_by_name(layer)
    previous_layer = layer
    for _ in range(num_layers):
        previous_layer = list(to_op(previous_layer).inputs)[0]
        print(previous_layer)
    return to_op(previous_layer)


In [None]:
graph.get_operations()[-10:]

In [None]:
output = graph.get_operations()[-1]
linear_layer = go_backwards(output, 5)
linear_layer

In [None]:
list(linear_layer.inputs)

In [None]:
weight_matrix = get_tensor(linear_layer.inputs[1].name)
weight_matrix.shape

In [None]:
relevant_weights = weight_matrix[0, 0, :, class_id]
ordering = (-relevant_weights).argsort()
relevant_weights.min(), relevant_weights.max()

In [None]:
sns.histplot(relevant_weights, bins=100)
None

In [None]:
important_features = ordering[:5]
important_features

In [None]:
relevant_weights[important_features]

In [None]:
main_feature = important_features[0]
main_feature

In [None]:
residual_stream = go_backwards(linear_layer, 4)
residual_stream

In [None]:
display(get_images(residual_stream.name, main_feature))

In [None]:
residual_stream = go_backwards(residual_stream)
residual_stream

In [None]:
display(get_images(residual_stream.name, main_feature))

In [None]:
residual_stream = go_backwards(residual_stream)
residual_stream

In [None]:
display(get_images(residual_stream.name, main_feature))