## VGG 16 viz Filters

#### author info
* by Aven Le ZHOU (https://www.aven.cc)
* artMahcines & NYU Shanghai
* aiarts, spring 2020
* https://github.com/artmachines/aiarts2020

#### credits
* Original code implementation from keras team [link](https://keras.io/examples/conv_filter_visualization/)
* adapted by Aven for aiarts 2020 course


In [35]:
use_plaidML_backend = True
use_colab = False

if use_plaidML_backend and (not use_colab): 
    import os
    os.environ["KERAS_BACKEND"]="plaidml.keras.backend"

In [27]:
from keras import layers
from keras import backend as K
from keras.preprocessing.image import save_img
from keras.applications import vgg16
from PIL import Image as pil_image

import time
import numpy as np

In [28]:
def normalize(x): return (x / (K.sqrt(K.mean(K.square(x))) + K.epsilon()))

In [31]:
def decode_image(x):
    """convert a float array into a valid uint8 image. x: A numpy-array representing the generated image."""
    # normalize tensor: center on 0., ensure std is 0.25
    x -= x.mean()
    x /= (x.std() + K.epsilon())
    x *= 0.25

    # clip to [0, 1]
    x += 0.5
    x = np.clip(x, 0, 1)

    # convert to RGB array
    x *= 255
    if K.image_data_format() == 'channels_first': x = x.transpose((1, 2, 0))
    x = np.clip(x, 0, 255).astype('uint8')
    return x

In [34]:
def encode_image(x, former):
    """convert a valid uint8 image back into a float array. Reverses `decode_image`."""
    if K.image_data_format() == 'channels_first': x = x.transpose((2, 0, 1))
    return (x / 255 - 0.5) * 4 * former.std() + former.mean()

### visualize_layer: visualizes the most relevant filters of one conv-layer in a certain model.
* model: The model containing layer_name.
* layer_name: The name of the layer to be visualized. Has to be a part of model.
* step: step size for gradient ascent.
* epochs: Number of iterations for gradient ascent.
* upscaling_steps: Number of upscaling steps. Starting image is in this case (80, 80).
* upscaling_factor: Factor to which to slowly upgrade the image towards output_dim.
* output_dim: [img_width, img_height] The output image dimensions.
* filter_range: Tupel[lower, upper]
              Determines the to be computed filter numbers.
              If the second value is `None`,
              the last filter will be inferred as the upper boundary.


In [7]:
def visualize_layer(model,
                    layer_name,
                    step=1.,
                    epochs=15,
                    upscaling_steps=9,
                    upscaling_factor=1.2,
                    output_dim=(412, 412),
                    filter_range=(0, None)):


    def _generate_filter_image(input_img,
                               layer_output,
                               filter_index):
        """
            Generates image for one particular filter.
            input_img: The input-image Tensor.
            layer_output: The output-image Tensor.
            filter_index: The to be processed filter number. Assumed to be valid.
            Returns: a tuple of the image (array) itself and the last loss. or None if no image could be generated.
        """
        s_time = time.time()

        # we build a loss function that maximizes the activation of the nth filter of the layer considered
        if K.image_data_format() == 'channels_first':
            loss = K.mean(layer_output[:, filter_index, :, :])
        else:
            loss = K.mean(layer_output[:, :, :, filter_index])

        # we compute the gradient of the input picture wrt this loss
        grads = K.gradients(loss, input_img)[0]

        # normalization trick: we normalize the gradient
        grads = normalize(grads)

        # this function returns the loss and grads given the input picture
        iterate = K.function([input_img], [loss, grads])

        # we start from a gray image with some random noise
        intermediate_dim = tuple(
            int(x / (upscaling_factor ** upscaling_steps)) for x in output_dim)
        if K.image_data_format() == 'channels_first':
            input_img_data = np.random.random(
                (1, 3, intermediate_dim[0], intermediate_dim[1]))
        else:
            input_img_data = np.random.random(
                (1, intermediate_dim[0], intermediate_dim[1], 3))
        input_img_data = (input_img_data - 0.5) * 20 + 128

        # Slowly upscaling towards the original size prevents
        # a dominating high-frequency of the to visualized structure
        # as it would occur if we directly compute the 412d-image.
        # Behaves as a better starting point for each following dimension
        # and therefore avoids poor local minima
        for up in reversed(range(upscaling_steps)):
            # we run gradient ascent for e.g. 20 steps
            for _ in range(epochs):
                loss_value, grads_value = iterate([input_img_data])
                input_img_data += grads_value * step

                # some filters get stuck to 0, we can skip them
                if loss_value <= K.epsilon():
                    return None

            # Calculate upscaled dimension
            intermediate_dim = tuple(
                int(x / (upscaling_factor ** up)) for x in output_dim)
            # Upscale
            img = decode_image(input_img_data[0])
            img = np.array(pil_image.fromarray(img).resize(intermediate_dim,
                                                           pil_image.BICUBIC))
            input_img_data = np.expand_dims(
                encode_image(img, input_img_data[0]), 0)

        # decode the resulting input image
        img = decode_image(input_img_data[0])
        e_time = time.time()
        print('Costs of filter {:3}: {:5.0f} ( {:4.2f}s )'.format(filter_index,
                                                                  loss_value,
                                                                  e_time - s_time))
        return img, loss_value

    def _draw_filters(filters, n=None):
        """Draw the best filters in a nxn grid.

        # Arguments
            filters: A List of generated images and their corresponding losses
                     for each processed filter.
            n: dimension of the grid.
               If none, the largest possible square will be used
        """
        if n is None:
            n = int(np.floor(np.sqrt(len(filters))))

        # the filters that have the highest loss are assumed to be better-looking.
        # we will only keep the top n*n filters.
        filters.sort(key=lambda x: x[1], reverse=True)
        filters = filters[:n * n]

        # build a black picture with enough space for
        # e.g. our 8 x 8 filters of size 412 x 412, with a 5px margin in between
        MARGIN = 5
        width = n * output_dim[0] + (n - 1) * MARGIN
        height = n * output_dim[1] + (n - 1) * MARGIN
        stitched_filters = np.zeros((width, height, 3), dtype='uint8')

        # fill the picture with our saved filters
        for i in range(n):
            for j in range(n):
                img, _ = filters[i * n + j]
                width_margin = (output_dim[0] + MARGIN) * i
                height_margin = (output_dim[1] + MARGIN) * j
                stitched_filters[
                    width_margin: width_margin + output_dim[0],
                    height_margin: height_margin + output_dim[1],
                    :] = img

        # save the result to disk
        save_img('vgg_{0:}_{1:}x{1:}.png'.format(layer_name, n), stitched_filters)

    # this is the placeholder for the input images
    assert len(model.inputs) == 1
    input_img = model.inputs[0]

    # get the symbolic outputs of each "key" layer (we gave them unique names).
    layer_dict = dict([(layer.name, layer) for layer in model.layers[1:]])

    output_layer = layer_dict[layer_name]
    assert isinstance(output_layer, layers.Conv2D)

    # Compute to be processed filter range
    filter_lower = filter_range[0]
    filter_upper = (filter_range[1]
                    if filter_range[1] is not None
                    else len(output_layer.get_weights()[1]))
    assert(filter_lower >= 0
           and filter_upper <= len(output_layer.get_weights()[1])
           and filter_upper > filter_lower)
    print('Compute filters {:} to {:}'.format(filter_lower, filter_upper))

    # iterate through each filter and generate its corresponding image
    processed_filters = []
    for f in range(filter_lower, filter_upper):
        img_loss = _generate_filter_image(input_img, output_layer.output, f)

        if img_loss is not None:
            processed_filters.append(img_loss)

    print('{} filter processed.'.format(len(processed_filters)))
    # Finally draw and store the best filters to disk
    _draw_filters(processed_filters)


In [8]:
if __name__ == '__main__':
    # the name of the layer we want to visualize
    # (see model definition at keras/applications/vgg16.py)
    LAYER_NAME = 'block5_conv1'

    # build the VGG16 network with ImageNet weights
    vgg = vgg16.VGG16(weights='imagenet', include_top=False)
    print('Model loaded.')
    vgg.summary()

    # example function call
    visualize_layer(vgg, LAYER_NAME)

INFO:plaidml:Opening device "metal_amd_radeon_pro_560x.0"


Model loaded.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, None, None, 3)     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, None, None, 64)    1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, None, None, 64)    36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, None, None, 64)    0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, None, None, 128)   73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, None, None, 128)   147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, None, None, 128)   0      

