In [None]:
import os
import random
import joblib
import argparse

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set_context('talk')
from sklearn.manifold import TSNE

from skimage.transform import rescale

import tensorflow as tf
from baselines.common.tf_util import load_variables
from baselines import deepq
from baselines.run import train
import gym

import sys
sys.path.append('..')
from gym_recorder.recorder import Recorder
from gym_recorder.utils import shortlist_operations, get_activations
from gym_recorder.perturb import perturb

%load_ext autoreload
%autoreload 2

In [None]:
arguments = dict(alg='deepq',
            env='PongNoFrameskip-v4', 
            gamestate=None,
            network=None,
            num_env=None, 
            num_timesteps=0.0, 
            play=True, 
            reward_scale=1.0, 
            save_path=None, 
            seed=None)

extra_arguments = dict(load_path='../../models/pong_1e6_dqn')

parser = argparse.ArgumentParser()
for key in arguments:
    parser.add_argument(f'--{key}')

extra_parser = argparse.ArgumentParser()
for key in extra_arguments:
    extra_parser.add_argument(f'--{key}')

args = parser.parse_args('')
extra_args = parser.parse_args('')

args.alg = 'deepq'
args.env='EnduroNoFrameskip-v4'
args.gamestate=None
args.network=None
args.num_env=None
args.num_timesteps=0.0
args.play=True
args.reward_scale=1.0
args.save_path=None
args.seed=None

extra_args = dict(load_path='../models/enduro_0_dqn')

model, env = train(args, extra_args)

In [None]:
# Set parameters
session = tf.get_default_session()
n_episodes = 1
operations = {'q_values': 'deepq/q_func/action_value/fully_connected_1/MatMul',
              '2nd_to_last': 'deepq_1/q_func/convnet/Conv_2/Relu'}
input1 = session.graph.get_operation_by_name('deepq/observation').outputs[0]
input2 = session.graph.get_operation_by_name('deepq_1/obs_t').outputs[0]

In [None]:
# Record gameplay
recorder = Recorder(act=model, env=env, operations=operations)
recorder.record(
    session=session,
    feed_operations=([input1, input2]), 
    max_episodes=1,
#     max_steps=3000,
    sample_modulo=1
)

In [None]:
!pwd

In [None]:
len(recorder.observations)

In [None]:
recorder.observations = [ob for ix, ob in enumerate(recorder.observations) if (ix%5==0) and (ix<500)]

In [None]:
from skimage.io import imsave
import glob

for ix, frame in enumerate(recorder.frames[0:10]):
    
    filename = '/{:08}.png'.format(ix)
    folder = os.path.abspath("../saliency2")
    imsave(folder+filename, frame)

In [None]:
step_size = 4
recorder.get_saliencies(session=session,
                        operation_name='deepq/q_func/action_value/fully_connected_1/MatMul',
                        feed_operations=[input1, input2],
                        step_size=step_size,
                        mode='clipping'
                       )

In [None]:
!ls

In [None]:
! open ..

In [None]:
from skimage.io import imsave
for ix, frame in enumerate(recorder.saliencies):
    filename = '../saliencies_random/{:08}.png'.format(ix)
    imsave(filename, frame)

In [None]:
for heatmap in recorder.saliencies:
    plt.figure(figsize=(heatmap.shape[0]/20, heatmap.shape[1]/20))
    plt.imshow(heatmap)
    plt.show()

In [None]:
for heatmap in recorder.saliencies:
    plt.figure(figsize=(heatmap.shape[0]/15, heatmap.shape[1]/15))
    plt.imshow(heatmap)
    plt.show()

In [None]:
plt.plot(recorder.actions)

In [None]:
plt.plot(recorder.episode_rewards)

In [None]:
env.reset()

## TSNE

In [None]:
recorder.episode_rewards = [ob for ix, ob in enumerate(recorder.episode_rewards) if (ix%5==0) and (ix<500)]

In [None]:
len(recorder.observations)

In [None]:
mapping = {'n_steps': recorder.n_steps,
           'episode_reward': recorder.episode_rewards}
n_frames = 3000 #len(recorder.frames)

In [None]:
embedder = TSNE()
activations = recorder.activations['2nd_to_last']
vectors = [np.squeeze(array)[:,:,0].flatten() for array in activations]
embedding = embedder.fit_transform(vectors)
x, y = zip(*embedding)

In [None]:
for variable in mapping:
    plt.figure(figsize=(10,5))
    plt.scatter(x, y, c=mapping[variable], alpha=0.3)
    plt.colorbar()
    plt.title(f't-SNE embeddings of {n_frames} frames\ncolored by {variable}')
    plt.show()