# Import

In [1]:
import matplotlib
matplotlib.use("TkAgg")
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import tkinter as tk

import sys
if sys.version_info >= (3,0):
    from queue import Queue
else:
    from Queue import Queue

import gym
import numpy as np
import scipy.misc as misc
from scipy.ndimage.filters import gaussian_filter
from scipy.misc import imresize
import os
import re
import numpy as np
import tensorflow as tf

# Config

In [2]:
class Config:

    #########################################################################
    # Game configuration

    # Name of the game, with version (e.g. PongDeterministic-v0)
    ATARI_GAME = 'PongDeterministic-v0'

    # Enable to see the trained agent in action
    PLAY_MODE = False

    # Input of the DNN
    STACKED_FRAMES = 4
    IMAGE_WIDTH = 84
    IMAGE_HEIGHT = 84
    
MODE = 'actor'
    
FIRST_FRAME = 350
NUM_FRAMES = 100

DENSITY = 5
RADIUS = 5
FUDGE_FACTOR = 50


# Environment

In [3]:
class GameManager:
    def __init__(self, game_name, display):
        self.game_name = game_name
        self.display = display

        self.env = gym.make(game_name)
        self.reset()

    def reset(self):
        observation = self.env.reset()
        return observation

    def step(self, action):
        self._update_display()
        observation, reward, done, info = self.env.step(action)
        return observation, reward, done, info

    def _update_display(self):
        if self.display:
            self.env.render()


class Environment:
    def __init__(self):
        self.game = GameManager(Config.ATARI_GAME, display=Config.PLAY_MODE)
        self.nb_frames = Config.STACKED_FRAMES
        self.frame_q = Queue(maxsize=self.nb_frames)
        self.previous_state = None
        self.current_state = None
        self.total_reward = 0

        self.reset()

    @staticmethod
    def _rgb2gray(rgb):
        return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])

    @staticmethod
    def _preprocess(image):
        image = Environment._rgb2gray(image)
        image = misc.imresize(image, [Config.IMAGE_HEIGHT, Config.IMAGE_WIDTH], 'bilinear')
        image = image.astype(np.float32) / 128.0 - 1.0
        return image

    def _get_current_state(self):
        if not self.frame_q.full():
            return None  # frame queue is not full yet.
        x_ = np.array(self.frame_q.queue)
        x_ = np.transpose(x_, [1, 2, 0])  # move channels
        return x_

    def _update_frame_q(self, frame):
        if self.frame_q.full():
            self.frame_q.get()
        image = Environment._preprocess(frame)
        self.frame_q.put(image)

    def get_num_actions(self):
        return self.game.env.action_space.n

    def reset(self):
        self.total_reward = 0
        self.frame_q.queue.clear()
        self._update_frame_q(self.game.reset())
        self.previous_state = self.current_state = None

    def step(self, action):
        observation, reward, done, _ = self.game.step(action)

        self.total_reward += reward
        self._update_frame_q(observation)

        self.previous_state = self.current_state
        self.current_state = self._get_current_state()
        return reward, done


# Network

