In [None]:
import numpy as np
from scipy.stats import norm
import tensorflow as tf
import matplotlib.pyplot as plt
from matplotlib import animation, rc
from APIAE import DynNet, GenNet, APIAE
import matplotlib.patches as patches
from IPython.display import HTML
import pickle
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="" # which GPU devices to use
plt.rcParams.update({'figure.max_open_warning': 0})

np.random.seed(0)
tf.set_random_seed(0)

In [None]:
# Build APIAE
params_inference = \
dict(n_x = 16**2, # dimension of x; observation
n_z = 2, # dimension of z; latent space
n_u = 1, # dimension of u; control

K = 10, # the number of time steps
L = 100, # the number of trajectory sampled
R = 10,# the number of improvements

dt = .1, # time interval
ur = .3, # update rate 
lr = 0.001 # learning rate
    )
apiae_inference = APIAE(**params_inference)

# Build APIAE
params_planner = \
dict(n_x = 16**2, # dimension of x; observation
n_z = 2, # dimension of z; latent space
n_u = 1, # dimension of u; control

K = 50, # the number of time steps
L = 10000, # the number of trajectory sampled
R = 30,# the number of improvements

dt = .1, # time interval
ur = .3, # update rate 
lr = 0.001, # learning rate
isPlanner = True
    )
apiae_planner = APIAE(**params_planner)

# Set parameters
K_inference = params_inference['K']
K_planner = params_planner['K']
dt_inference = params_planner['dt']
dt_planner = params_planner['dt']
n_x = params_inference['n_x']
n_z = params_inference['n_z']

In [None]:
apiae_inference.restoreWeights(filename='./weights_demo.pkl')
apiae_planner.restoreWeights(filename='./weights_demo.pkl')

In [None]:
# Initial Sequences
file = open("pendulum_zero.pkl",'rb')
DATA = pickle.load(file, encoding='latin1')
X0 = DATA[0][:,:,0:int(K_inference*dt_inference/0.01):int(dt_inference/0.01),:,:] # observe first sequence
Xf = DATA[0][:,:,int((K_inference-1)*dt_inference/0.01):-1:int(dt_planner/0.01),:,:]
file.close()

# The guidance image
Xref = np.zeros((16,16))
for i in range(0,16):
    Xref[i,:] = 16./(i+3) # swing up
Xref = (Xref - np.min(Xref)) / (np.max(Xref) - np.min(Xref))

# Show guidance image
plt.close()
plt.figure()
plt.imshow(Xref)
plt.show()

In [None]:
# Inference of the latent space trajectory for the given sequence of observation
museq0 = apiae_inference.sess.run((apiae_inference.museq_list), feed_dict={apiae_inference.xseq: X0})
museq0_reshape = np.reshape(museq0[-1][0,0,:,:,0],(K_inference,n_z))
xtemp = apiae_inference.sess.run(apiae_inference.genNet.x_out,feed_dict={apiae_inference.genNet.z_in:museq0_reshape})
xrecon = np.reshape(xtemp, (-1,16,16))

# Plan
Xobjective = np.tile(Xref.reshape((1,1,1,256,1)),(1,1,K_planner,1,1)) # Define obejctive image
museq_planner = apiae_planner.sess.run(apiae_planner.museq_list,
                feed_dict={apiae_planner.xseq: Xobjective, apiae_planner.mu0:museq0[-1][:,:,-1:,:,:]})
museq_planner_reshape = np.reshape(museq_planner[-1][0,0,:,:,0],(K_planner,n_z))
xtemp = apiae_planner.sess.run(apiae_planner.genNet.x_out,feed_dict={apiae_planner.genNet.z_in:museq_planner_reshape})
xplan = np.reshape(xtemp, (-1,16,16))

# Prediction (Learned Dynamics)
z_pred = np.zeros((K_planner,2))
z_pred[0,:] = museq0[-1][0,0,-1,:,0]
for t in range(K_planner-1):
    dz_pred = apiae_planner.sess.run(apiae_planner.dynNet.zdot_out, \
                 feed_dict={apiae_planner.dynNet.z_in:z_pred[t:t+1,:]})*dt_planner
    z_pred[t+1,:] = z_pred[t,:] + dz_pred

xtemp = apiae_planner.sess.run(apiae_planner.genNet.x_out,feed_dict={apiae_planner.genNet.z_in:z_pred})
xpred = np.reshape(xtemp, (-1,16,16))

