In [None]:
import keras
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [None]:
model = keras.applications.vgg16.VGG16(weights = 'imagenet',
              include_top = False)
model.summary()

In [None]:
layer_name = 'block5_conv3'
filter_index = 0
layer_output = model.get_layer(layer_name).output

In [None]:
from keras import models
feature_extractor = models.Model(inputs = model.inputs,
                                 outputs = layer_output )
def compute_loss(image,filter_index):
  activation = feature_extractor(image)
  filter_activation = activation[:,:,:,filter_index]
  return tf.reduce_mean(filter_activation)

In [None]:
@tf.function
def gradient_ascent_step(image, filter_index, learning_rate):
  with tf.GradientTape() as tape:
    tape.watch(image)
    loss = compute_loss(image, filter_index)
  grads = tape.gradient(loss, image)
  grads = tf.math.l2_normalize(grads)
  image += learning_rate * grads
  return image

In [None]:
img_width = 200
img_height = 200

def generate_filter_pattern(filter_index):
  iterations = 30
  learning_rate = 10.
  image = tf.random.uniform(
      minval = 0.4,
      maxval = 0.6,
      shape = (1, img_width, img_height,3))
  for i in range(iterations):
    image = gradient_ascent_step(image, filter_index, learning_rate)
  return image[0].numpy()

In [None]:
def deprocess_image(image):
  image -= image.mean()
  image /= image.std()
  image *= 64
  image += 128
  image = np.clip(image, 0,255).astype('uint8')
  image = image[25:-25,25:-25,:]
  return image

In [None]:
all_images = []
for filter_index in range(64):
  print(f"Processing filter {filter_index}")
  image = deprocess_image(generate_filter_pattern(filter_index))
  all_images.append(image)

margin = 2
n = 8
cropped_width = img_width - 25 * 2
cropped_height = img_height - 25 * 2
width = n * cropped_width + (n - 1) * margin
height = n * cropped_height + ( n - 1) * margin
stitched_filters = np.zeros((width, height, 3))

for i in range(n):
  for j in range(n):
    image = all_images[i * n  + j]
    stitched_filters[(cropped_width + margin) * i:
        (cropped_width + margin) * i + cropped_width,
        (cropped_height + margin) * j:
        (cropped_height + margin) * j + cropped_height,: ] = image


In [None]:
plt.figure(figsize = (20,20))
plt.title(layer_name,fontsize = 50)
plt.imshow(stitched_filters)

In [None]:
keras.utils.save_img(f'filters_for_layer_{layer_name}.png', stitched_filters)