In [4]:
class Network(object):
    def __init__(self, device, model_name, num_actions):
        self.device = device 
        self.model_name = model_name
        self.num_actions = num_actions

        self.img_width = Config.IMAGE_WIDTH
        self.img_height = Config.IMAGE_HEIGHT
        self.img_channels = Config.STACKED_FRAMES

        self.graph = tf.Graph()
        with self.graph.as_default() as g:
            with tf.device(self.device):
                self.create_placeholder()
                self.create_network()
                # self.create_train_op()
                self.sess = tf.Session(
                    graph=self.graph,
                    config=tf.ConfigProto(
                        allow_soft_placement=True,
                        log_device_placement=False,
                        gpu_options=tf.GPUOptions(allow_growth=True)))
                self.sess.run(tf.global_variables_initializer())

                vars = tf.trainable_variables()
                self.saver = tf.train.Saver({var.name: var for var in vars}, max_to_keep=0)

    def create_placeholder(self):
        self.x = tf.placeholder(
            tf.float32, [None, self.img_height, self.img_width, self.img_channels], name='X')

    def create_network(self):
        # As implemented in A3C paper
        self.n1 = self.conv2d_layer(self.x, 8, 16, 'conv11', strides=[1, 4, 4, 1])
        self.n2 = self.conv2d_layer(self.n1, 4, 32, 'conv12', strides=[1, 2, 2, 1])
        self.action_index = tf.placeholder(tf.float32, [None, self.num_actions])
        _input = self.n2

        flatten_input_shape = _input.get_shape()
        nb_elements = flatten_input_shape[1] * flatten_input_shape[2] * flatten_input_shape[3]

        self.flat = tf.reshape(_input, shape=[-1, nb_elements._value])
        self.d1 = self.dense_layer(self.flat, 256, 'dense1')

        self.logits_v = tf.squeeze(self.dense_layer(self.d1, 1, 'logits_v', func=None), axis=[1])
        self.logits_p = self.dense_layer(self.d1, self.num_actions, 'logits_p', func=None)
        self.softmax_p = tf.nn.softmax(self.logits_p)

    def dense_layer(self, input, out_dim, name, func=tf.nn.relu):
        in_dim = input.get_shape().as_list()[-1]
        d = 1.0 / np.sqrt(in_dim)
        with tf.variable_scope(name):
            w_init = tf.random_uniform_initializer(-d, d)
            b_init = tf.random_uniform_initializer(-d, d)
            w = tf.get_variable('w', dtype=tf.float32, shape=[in_dim, out_dim], initializer=w_init)
            b = tf.get_variable('b', shape=[out_dim], initializer=b_init)

            output = tf.matmul(input, w) + b
            if func is not None:
                output = func(output)

        return output

    def conv2d_layer(self, input, filter_size, out_dim, name, strides, func=tf.nn.relu):
        in_dim = input.get_shape().as_list()[-1]
        d = 1.0 / np.sqrt(filter_size * filter_size * in_dim)
        with tf.variable_scope(name):
            w_init = tf.random_uniform_initializer(-d, d)
            b_init = tf.random_uniform_initializer(-d, d)
            w = tf.get_variable('w',
                                shape=[filter_size, filter_size, in_dim, out_dim],
                                dtype=tf.float32,
                                initializer=w_init)
            b = tf.get_variable('b', shape=[out_dim], initializer=b_init)

            output = tf.nn.conv2d(input, w, strides=strides, padding='SAME') + b
            if func is not None:
                output = func(output)

        return output

    def predict_p_and_v_single(self, x):
        p, v = self.sess.run([self.softmax_p, self.logits_v], feed_dict={self.x: x[np.newaxis, :]})
        return p[0], v[0]

    def _checkpoint_filename(self, episode):
        return 'checkpoints/%s_%08d' % (self.model_name, episode)

    def _get_episode_from_filename(self, filename):
        # TODO: hacky way of getting the episode. ideally episode should be stored as a TF variable
        return int(re.split('/|_|\.', filename)[2])

# Sailency

In [5]:
def occlude(img, mask):
    ret = np.zeros_like(img)
    for d in range(img.shape[2]):
        ret[:, :, d] = img[:, :, d] * (1 - mask) + gaussian_filter(img[:, :, d], sigma=3) * mask
    return ret

def get_mask(center, size, r):
    y,x = np.ogrid[-center[0]:size[0]-center[0], -center[1]:size[1]-center[1]]
    keep = x*x + y*y <= 1
    mask = np.zeros(size) ; mask[keep] = 1 # select a circle of pixels
    mask = gaussian_filter(mask, sigma=r) # blur the circle of pixels. this is a 2D Gaussian for r=r^2=1
    return mask/mask.max()

def score_frame(network, experiences, frame_id, radius, density, mode='actor'):
    # with original state
    if mode == 'actor':
        L, _ = network.predict_p_and_v_single(experiences[frame_id].state)
    elif mode == 'critic':
        _, L = network.predict_p_and_v_single(experiences[frame_id].state)
    scores = np.zeros((int(Config.IMAGE_HEIGHT / density) + 1, int(Config.IMAGE_WIDTH / density) + 1))
    for i in range(0, Config.IMAGE_HEIGHT, density):
        for j in range(0, Config.IMAGE_WIDTH, density):
            mask = get_mask(center=[i,j], size=[Config.IMAGE_HEIGHT, Config.IMAGE_WIDTH], r=radius)
            # with occluded state
            if mode == 'actor':
                l, _ = network.predict_p_and_v_single(occlude(experiences[frame_id].state, mask))
            elif mode == 'critic':
                _, l = network.predict_p_and_v_single(occlude(experiences[frame_id].state, mask))
            scores[int(i / density), int(j / density)] = np.square(L - l).sum() * 0.5

    pmax = scores.max()
    scores = imresize(scores, size=[Config.IMAGE_HEIGHT, Config.IMAGE_WIDTH], interp='bilinear').astype(np.float32)
    return pmax * scores / scores.max()

