## Introduction

This notebook shows usage of the XAITK-Saliency API to gain inference on the behavior of a deep RL agent in an Atari 2600 environment.

This example is based upon [this paper](https://arxiv.org/abs/1711.00138) and corresponding [github page](https://github.com/greydanus/visualize_atari).
The authors use the Asynchronous Advantage Actor Critic (A3C) algorithm with an LSTM-CNN policy network to train several agents for automated gameplay of different Atari 2600 games.
They also implement a method for generating saliency maps using image perturbation.
We will show here a recreation of their results using the XAITK-Saliency API, focusing on the Breakout environment.

## Install Dependencies

In [1]:
!pip install -q "xaitk-saliency"
!pip install -q "torchvision"
!pip install -q "gym[atari]"
!pip install -q "opencv-python"

## Download Pretrained Model

The author's provide pretrained agents for the different environments they used.
We will use the Breakout agent for our purposes.

Due to permissions, downloading the zip file with the models will fail.
Navigate to the Google drive link manually to download the zip file and move it to the current working directory.

In [2]:
import gdown

url = 'https://drive.google.com/u/0/uc?export=download&confirm=gzOH&id=0B-HNE76mR97FaHYtX202WFZSRXc'
output = 'pretrained.zip'
gdown.download(url, output, quiet=False)

Access denied with the following error:



 	Cannot retrieve the public link of the file. You may need to change
	the permission to 'Anyone with the link', or have had many accesses. 

You may still be able to access the file from the browser:

	 https://drive.google.com/u/0/uc?export=download&confirm=gzOH&id=0B-HNE76mR97FaHYtX202WFZSRXc 



### Extract Model

In [3]:
import zipfile

zip_fname = 'pretrained.zip'
output_dir = 'pretrained_agents/'

try:
    with zipfile.ZipFile(zip_fname, 'r') as zip_ref:
        zip_ref.extractall(output_dir)
except FileNotFoundError:
    print("Can't find zip file. Make sure you have downloaded and moved it to the cwd.")
    raise

## Create Atari Environment

Here we create the Breakout environment for our agent using Gym.

Our agent has 4 different actions to choose from:
<br>&nbsp;&nbsp;&nbsp;&nbsp;1. Do nothing
<br>&nbsp;&nbsp;&nbsp;&nbsp;2. Fire
<br>&nbsp;&nbsp;&nbsp;&nbsp;3. Move right
<br>&nbsp;&nbsp;&nbsp;&nbsp;4. Move left

In [4]:
import gym

env_name = "Breakout-v0"
env = gym.make(env_name)
env.seed(1)

action_space = env.unwrapped.get_action_meanings()
print(f"Action space: {action_space}")

Action space: ['NOOP', 'FIRE', 'RIGHT', 'LEFT']


A.L.E: Arcade Learning Environment (version +a54a328)
[Powered by Stella]


## Define Policy Network

This policy network implementation is taken directly from the author's own implementation.
It consists of four convolutional layers, an LSTM layer, and two separate fully-connected layers for the value and policy function predictions.

In [5]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import glob
import numpy as np

class NNPolicy(torch.nn.Module):  # an actor-critic neural network
    def __init__(self, channels, num_actions):
        super(NNPolicy, self).__init__()
        self.conv1 = nn.Conv2d(channels, 32, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.lstm = nn.LSTMCell(32 * 5 * 5, 256)
        self.critic_linear, self.actor_linear = nn.Linear(256, 1), nn.Linear(
            256, num_actions
        )

    def forward(self, inputs):
        inputs, (hx, cx) = inputs
        x = F.elu(self.conv1(inputs))
        x = F.elu(self.conv2(x))
        x = F.elu(self.conv3(x))
        x = F.elu(self.conv4(x))
        x = x.view(-1, 32 * 5 * 5)
        hx, cx = self.lstm(x, (hx, cx))
        return self.critic_linear(hx), self.actor_linear(hx), (hx, cx)

    def try_load(self, save_dir, checkpoint="*.tar"):
        paths = glob.glob(save_dir + checkpoint)
        step = 0
        if len(paths) > 0:
            ckpts = [int(s.split(".")[-2]) for s in paths]
            ix = np.argmax(ckpts)
            step = ckpts[ix]
            self.load_state_dict(torch.load(paths[ix]))
        print("\tno saved models") if step == 0 else print(
            "\tloaded model: {}".format(paths[ix])
        )
        return step

## Load Pretrained Model

Here we load the downloaded model into an instance of the policy function class.

In [6]:
load_dir = '{}{}/'.format(output_dir, env_name.lower())

model = NNPolicy(channels=1, num_actions=env.action_space.n)
_ = model.try_load(load_dir, checkpoint='*.tar')

torch.manual_seed(1)

	loaded model: pretrained_agents/breakout-v0/strong.40.tar


<torch._C.Generator at 0x7f3d05dd79b0>

## Define Rollout Function

This function carries out the pretrained agent's policy for a defined number of frames in our Breakout environment.
At each step, the current game frame is ran through our policy model to get the predicted best action and the agent takes that action.
The state of the game is stored after after each step.

In [7]:
import cv2

prepro = (
    lambda img: cv2.resize(src=img[35:195].mean(2), dsize=(80, 80))
    .astype(np.float32)
    .reshape(1, 80, 80)
    / 255.0
)

def rollout(model, env, max_ep_len):
    history = {"ins": [], "logits": [], "values": [], "outs": [], "hx": [], "cx": []}

    state = torch.Tensor(prepro(env.reset()))  # get first state
    episode_length, epr, eploss, done = 0, 0, 0, False  # bookkeeping
    hx, cx = torch.zeros(1, 256), torch.zeros(1, 256)

    # iterate through each frame in episode
    while not done and episode_length <= max_ep_len:
        episode_length += 1
        
        # get game state
        model_inp = (state.view(1, 1, 80, 80), (hx, cx))
        
        # run through model
        value, logit, (hx, cx) = model(model_inp)
        hx, cx = hx.data, cx.data
        
        # action probabilities
        prob = F.softmax(logit)

        # best action
        action = prob.max(1)[1].data
        
        # take best action
        obs, reward, done, expert_policy = env.step(action.numpy()[0])
        
        state = torch.Tensor(prepro(obs))
        epr += reward

        # save state
        history["ins"].append(obs) # game state after taking action
        history["hx"].append(hx.squeeze(0).data.numpy()) # LSTM hx output
        history["cx"].append(cx.squeeze(0).data.numpy()) # LSTM cx output
        history["logits"].append(logit.data.numpy()[0]) # actor output
        history["values"].append(value.data.numpy()[0]) # critic output
        history["outs"].append(prob.data.numpy()[0]) # action probabilities
        print("\tstep # {}, reward {:.0f}".format(episode_length, epr), end="\r")

    return history

### Play Breakout

Our pretrained agent will now play the game for 3,000 frames.
We will create a short video clip from a slice of the game state so we can see the agent in action.

In [8]:
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings("ignore")

# Play game
print("Rolling out policy...")
history = rollout(model, env, max_ep_len=3e3)

# Create video from frames
print("\nCreating video...")

w = history['ins'][0].shape[1]
h = history['ins'][0].shape[0]
fps = 30

out = cv2.VideoWriter("breakout.mp4", cv2.VideoWriter_fourcc(*'vp09'), fps, (w,h))

start_frame = 1000
end_frame = 1200
for i in range(start_frame, end_frame+1):
    frame = cv2.cvtColor(history['ins'][i], cv2.COLOR_RGB2BGR) # Convert to BRG for cv2 standards
    out.write(frame)
out.release()

print("Done")

Rolling out policy...
	step # 2716, reward 376
Creating video...
Done


In [9]:
%%HTML
<div align="middle">
<video width="50%" controls>
      <source src="breakout.mp4" type="video/mp4">
</video></div>

## Defining the Application

Our saliency application has four parameters:
<br>&nbsp;&nbsp;&nbsp;&nbsp;`start_frame` - the first frame to perform saliency generation for
<br>&nbsp;&nbsp;&nbsp;&nbsp;`end_frame` - the last frame to perform saliency generation for
<br>&nbsp;&nbsp;&nbsp;&nbsp;`perturber` - the PerturbImage implementation to use
<br>&nbsp;&nbsp;&nbsp;&nbsp;`saliency_gen` - the GenerateClassifierConficdenceSaliency implementation to use

The application will create saliency maps for both the actor(policy function) and the critic(value function) for each frame from `start_frame` to `end_frame` using the image perturber and saliency generator that you pass it.
To show generally what sections of each frame are affecting the agents decisions, the saliency maps are altered in two ways.
First, the distinction between negative and positive saliency is removed by taking the absolute value of each map.
Second, the saliency maps for each class are averaged to give a single saliency map for the actor and critic of each frame.
This gives a single representation for each frame of where both models are looking to make their predictions.
Salient parts of each frame will be highlighted in red separately for both the actor and critic and two videos will be created from the respective set of highlighted frames for both models.

To speed up this process, the application utilizes multiple processing threads, one for each frame. This will use a good amount of system memory so be cautious.

In [10]:
from typing import Callable
from xaitk_saliency import PerturbImage, GenerateClassifierConfidenceSaliency
from xaitk_saliency.utils.masking import occlude_image_batch
import threading

actor_sal_maps = []
critic_sal_maps = []

def app(
    start_frame: int,
    end_frame: int,
    perturber: PerturbImage,
    saliency_gen: GenerateClassifierConfidenceSaliency
):
    global actor_sal_maps
    global critic_sal_maps
    
    # Initialize map arrays to correct size
    actor_sal_maps = [None] * (end_frame - start_frame +1)
    critic_sal_maps = [None] * (end_frame - start_frame +1)
    
    threads = []
    
    for img_idx in range(start_frame, end_frame+1):
        
        # Create threads
        threads.append(
            threading.Thread(
                target=gen_sal_maps,
                args=[img_idx, (img_idx-start_frame), perturber, saliency_gen]
            )
        )
    
    # Start threads
    for t in threads:
        t.start()
        
    # Wait for threads to finish
    for t in threads:
        t.join()
    
    # Write out videos
    print("Writing actor video")
    fps = 1
    actor_vid_writer = cv2.VideoWriter("breakout_actor_saliency.mp4", cv2.VideoWriter_fourcc(*'vp09'), fps, (w,h))
    critic_vid_writer = cv2.VideoWriter("breakout_critic_saliency.mp4", cv2.VideoWriter_fourcc(*'vp09'), fps, (w,h))
    
    for img_idx in range(start_frame, end_frame+1):
        sal_idx = img_idx-start_frame
        
        actor_img = history['ins'][img_idx].copy()
        critic_img = history['ins'][img_idx].copy()
        
        # Highlight salient locations in red
        actor_img[:,:,0] += (200.0*actor_sal_maps[sal_idx]).astype("uint8")
        critic_img[:,:,0] += (200.0*critic_sal_maps[sal_idx]).astype("uint8")
        
        # Convert to BGR to meet cv2 standard
        actor_frame = cv2.cvtColor(actor_img, cv2.COLOR_RGB2BGR)
        critic_frame = cv2.cvtColor(critic_img, cv2.COLOR_RGB2BGR)
        
        actor_vid_writer.write(actor_frame)
        critic_vid_writer.write(critic_frame)
        
    actor_vid_writer.release()
    critic_vid_writer.release()
    
    print("Done.")
    
def gen_sal_maps(img_idx, sal_idx, perturber, saliency_gen):
    global actor_sal_maps
    global critic_sal_maps
    
    # Score reference frame
    print(f"[{img_idx}]Scoring frame")
    
    ref_img = history['ins'][img_idx]
    
    ref_img_proc = prepro(ref_img)

    ref_img_tensor = torch.tensor(ref_img_proc.reshape(1, 1, 80, 80))
    hx = torch.tensor(history['hx'][img_idx-1]).view(1, -1)
    cx = torch.tensor(history['cx'][img_idx-1]).view(1, -1)

    model_inp = (ref_img_tensor, (hx, cx))

    ref_value, ref_logit, _ = model(model_inp)

    ref_value = ref_value.detach().numpy()[0]
    ref_logit = ref_logit.detach().numpy()[0]
        
    # Get image perturbations
    print(f"[{img_idx}]Perturbing image")

    pert_masks = perturber(ref_img)
        
    pert_imgs = occlude_image_batch(ref_img, pert_masks)

    pert_values = []
    pert_logits = []
    
    # Score perturbations
    print(f"[{img_idx}]Scoring perturbations")

    for pert_img in pert_imgs:

        pert_img_proc = prepro(pert_img)

        pert_img_tensor = torch.tensor(pert_img_proc.reshape(1, 1, 80, 80))

        model_inp = (pert_img_tensor, (hx, cx))

        pert_value, pert_logit, _ = model(model_inp)

        pert_values.append(pert_value.detach().numpy()[0])
        pert_logits.append(pert_logit.detach().numpy()[0])
        
    # Generate actor saliency maps
    print(f"[{img_idx}]Generating actor saliency maps")
    actor_sal = saliency_gen(ref_logit, pert_logits, pert_masks)
    actor_sal = np.sum(np.abs(actor_sal), axis=0)
    actor_sal = actor_sal / actor_sal.max()
    actor_sal_maps[sal_idx] = actor_sal
        
    # Generate critic saliency maps
    print(f"[{img_idx}]Generating critic saliency maps")
    critic_sal = saliency_gen(ref_value, pert_values, pert_masks)[0]
    critic_sal = np.abs(critic_sal)
    critic_sal_maps[sal_idx] = critic_sal

## Perturbation and Saliency Implementations

For this example we will use the `SlidingWindow` perturbation implementation with a window size of (5,5) and stride of (5,5).

The `OcclusionScoring` heatmap generation implementation is appropriate here as both the actor and critic are classification-like models.

In [11]:
from xaitk_saliency.impls.perturb_image.sliding_window import SlidingWindow
from xaitk_saliency.impls.gen_classifier_conf_sal.occlusion_scoring import OcclusionScoring

window_perturber = SlidingWindow(window_size=(5,5), stride=(5,5))
sal_gen = OcclusionScoring()

## Calling the Application

Five arbitrary frames from the middle of the set are chosen for saliency generation using our application.
The resulting videos will be displayed side-by-side, actor on the right, critic on the left, after the application finishes.

In [12]:
app(
    start_frame=1039,
    end_frame=1043,
    perturber=window_perturber,
    saliency_gen=sal_gen
)

[1039]Scoring frame
[1040]Scoring frame
[1041]Scoring frame
[1042]Scoring frame[1043]Scoring frame

[1039]Perturbing image
[1040]Perturbing image
[1043]Perturbing image
[1041]Perturbing image
[1042]Perturbing image
[1039]Scoring perturbations
[1040]Scoring perturbations
[1043]Scoring perturbations
[1042]Scoring perturbations
[1041]Scoring perturbations
[1040]Generating actor saliency maps
[1039]Generating actor saliency maps
[1043]Generating actor saliency maps
[1042]Generating actor saliency maps
[1041]Generating actor saliency maps
[1040]Generating critic saliency maps
[1039]Generating critic saliency maps
[1043]Generating critic saliency maps
[1042]Generating critic saliency maps
[1041]Generating critic saliency maps
Writing actor video
Done.


In [13]:
%%HTML
<div align="left">
<video width="45%" controls>
    <source src="breakout_actor_saliency.mp4" type="video/mp4">
</video>
<video width="45%" controls>
    <source src="breakout_critic_saliency.mp4" type="video/mp4">
</video>
</div>
