# Visualizing CNN output

## Overhead

In [1]:
# DL framework
import tensorflow as tf

from datetime import datetime

# common packages
import numpy as np
import os # handling file i/o
import sys
import math
import time # timing epochs

# for ordered dict when building layer components
import collections

# plotting pretty figures
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import pyplot
from matplotlib import colors # making colors consistent
from mpl_toolkits.axes_grid1 import make_axes_locatable # colorbar helper

# read image
from scipy.misc import imread
# + data augmentation
from scipy import ndimage
from scipy import misc

# used for manually saving best params
import pickle

# for shuffling data batches
from sklearn.utils import shuffle

# const
SEED = 42

# Helper to make the output consistent
def reset_graph(seed=SEED):
    tf.reset_default_graph()
    tf.set_random_seed(seed)
    np.random.seed(seed)

# helper to create dirs if they don't already exist
def maybe_create_dir(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
        print("{} createed".format(dir_path))
    else:
        print("{} already exists".format(dir_path))
    
# set log level to supress messages, unless an error
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# Important Version information
print("Python: {}".format(sys.version_info[:]))
print('TensorFlow: {}'.format(tf.__version__))

# Check if using GPU
if not tf.test.gpu_device_name():
    print('No GPU found')
else:
    print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))
    
reset_graph()

Python: (3, 5, 4, 'final', 0)
TensorFlow: 1.4.0
Default GPU Device: /device:GPU:0


In [2]:
# `saver/` will hold tf saver files
maybe_create_dir("saver")
# `best_params/` will hold a serialized version of the best params
# I like to keep this as a backup in case I run into issues with
# the saver files
maybe_create_dir("best_params")
# `tf_logs/` will hold the logs that will be visable in tensorboard
maybe_create_dir("tf_logs")

saver createed
best_params createed
tf_logs createed


In [3]:
# these two functions (get_model_params and restore_model_params) are 
# ad[a|o]pted from; 
# https://github.com/ageron/handson-ml/blob/master/11_deep_learning.ipynb
def get_model_params():
    global_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    return {global_vars.op.name: value for global_vars, value in 
            zip(global_vars, tf.get_default_session().run(global_vars))}

def restore_model_params(model_params, g, sess):
    gvar_names = list(model_params.keys())
    assign_ops = {gvar_name: g.get_operation_by_name(gvar_name + "/Assign")
                  for gvar_name in gvar_names}
    init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}
    feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}
    sess.run(assign_ops, feed_dict=feed_dict)

# these two functions are used to manually save the best
# model params to disk
def save_obj(obj, name):
    with open('best_params/'+ name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(name):
    with open('best_params/' + name + '.pkl', 'rb') as f:
        return pickle.load(f)

In [4]:
# helper to add an image to the plot and choose whether
# to include the color bar
def implot(mp, ax, SHOW_CB=False):
    cmap = plt.get_cmap('viridis')
    # bounds=[-4,0,4]
    bounds=np.linspace(-0.01, 1, 80)
    norm = colors.BoundaryNorm(bounds, cmap.N)

    # tell imshow about color map so that only set colors are used
    im = ax.imshow(mp, interpolation='nearest', origin='lower',
                        cmap=cmap, norm=norm)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    if SHOW_CB:
        if np.min(mp) != np.max(mp):
            cbar = plt.colorbar(im, cax=cax, format='%1.2f', boundaries=bounds)
        else:
            cax.set_axis_off()
    else:
        cax.set_axis_off()

    ax.set_axis_off()
    

def feat_plot(mp, ax):
    cmap = plt.get_cmap('viridis')
    bounds=np.linspace(-3, 3, 80)
    norm = colors.BoundaryNorm(bounds, cmap.N)

    img = ax.imshow(mp, interpolation='nearest', origin='lower',
                    cmap=cmap, norm=norm)

    ax.set_axis_off()

def show_masked_result(image, mask):
    # create combined image of (image & mask)
    combined = np.copy(img)
    combined[mask == 0] = [0, 0, 0]

    fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3, sharey=True, figsize=(12,4))

    implot(image, ax1)
    implot(mask, ax2, True)
    implot(combined, ax3)

    plt.grid('off')
    plt.tight_layout()
    plt.show()

In [5]:
def plot_square_maps(all_imgs, plot_as_zero_one=False, img_type="misc", epoch_num=0, SAVE_FIG=False):
    # TODO: this function is a horrible hack job, but it gets the job done
    # it really should be cleaned up
    
    # set up shape, this is a "hack" but it works for now
    if all_imgs.shape[2] % 4 == 0:
        n_col = 4
        n_row = int(all_imgs.shape[2]/n_col)
    elif all_imgs.shape[2] == 2:
        n_row = 1
        n_col = 2
    elif all_imgs.shape[2] == 1:
        n_row = 1
        n_col = 1
    else:
        print(all_imgs.shape)
    fig, ax = plt.subplots(nrows=n_row, ncols=n_col, sharey=True, figsize=(n_col*2,n_row*2))

    k = 0
    if all_imgs.shape[2] == 1:
        cur_img = all_imgs[:,:,0]
        implot(cur_img, ax)
    else:
        if n_row > 1:
            for i in range(ax.shape[0]):
                for j in range(ax.shape[1]):
                    cur_img = all_imgs[:,:,k]
                    k += 1
                    feat_plot(cur_img, ax[i][j])
        else:
            for i in range(ax.shape[0]):
                cur_img = all_imgs[:,:,k]
                if plot_as_zero_one:
                    implot(cur_img, ax[i])
                else:
                    feat_plot(cur_img, ax[i])
                k += 1

    plt.grid('off')
    if SAVE_FIG:
        name_str = '{}_{}_image_{}.png'.format(img_type, epoch_num, i)
        plt.savefig(os.path.join("./misc", name_str), dpi=300)
    plt.show()

## Dataset