# Random Goal Exploration with Full Information


In [2]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import ipywidgets
%matplotlib notebook
import pickle
import os
import json
import pprint
import warnings
warnings.filterwarnings('ignore')
from itertools import *
import scipy.stats
from tqdm import tqdm_notebook as tqdm
import skl_groups
from skl_groups.divergences import KNNDivergenceEstimator
from skl_groups.features import Features
import matplotlib.lines

## Config

In [171]:
path = "Rge-Rep Vae Armball 2017-12-20 15:59:49.019636/"
with open(os.path.join(path, "config.json")) as f:
    config = json.load(f)
pprint.pprint(config)

{u'deformation': 0.0,
 u'distractor': False,
 u'embedding': u'vae',
 u'environment': u'armball',
 u'explo_ratio': 0.05,
 u'name': u'Rge-Rep Vae Armball 2017-12-20 15:59:49.019636',
 u'nb_bins_exploration_ratio': 10,
 u'nb_exploration_iterations': 5000,
 u'nb_period_callback': 10,
 u'nb_samples': 10000,
 u'nb_samples_divergence': 1000,
 u'nb_samples_manifold': 1000,
 u'nb_samples_mse': 100,
 u'nlatents': 10,
 u'noise': 0.0,
 u'outliers': 0.0,
 u'path': u'results/Rge-Rep Vae Armball 2017-12-20 15:59:49.019636',
 u'sampling': u'normal',
 u'test': False,
 u'verbose': False}


### Training States

In [172]:
states = np.load(os.path.join(path, 'samples_states.npy'))
print("State Space Size: %s"%(states.shape[-1]))
fig = plt.figure(figsize=(9,9))
plt.scatter(states[:,0], states[:,1])

State Space Size: 2


<IPython.core.display.Javascript object>

<matplotlib.collections.PathCollection at 0x7f527e4301d0>

In [160]:
n_states=states.shape[-1]
fig, ax = plt.subplots(n_states,n_states, figsize=(9,9))
for i in range(n_states**2):
    ax[i%n_states, i/n_states].scatter(states[:,i%n_states], states[:,i/n_states], s=.05)
    ax[i%n_states, i/n_states].axis('off')

<IPython.core.display.Javascript object>

### Training Geodesics

In [173]:
geodesics = np.load(os.path.join(path, 'samples_geodesics.npy'))
print("Embedded Size: %s"%(geodesics.shape[-1]))
fig = plt.figure(figsize=(9,9))
plt.scatter(geodesics[:,0], geodesics[:,1])
print(geodesics.shape)

Embedded Size: 2


<IPython.core.display.Javascript object>

(1000, 2)


## Exploration Dynamics

### Location Explored

In [174]:
with open(os.path.join(path, "explored_states_history.pkl"), 'rb') as f:
    explored_states_history = pickle.load(f)
print("Number of callbacks: %s"%len(explored_states_history))

Number of callbacks: 499


In [175]:
fig = plt.figure(figsize=(9,9))
ax = fig.add_subplot(1, 1, 1)
scatt = ax.scatter(explored_states_history[498][:,0], explored_states_history[498][:,1])
def update(epoch):
    scatt.set_offsets(explored_states_history[epoch][:,0:2])
    fig.canvas.draw()
ipywidgets.interact(update, epoch=(0, 498));

<IPython.core.display.Javascript object>

A Jupyter Widget

In [176]:
with open(os.path.join(path, "explored_states_history.pkl"), 'rb') as f:
    explored_states_history = pickle.load(f)

In [177]:
kls = np.zeros((49))
expl = np.zeros((49))
X_real = sample_in_attainable(explored_states_history[498].shape[0], 'armball')
for i in tqdm(range(49)):   
    #explored = explored_states_history[i*10][np.where(explored_states_history[i*10]!=[.6, .6])[1]]
    explored = explored_states_history[i*10]
    kls[i] = discretized_kl_div(explored)
    expl[i] = np.sum(np.linalg.norm(explored_states_history[i*10] - np.array([0.6, 0.6]), axis=1, ord=2) > 1e-3)
    #expl[i] /= explored_states_history[i*10].shape[0]

A Jupyter Widget




In [178]:
arm = scipy.misc.imread('test.png')
cmap='Blues'
plt.style.use('seaborn-darkgrid')

In [179]:
fig = plt.figure(figsize=(9.5,3))
plt.title("RGE-VAE - ArmBall - 10 Latents")

ax1 = fig.add_subplot(1, 1, 1)
ax1.plot(kls,linewidth=1., color='royalblue')
ax1.tick_params(axis='y', colors='royalblue')
ax1.set_ylim(3.5, 10.)
ax1.set_yticks(np.linspace(3.5,10.,6))
ax1.set_xlim(0., 50.)
ax1.set_xlabel("Exploration epochs (x100)")
ax1.set_ylabel("KLC", color='royalblue')
ax1.add_line(matplotlib.lines.Line2D([5,5],   [kls[5]+.05,7], linewidth=1))
ax1.add_line(matplotlib.lines.Line2D([15,15], [kls[15]+.05,7], linewidth=1))
ax1.add_line(matplotlib.lines.Line2D([25,25], [kls[25]+.05,7], linewidth=1))
ax1.add_line(matplotlib.lines.Line2D([35,35], [kls[35]+.05,7], linewidth=1))
ax1.add_line(matplotlib.lines.Line2D([45,45], [kls[45]+.05,7], linewidth=1))
points = [5,15,25,35,45]
ax1.scatter(points, kls[points])

ax1_2 = ax1.twinx()
ax1_2.locator_params(axis='y', nticks=6)
ax1_2.plot(expl, linewidth=1., c='mediumorchid')
ax1_2.tick_params(axis='y', colors='mediumorchid')
ax1_2.set_ylim(0., 3000.)
ax1_2.set_yticks(np.linspace(0,3000.,6))
ax1_2.set_ylabel("Nb. of Object Motion", color='mediumorchid')
ax1_2.set_xlim(0., 50.)

ax = fig.add_axes([0.01, 0.55, .3, .3])
ax.imshow(arm, extent=[-1,1, -1, 1], alpha=.85)
ax.scatter(explored_states_history[50][:,0], explored_states_history[50][:,1], s=.5, alpha=.2)
ax.axis("off")

ax = fig.add_axes([0.175, 0.55, .3, .3])
ax.imshow(arm, extent=[-1,1, -1, 1], alpha=.85)
ax.scatter(explored_states_history[150][:,0], explored_states_history[150][:,1], s=.5, alpha=.2)
ax.axis("off")

ax = fig.add_axes([0.35, 0.55, .3, .3])
ax.imshow(arm, extent=[-1,1, -1, 1], alpha=.85)
ax.scatter(explored_states_history[250][:,0], explored_states_history[250][:,1], s=.5, alpha=.2)
ax.axis("off")

ax = fig.add_axes([0.525, 0.55, .3, .3])
ax.imshow(arm, extent=[-1,1, -1, 1], alpha=.85)
ax.scatter(explored_states_history[350][:,0], explored_states_history[350][:,1], s=.5, alpha=.2)
ax.axis("off")

ax = fig.add_axes([0.70, 0.55, .3, .3])
ax.imshow(arm, extent=[-1,1, -1, 1], alpha=.85)
ax.scatter(explored_states_history[450][:,0], explored_states_history[450][:,1], s=.5, alpha=.2)
ax.axis("off");

plt.tight_layout()
plt.savefig("Figures/exploration_plot_rge_vae_armball_10.pdf")

<IPython.core.display.Javascript object>