# Attribution multiple images at once

This notebook computes masks for the 0.99 quantile and computes the attribution of many images

In [1]:
%load_ext autoreload
%autoreload 2
# imports
import glob
import logging
import os
import sys

sys.path.insert(0, os.path.abspath('..'))
sys.path.insert(0, os.path.abspath('../..'))

import cv2
import numpy as np
import tensorflow as tf
from PIL import Image
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm

from deepexplain.tf.v1_x import DeepExplain
from plot_utils import plot, plt
from tf_pose import common
from tf_pose.common import CocoPart
from tf_pose.estimator import BodyPart, TfPoseEstimator
from tf_pose.networks import get_graph_path


os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
logging.getLogger("tensorflow").setLevel(logging.CRITICAL)
logging.getLogger('TfPoseEstimatorRun').setLevel(logging.ERROR)
logging.getLogger('DeepExplain').setLevel(logging.ERROR)
logging.getLogger('TfPoseEstimator').setLevel(logging.ERROR)

Using tf version = 1.15.0


  import pandas.util.testing as tm


The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.




In [2]:
# params 432, 368
w, h = 432, 368
model = 'cmu'
resize_out_ratio = 2.0
e = TfPoseEstimator(get_graph_path(model), target_size=(w, h), trt_bool=False)
image_path = '../data/images/highFrequencyData/*.jpg'
test_data = [f for f in glob.glob(image_path)]
NORMALIZATION_FLAG = True


def return_uncertain_part(humans):
    part = None
    cur_min = 1.0
    if len(humans) == 0:
        return BodyPart(0, CocoPart.RShoulder.value, 0, 0, 0.0)
    for body_part in humans[0].body_parts.keys():
        if humans[0].body_parts[body_part].score < cur_min:
            part = humans[0].body_parts[body_part]
            cur_min = humans[0].body_parts[body_part].score
    return part

In [3]:
results = []
# compute attribution only for 99%
quantile = 0.995

for image_path in tqdm(test_data):
    current_stats = {}
    # read file
    current_stats["image"] = common.read_imgfile(image_path, w, h)
    # compute humans and draw them onto the image
    humans = e.inference(current_stats["image"], resize_to_default=(
        w > 0 and h > 0), upsample_size=resize_out_ratio)
    current_stats["image_result"] = TfPoseEstimator.draw_humans(
        current_stats["image"], humans, imgcopy=True)
    # compute the worst attribution part
    current_stats["part"] = return_uncertain_part(humans)

    # get the resulting confidence map
    current_stats["heatmap"] = e.heatMat[:, :, current_stats["part"].part_idx]

    # normalize mask
    if NORMALIZATION_FLAG:
        total_confidence_value = np.sum(current_stats["heatmap"])
        current_stats["heatmap"] /= total_confidence_value

    # compute a mask
    quant = np.quantile(current_stats["heatmap"], quantile)
    current_stats["mask"] = current_stats["heatmap"] > quant
    current_stats["mask"] = Image.fromarray(
        np.uint8(current_stats["mask"]*255))
    current_stats["mask"] = np.array(
        current_stats["mask"].resize((54, 46), Image.ANTIALIAS))

    # use peaks as the mask
#     peak = e.peaks[:, :, current_stats["part"].part_idx]
#     mask  = e.peaks[:, :, current_stats["part"].part_idx] >= np.max(e.peaks[:, :, current_stats["part"].part_idx])
#     peak[mask] = 1.0
#     peak = Image.fromarray(peak)
#     current_stats["mask"] = np.array(peak.resize((54, 46)))

    # get the current session
    sess = e.persistent_sess

    # Since we will explain it, the model has to be wrapped in a DeepExplain context
    with DeepExplain(session=sess, graph=e.graph) as de:

        input_tensor = e.tensor_image
        output_tensor = e.tensor_heatMat[:, :,
                                         :, current_stats["part"].part_idx]
        xs = tf.expand_dims(current_stats["image"], 0).eval(session=sess)
        xs = xs.astype('float64')

        Y_shape = [None] + [1, 46, 54]  # size of heatmaps

        ys = np.expand_dims(np.expand_dims(
            current_stats["mask"], axis=0), axis=0)

        baseline = np.zeros(list(xs.shape)[1:])  # baseline to compare against
        current_stats['Saliency maps'] = de.explain(
            'saliency', T=output_tensor, X=input_tensor, xs=xs, ys=ys, Y_shape=Y_shape)
        current_stats['Gradient * Input'] = de.explain(
            'grad*input', T=output_tensor, X=input_tensor, xs=xs, ys=ys, Y_shape=Y_shape)

    results.append(current_stats)

HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

KeyboardInterrupt: 

In [None]:
# Plot attributions
%matplotlib inline

n_cols = int(len(results[0]))
n_rows = int(len(results))
fig_scale = 3
fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(
    3*n_cols*fig_scale, 3*n_rows*fig_scale))

for i, result in enumerate(results):
    ax = axes.flatten()[i*n_cols]
    ax.imshow(result['image'])
    ax.set_title('Original')
    ax.axis('off')
    for j, (key, value) in enumerate(result.items()):
        axj = axes.flatten()[i*n_cols + j]
        if key == 'image' or key == 'image_result':
            axj.imshow(cv2.cvtColor(value, cv2.COLOR_BGR2RGB))
            axj.set_title(f'{key} {result["part"].get_part_name()}', fontdict={
                          'fontsize': 20})
            axj.axis('off')
        elif key == 'mask':
            axj.imshow(cv2.cvtColor(value, cv2.COLOR_BGR2RGB))
            axj.set_title(key, fontdict={'fontsize': 20})
        elif key == 'heatmap':
            heat_image = axj.imshow(value, cmap=plt.cm.hot, alpha=1.0)
            axj.set_title(key, fontdict={'fontsize': 20})
            fig.colorbar(heat_image, ax=axj, shrink=0.63)
        elif key == 'humans' or key == 'part':
            continue
        else:
            xi = (result['image'] - np.min(result['image'])).astype('float64')
            xi /= np.max(xi)
            plot(value[0], xi=xi, axis=axj, dilation=.5, percentile=99,
                 alpha=.2).set_title(key, fontdict={'fontsize': 20})