# Draw Results
plt.close()
plt.figure()
plt.plot(museq0_reshape[0,0],museq0_reshape[0,1],'k.') # inference
plt.plot(museq0_reshape[:,0],museq0_reshape[:,1],'k-') # inference
plt.plot(museq_planner_reshape[0,0],museq_planner_reshape[0,1],'r.') # planning
plt.plot(museq_planner_reshape[:,0],museq_planner_reshape[:,1],'r-') # planning
plt.plot(z_pred[0,0],z_pred[0,1],'b.') # generative
plt.plot(z_pred[:,0],z_pred[:,1],'b-') # generative
plt.grid()
plt.xlabel('z1')
plt.ylabel('z2')
plt.show()

In [None]:
# Draw Animation 
ims_planner = []
fig_planner, (ax1,ax2,ax3) = plt.subplots(1,3, figsize=(45,15))
ax1.set_xticks([])
ax1.set_yticks([])
ax2.set_xticks([])
ax2.set_yticks([])
ax3.set_xticks([])
ax3.set_yticks([])
plt.tight_layout()

ax1.add_patch(
    patches.Rectangle(
        (-0.35, -0.4),
        15.7,
        15.8,
        fill=False,
        edgecolor="blue",
        linewidth=20
    )
)
ax2.add_patch(
    patches.Rectangle(
        (-0.35, -0.4),
        15.7,
        15.8,
        fill=False,
        edgecolor="blue",
        linewidth=20
    )
)
ax3.add_patch(
    patches.Rectangle(
        (-0.35, -0.4),
        15.7,
        15.8,
        fill=False,
        edgecolor="blue",
        linewidth=20
    )
)
for t in range(K_inference):
    im1 = ax1.imshow(X0[0,0,t,:,0].reshape(16,16), animated=True)
    imtext1 = ax1.text(.2,1,'Ground Truth',fontsize=70,color='white')
    im2 = ax2.imshow(xrecon[t,:,:], animated=True)
    imtext2 = ax2.text(.2,1,'Reconstruction',fontsize=70,color='white')
    im3 = ax3.imshow(xrecon[t,:,:], animated=True)
    imtext3 = ax3.text(.2,1,'Reconstruction',fontsize=70,color='white')
    ims_planner.append([im1,imtext1,im2,imtext2,im3,imtext3])

im_ani_planner = animation.ArtistAnimation(fig_planner, ims_planner, interval=200,repeat_delay=100)  
im_ani_planner.save('./inference.mp4')         
HTML(im_ani_planner.to_html5_video())     
    
    
ims_planner = []
fig_planner, (ax1,ax2,ax3) = plt.subplots(1,3, figsize=(45,15))
ax1.set_xticks([])
ax1.set_yticks([])
ax2.set_xticks([])
ax2.set_yticks([])
ax3.set_xticks([])
ax3.set_yticks([])
plt.tight_layout()    
ax1.add_patch(
    patches.Rectangle(
        (-0.35, -0.4),
        15.7,
        15.8,
        fill=False,
        edgecolor="blue",
        linewidth=20
    )
)
ax2.add_patch(
    patches.Rectangle(
        (-0.35, -0.4),
        15.7,
        15.8,
        fill=False,
        edgecolor="red",
        linewidth=20
    )
)
ax3.add_patch(
    patches.Rectangle(
        (-0.35, -0.4),
        15.7,
        15.8,
        fill=False,
        edgecolor="red",
        linewidth=20
    )
)
for t in range(1,K_planner):
    im1 = ax1.imshow(Xf[0,0,t,:,0].reshape(16,16), animated=True)
    imtext1 = ax1.text(.2,1,'Ground Truth',fontsize=70,color='white')
    im2 = ax2.imshow(xpred[t,:,:], animated=True)
    imtext2 = ax2.text(.2,1,'Prediction',fontsize=70,color='white')
    im3 = ax3.imshow(xplan[t,:,:], animated=True)
    imtext3 = ax3.text(.2,1,'Planning',fontsize=70,color='white')
    ims_planner.append([im1,imtext1,im2,imtext2,im3,imtext3])
    
im_ani_planner = animation.ArtistAnimation(fig_planner, ims_planner, interval=200,repeat_delay=100)  
im_ani_planner.save('./planning.mp4')         
HTML(im_ani_planner.to_html5_video()) 