In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import gym
from main.alpaca import *
from main.datagen import *
from main.dataViz import *

In [None]:
cfg_filename = 'configs/pendulum-config.yml'
with open(cfg_filename,'r') as ymlfile:
    config = yaml.load(ymlfile)

In [None]:
N_train = 5000
N_test = 500
test_horz = 100
env = gym.make('Pendulum-v0')
DG = DataGenerator(config,env,'Pendulum-v0')
Y,X = DG.sample_trajectories(pi_zero,test_horz,N_train,return_lists=False)
Y_test,X_test = DG.sample_trajectories(pi_zero,test_horz,N_test,return_lists=False)

## Agent

In [None]:
g1 = tf.Graph()
sess1 = tf.InteractiveSession(config=tf.ConfigProto(log_device_placement=True), graph=g1)
agent = ALPaCA(config)

In [None]:
agent.construct_model(sess1)

In [None]:
agent.train(sess1,Y,X,5000)

## Agent without meta training

In [None]:
g2 = tf.Graph()
sess2 = tf.Session(config=tf.ConfigProto(log_device_placement=True), graph=g2)
agent_nometa = ALPaCA(config)
agent_nometa.config['data_horizon'] = 0
agent_nometa.construct_model(sess2,g2)
agent_nometa.train(sess2, Y,X, 10000)

## GP Agent

In [None]:
from gp_reg import *

GPR_agent = GPReg()

## DynamicsAgent

In [None]:
g3 = tf.Graph()
sess3 = tf.Session(config=tf.ConfigProto(log_device_placement=True), graph=g3)
dyn_model = AdaptiveDynamics(config)
dyn_model.construct_model(sess3,g3)

In [None]:
dyn_model.train(sess3, Y,X, 1000)

In [None]:
ind = 7
for Nu in [0,3,5,10]:
    gen_pendulum_onestep_fig(agent,sess1,X_test[ind:ind+1,:,:],Y_test[ind:ind+1,:,:],Nu=Nu,T=30)

In [None]:
ind = 7
T = 50
N_samples = 50
for Nu in [0,15,30]:
    plt.figure()
    #gen_pendulum_sample_fig(dyn_model, sess3, X_test[ind:ind+1,:,:], Y_test[ind:ind+1,:,:], Nu, N_samples=N_samples,T=T, T_rollout= 20)
    test_adaptive_dynamics(dyn_model, sess3, X_test[ind:ind+1,:T,:], Y_test[ind:ind+1,:T,:], N_samples, Nu)
    plt.show()

In [None]:
ind = 0
sample_size_list = [0,10,20,30]
T=60
N_samples=50
T_rollout = 30

plt.figure(figsize=(9,5))
for i,Nu in enumerate(sample_size_list):
    title=None
    legend=False
    if i == 0:
        legend=True
        title=True
        
    ax1 = plt.subplot(len(sample_size_list),3,3*i+1)
    gen_pendulum_sample_fig(agent, sess1, X_test[ind:ind+1,:,:], Y_test[ind:ind+1,:,:], Nu, N_samples=N_samples,T=T, T_rollout=T_rollout)
    if i == 0:
        plt.title('ALPaCA')
    if i < len(sample_size_list) - 1:
        plt.setp(ax1.get_xticklabels(), visible=False)
    
    ax2 = plt.subplot(len(sample_size_list),3,3*i+2, sharey=ax1)
    gen_pendulum_sample_fig(agent_nometa, sess2, X_test[ind:ind+1,:,:], Y_test[ind:ind+1,:,:], Nu, N_samples=N_samples,T=T, T_rollout=T_rollout)
    plt.setp(ax2.get_yticklabels(), visible=False)
    if i == 0:
        plt.title('ALPaCA (no meta)')
    if i < len(sample_size_list) - 1:
        plt.setp(ax2.get_xticklabels(), visible=False)
    
    ax3 = plt.subplot(len(sample_size_list),3,3*i+3, sharey=ax1)
    gen_pendulum_sample_fig(agent_nometa, sess2, X_test[ind:ind+1,:,:], Y_test[ind:ind+1,:,:], Nu, N_samples=N_samples,T=T, T_rollout=T_rollout,no_update=True)
    plt.setp(ax2.get_yticklabels(), visible=False)
    if i == 0:
        plt.title('ALPaCA (no update)')
    if i < len(sample_size_list) - 1:
        plt.setp(ax2.get_xticklabels(), visible=False)
    
plt.tight_layout(w_pad=0.0,h_pad=0.2)
plt.savefig('figures/pendulum_three.pdf')
plt.show()

In [None]:
# compute NLL for all three models
import time