# Visualize

In [None]:
class Experience(object):
    def __init__(self, state, action, prediction, reward, done):
        self.state = state
        self.action = action
        self.prediction = prediction
        self.reward = reward
        self.done = done

env = Environment()
network = Network("cpu:0", "network", env.get_num_actions())
if Config.ATARI_GAME == 'PongDeterministic-v0':
    network.saver.restore(network.sess, './checkpoints/pong/network_00029000')
elif Config.ATARI_GAME == 'BreakoutDeterministic-v0':
    network.saver.restore(network.sess, './checkpoints/breakout/network_00097000')
else:
    raise NotImplementedError

env.reset()
done = False
experiences = []

while not done:
    # very first few frames 
    if env.current_state is None:
        env.step(0) # 0 == NOOP
        continue

    prediction, value = network.predict_p_and_v_single(env.current_state)
    action = np.argmax(prediction)
    reward, done = env.step(action)
    exp = Experience(env.previous_state, action, prediction, reward, done)
    experiences.append(exp)

frames = []
perturbation_maps = []
for frame_id in range(FIRST_FRAME, FIRST_FRAME + NUM_FRAMES):
    sailency = score_frame(network, experiences, frame_id, RADIUS, DENSITY, mode=MODE)
    pmax = sailency.max()

    sailency -= sailency.min() ; sailency = FUDGE_FACTOR * pmax * sailency / sailency.max()
    frames.append(experiences[frame_id].state[:, :, 3])
    perturbation_maps.append(experiences[frame_id].state[:, :, 3] + sailency)
    print(' [ %d / %d ] processing perturbation_map ... ' % (frame_id - FIRST_FRAME, NUM_FRAMES))

# Visualize
fig = plt.Figure()

root = tk.Tk()

label = tk.Label(root, text="Video")
label.grid(column=0, row=0)

canvas = FigureCanvasTkAgg(fig, master=root)
canvas.get_tk_widget().grid(column=0, row=1)

ax_1 = fig.add_subplot(121)
ax_2 = fig.add_subplot(122)


def vedio(i):
    frame = frames.pop(0)
    frames.append(frame)
    ax_1.clear()
    ax_1.imshow(frame, vmin=0, vmax=1, cmap='gray')
    p_map = perturbation_maps.pop(0)
    perturbation_maps.append(p_map)
    ax_2.clear()
    ax_2.imshow(p_map, vmin=0, vmax=1, cmap='gray') #actor_sailency)

ani = animation.FuncAnimation(fig, vedio, 1, interval=200)
tk.mainloop()




INFO:tensorflow:Restoring parameters from ./checkpoints/pong/network_00029000


`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.


 [ 0 / 100 ] processing perturbation_map ... 
 [ 1 / 100 ] processing perturbation_map ... 
 [ 2 / 100 ] processing perturbation_map ... 
 [ 3 / 100 ] processing perturbation_map ... 
 [ 4 / 100 ] processing perturbation_map ... 
 [ 5 / 100 ] processing perturbation_map ... 
 [ 6 / 100 ] processing perturbation_map ... 
 [ 7 / 100 ] processing perturbation_map ... 
 [ 8 / 100 ] processing perturbation_map ... 
 [ 9 / 100 ] processing perturbation_map ... 
 [ 10 / 100 ] processing perturbation_map ... 
 [ 11 / 100 ] processing perturbation_map ... 
 [ 12 / 100 ] processing perturbation_map ... 
 [ 13 / 100 ] processing perturbation_map ... 
 [ 14 / 100 ] processing perturbation_map ... 
 [ 15 / 100 ] processing perturbation_map ... 
 [ 16 / 100 ] processing perturbation_map ... 
 [ 17 / 100 ] processing perturbation_map ... 
 [ 18 / 100 ] processing perturbation_map ... 
 [ 19 / 100 ] processing perturbation_map ... 
 [ 20 / 100 ] processing perturbation_map ... 
 [ 21 / 100 ] processin