Costs of filter 227:   438 ( 52.32s )
Costs of filter 233:   385 ( 51.77s )
Costs of filter 234:   697 ( 51.82s )
Costs of filter 235:   783 ( 52.33s )
Costs of filter 236:   563 ( 50.16s )
Costs of filter 237:   567 ( 53.14s )
Costs of filter 238:   591 ( 52.09s )
Costs of filter 240:   518 ( 51.74s )
Costs of filter 241:   699 ( 50.93s )
Costs of filter 243:   759 ( 51.03s )
Costs of filter 245:  1028 ( 51.58s )
Costs of filter 246:   348 ( 53.21s )
Costs of filter 247:   516 ( 49.59s )
Costs of filter 249:   572 ( 50.47s )
Costs of filter 250:  1017 ( 50.36s )


INFO:plaidml:Analyzing Ops: 79 of 93 operations complete


Costs of filter 251:   480 ( 53.58s )
Costs of filter 255:   653 ( 54.41s )
Costs of filter 256:   448 ( 53.82s )
Costs of filter 257:   576 ( 50.82s )
Costs of filter 258:   607 ( 49.20s )
Costs of filter 262:   496 ( 48.93s )
Costs of filter 263:   557 ( 52.42s )
Costs of filter 264:   598 ( 51.22s )
Costs of filter 265:   827 ( 50.46s )
Costs of filter 267:   393 ( 50.30s )
Costs of filter 268:   484 ( 51.64s )
Costs of filter 269:   627 ( 49.95s )
Costs of filter 270:  1375 ( 50.61s )
Costs of filter 271:   603 ( 50.29s )
Costs of filter 272:   658 ( 50.81s )
Costs of filter 275:   546 ( 50.07s )
Costs of filter 278:   636 ( 49.84s )
Costs of filter 279:   662 ( 49.56s )
Costs of filter 280:   269 ( 50.66s )
Costs of filter 281:   579 ( 50.36s )
Costs of filter 282:   470 ( 50.91s )
Costs of filter 285:   613 ( 50.45s )
Costs of filter 286:  1008 ( 62.05s )
Costs of filter 287:   455 ( 50.78s )
Costs of filter 289:   467 ( 49.36s )
Costs of filter 291:   417 ( 50.06s )
Costs of fil