def gaussian_nll(y,mu,Sig):
    n = 2
    logdet = np.log(np.linalg.det(Sig)) 
    quadform = ((y-mu).T @ (np.linalg.inv(Sig)) @ (y-mu))
    nll = n*np.log(2*np.pi) + logdet + quadform
    
    return 0.5*nll

def MSE(y,mu):
    return (y-mu)**2

def get_stats(meas, N):
    mean = sum(meas)/N
    var = sum([(nl - mean)**2 for nl in meas])/(N-1)
    return mean, var

alpaca_nll_mean = []
alpaca_nll_var = []
alpaca_time_mean = []
alpaca_time_var = []

lpaca_nll_mean = []
lpaca_nll_var = []
lpaca_time_mean = []
lpaca_time_var = []

prior_nll_mean = []
prior_nll_var = []

N_test = 500

test_horz1 = T
for j in range(test_horz1):
    nll_list_alpaca = []
    nll_list_lpaca = []
    nll_list_prior = []
    
    time_list_alpaca = []
    time_list_lpaca = []
    
    for ind in range(N_test):
        X_update = X_test[ind:(ind+1),:j,:]
        Y_update = Y_test[ind:(ind+1),:j,:]
        x_pt = X_test[ind:(ind+1),(j):(j+1),:]
        y_pt = Y_test[ind:(ind+1),(j):(j+1),:]
    
        X_empty = X_test[ind:(ind+1),:0,:]
        Y_empty = Y_test[ind:(ind+1),:0,:]

        t1_alpaca = time.process_time()
        y, s = agent.test(sess1, X_update, Y_update, x_pt)
        t2_alpaca = time.process_time()
        
        t1_lpaca = time.process_time()
        y_lpaca,s_lpaca = agent_nometa.test(sess2, X_update, Y_update, x_pt)
        t2_lpaca = time.process_time()
        
        y_prior,s_prior = agent_nometa.test(sess2, X_empty, Y_empty, x_pt)

        time_list_alpaca.append(t2_alpaca - t1_alpaca)
        time_list_lpaca.append(t2_lpaca - t1_lpaca)
            
        nll_list_alpaca.append(gaussian_nll(y_pt[0,0,:],y[0,0,:],s[0,0,:,:]))
        nll_list_lpaca.append(gaussian_nll(y_pt[0,0,:],y_lpaca[0,0,:],s_lpaca[0,0,:,:]))
        nll_list_prior.append(gaussian_nll(y_pt[0,0,:],y_prior[0,0,:],s_prior[0,0,:,:]))
        
    nll_mean_alpaca, nll_var_alpaca = get_stats(nll_list_alpaca,N_test)
    nll_mean_lpaca, nll_var_lpaca = get_stats(nll_list_lpaca,N_test)
    nll_mean_prior, nll_var_prior = get_stats(nll_list_prior,N_test)
    
    time_mean_alpaca, time_var_alpaca = get_stats(time_list_alpaca,N_test)
    time_mean_lpaca, time_var_lpaca = get_stats(time_list_lpaca,N_test)
    
    alpaca_nll_mean.append(nll_mean_alpaca)
    alpaca_nll_var.append(nll_var_alpaca)
    
    lpaca_nll_mean.append(nll_mean_lpaca)
    lpaca_nll_var.append(nll_var_lpaca)
    
    prior_nll_mean.append(nll_mean_prior)
    prior_nll_var.append(nll_var_prior)
    
    alpaca_time_mean.append(time_mean_alpaca)
    alpaca_time_var.append(time_var_alpaca)
    
    lpaca_time_mean.append(time_mean_lpaca)
    lpaca_time_var.append(time_var_lpaca)

In [None]:
plt.figure(figsize=(3.5,3))
nll_plot(alpaca_nll_mean,alpaca_nll_var,lpaca_nll_mean,lpaca_nll_var,prior_nll_mean,prior_nll_var,N_test,legend=True,last_legend_label=r'ALPaCA (no update)')
plt.tight_layout()
plt.savefig('figures/nll_pendulum.pdf')
plt.show()

In [None]:
plt.figure(figsize=(3.5,3))
nll_plot(alpaca_nll_mean,alpaca_nll_var,lpaca_nll_mean,lpaca_nll_var,None,None,N_test,legend=True)
plt.tight_layout()
plt.savefig('figures/nll_pendulum.pdf')
plt.show()

In [None]:
plt.tight_layout(w_pad=0.0,h_pad=0.2)
plt.figure(figsize=(3.5,3))
time_plot(alpaca_time_mean,alpaca_time_var,lpaca_time_mean,lpaca_time_var,gp_time_mean,gp_time_var,N_test,legend=True)
plt.savefig('figures/time_multistep.pdf')
plt.show()