INFO:plaidml:Analyzing Ops: 79 of 93 operations complete
INFO:plaidml:Analyzing Ops: 79 of 93 operations complete


Costs of filter 442:   715 ( 60.47s )
Costs of filter 445:   423 ( 56.33s )
Costs of filter 446:   800 ( 56.98s )
Costs of filter 448:   469 ( 55.24s )
Costs of filter 449:   542 ( 56.16s )
Costs of filter 452:   703 ( 58.34s )
Costs of filter 453:   622 ( 58.60s )
Costs of filter 457:   621 ( 57.18s )
Costs of filter 458:   709 ( 55.61s )
Costs of filter 460:   726 ( 55.28s )
Costs of filter 461:   441 ( 53.11s )
Costs of filter 462:   467 ( 53.96s )


INFO:plaidml:Analyzing Ops: 79 of 93 operations complete


Costs of filter 465:   494 ( 58.49s )
Costs of filter 467:   455 ( 57.15s )
Costs of filter 470:   426 ( 55.58s )
Costs of filter 473:   726 ( 56.95s )
Costs of filter 474:   437 ( 58.14s )
Costs of filter 475:   445 ( 60.83s )
Costs of filter 476:   520 ( 58.98s )


INFO:plaidml:Analyzing Ops: 86 of 93 operations complete


Costs of filter 478:   442 ( 51.10s )
Costs of filter 481:   635 ( 50.34s )
Costs of filter 482:   316 ( 48.91s )
Costs of filter 483:   517 ( 51.79s )
Costs of filter 484:   327 ( 53.29s )
Costs of filter 485:   977 ( 50.52s )
Costs of filter 487:   595 ( 50.52s )
Costs of filter 489:   626 ( 51.94s )
Costs of filter 490:   634 ( 50.24s )
Costs of filter 493:   576 ( 52.19s )
Costs of filter 494:   801 ( 50.55s )
Costs of filter 495:   624 ( 53.31s )
Costs of filter 496:   472 ( 50.43s )
Costs of filter 499:   656 ( 51.67s )
Costs of filter 500:   779 ( 53.74s )
Costs of filter 501:   626 ( 54.83s )
Costs of filter 502:   493 ( 53.40s )
Costs of filter 503:   458 ( 50.64s )
Costs of filter 504:   994 ( 48.24s )
Costs of filter 505:   873 ( 50.72s )
Costs of filter 506:   453 ( 49.25s )
Costs of filter 509:   460 ( 52.90s )
Costs of filter 510:   551 ( 50.58s )
Costs of filter 511:   901 ( 49.49s )
308 filter processed.
