In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import gym
from main.maml_agent import MAMLAgent, register_flags
register_flags()
from main.alpaca import *
from main.datagen import *
from main.dataViz import *

In [None]:
cfg_filename = 'configs/sinusoid-config.yml'
with open(cfg_filename,'r') as ymlfile:
    config = yaml.load(ymlfile)

In [None]:
DG = DataGenerator(config,None,'Sinusoid')
N_train = 5000
N_test = 500
test_horz = 30

Y,X,phase_list,freq_list,amp_list = DG.sample_trajectories(None,test_horz,N_train,return_lists=True)
Y_test,X_test,phase_list_test,freq_list_test,amp_list_test = DG.sample_trajectories(None,test_horz,N_test,return_lists=True)

## Default agent

In [None]:
g1 = tf.Graph()
sess1 = tf.InteractiveSession(config=tf.ConfigProto(log_device_placement=True), graph=g1)
agent = ALPaCA(config)

In [None]:
agent.construct_model(sess1,g1)

In [None]:
agent.train(sess1,Y,X,3000)

## Agent without meta-training

In [None]:
g2 = tf.Graph()
sess2 = tf.Session(config=tf.ConfigProto(log_device_placement=True), graph=g2)
agent_nometa = ALPaCA(config)
agent_nometa.config['data_horizon'] = 0
agent_nometa.construct_model(sess2,g2)
agent_nometa.train(sess2, Y,X, 3000)

## GP regression agent

In [None]:
from gp_reg import *

GPR_agent = GPReg()

## MAML Agent

In [None]:
g3 = tf.Graph()
sess3 = tf.Session(config=tf.ConfigProto(log_device_placement=True), graph=g3)
maml_agent = MAMLAgent(config, exp_string="sinusoid_test5")

In [None]:
maml_agent.construct_model(sess3, g3)

# Visualize all agents

In [None]:
ind = 12
sample_size_list = [0,1,2,3,5]
plt.figure(figsize=(9,len(sample_size_list)*1))
for i,num_pts in enumerate(sample_size_list):
    X_update = X_test[ind:(ind+1),:num_pts,:]
    Y_update = Y_test[ind:(ind+1),:num_pts,:]
    
    title=None
    legend=False
    if i == 0:
        legend=True
        title=True
        
    ax1 = plt.subplot(len(sample_size_list),3,3*i+1)
    gen_sin_fig(agent, sess1, X_update, Y_update, freq_list_test[ind], phase_list_test[ind], amp_list_test[ind], label=None)
    if i == 0:
        plt.title('ALPaCA')
    if i < len(sample_size_list) - 1:
        plt.setp(ax1.get_xticklabels(), visible=False)
    
    ax2 = plt.subplot(len(sample_size_list),3,3*i+2, sharey=ax1)
    gen_sin_fig(agent_nometa, sess2, X_update, Y_update, freq_list_test[ind], phase_list_test[ind], amp_list_test[ind], label=None)
    plt.setp(ax2.get_yticklabels(), visible=False)
    if i == 0:
        plt.title('ALPaCA (no meta)')
    if i < len(sample_size_list) - 1:
        plt.setp(ax2.get_xticklabels(), visible=False)
    
    ax3 = plt.subplot(len(sample_size_list),3,3*i+3, sharey=ax1)
    gen_sin_gp_fig(GPR_agent, X_update, Y_update, freq_list_test[ind], phase_list_test[ind], amp_list_test[ind], label=None)
    plt.setp(ax3.get_yticklabels(), visible=False)
    if i == 0:
        plt.title('GPR')
    if i < len(sample_size_list) - 1:
        plt.setp(ax3.get_xticklabels(), visible=False)

plt.tight_layout(w_pad=0.0,h_pad=0.2)
plt.savefig('figures/sinusoid_three.pdf')
plt.show()

In [None]:
# compute NLL for all three models
def gaussian_nll(y,mu,Sig):
    n = 1
    logdet = np.log(Sig) 
    quadform = (y-mu) 
    nll = n*np.log(2*np.pi) + logdet + ((y-mu).T * (1/Sig) * (y-mu))
    return 0.5*nll

alpaca_nll_mean = []
alpaca_nll_var = []

lpaca_nll_mean = []
lpaca_nll_var = []

gp_nll_mean = []
gp_nll_var = []


for j in range(test_horz):
    nll_list_alpaca = []
    nll_list_lpaca = []
    nll_list_gp = []
    
    for ind in range(N_test):
        X_update = X_test[ind:(ind+1),:j,:]
        Y_update = Y_test[ind:(ind+1),:j,:]
        x_pt = X_test[ind:(ind+1),(j):(j+1),:]
        y_pt = Y_test[ind:(ind+1),(j):(j+1),:]

        y, s = agent.test(sess1, X_update, Y_update, x_pt)
        y_lpaca,s_lpaca = agent_nometa.test(sess2, X_update, Y_update, x_pt)
        y_gp, s_gp = GPR_agent.test(X_update, Y_update, x_pt)
                
        nll_list_alpaca.append(gaussian_nll(y_pt[0,0,0],y[0,0,0],s[0,0,0,0]))
        nll_list_lpaca.append(gaussian_nll(y_pt[0,0,0],y_lpaca[0,0,0],s_lpaca[0,0,0,0]))
        nll_list_gp.append(gaussian_nll(y_pt[0,0,0],y_gp[0,0,0],s_gp[0,0]))
        
    nll_mean_alpaca = sum(nll_list_alpaca)/N_test
    nll_var_alpaca = sum([(nl - nll_mean_alpaca)**2 for nl in nll_list_alpaca])/(N_test-1)
    
    nll_mean_lpaca = sum(nll_list_lpaca)/N_test
    nll_var_lpaca = sum([(nl - nll_mean_lpaca)**2 for nl in nll_list_lpaca])/(N_test-1)
    
    nll_mean_gp = sum(nll_list_gp)/N_test
    nll_var_gp = sum([(nl - nll_mean_gp)**2 for nl in nll_list_gp])/(N_test-1)
    
    alpaca_nll_mean.append(nll_mean_alpaca)
    alpaca_nll_var.append(nll_var_alpaca)
    
    lpaca_nll_mean.append(nll_mean_lpaca)
    lpaca_nll_var.append(nll_var_lpaca)
    
    gp_nll_mean.append(nll_mean_gp)
    gp_nll_var.append(nll_var_gp)
    


In [None]:
#compute MSE and time
import time

def gaussian_nll(y,mu,Sig):
    n = 1
    logdet = np.log(Sig) 
    quadform = (y-mu) 
    nll = n*np.log(2*np.pi) + logdet + ((y-mu).T * (1/Sig) * (y-mu))
    return 0.5*nll

def MSE(y,mu):
    return (y-mu)**2

def get_stats(meas, N):
    mean = sum(meas)/N
    var = sum([(nl - mean)**2 for nl in meas])/(N-1)
    return mean, var

alpaca_nll_mean = []
alpaca_nll_var = []
alpaca_time_mean = []
alpaca_time_var = []
alpaca_mse_mean = []
alpaca_mse_var = []

maml_time_mean = []
maml_time_var = []
maml_mse_mean = []
maml_mse_var = []
maml5_mse_mean = []
maml5_mse_var = []

lpaca_nll_mean = []
lpaca_nll_var = []
lpaca_time_mean = []
lpaca_time_var = []

gp_nll_mean = []
gp_nll_var = []
gp_time_mean = []
gp_time_var = []

N_test = 500

for j in range(test_horz):
    nll_list_alpaca = []
    nll_list_lpaca = []
    nll_list_gp = []
    
    mse_list_alpaca = []
    mse_list_maml = []
    mse_list_maml5 = []

    time_list_alpaca = []
    time_list_maml = []
    time_list_lpaca = []
    time_list_gp = []
    
    for ind in range(N_test):
        X_update = X_test[ind:(ind+1),:j,:]
        Y_update = Y_test[ind:(ind+1),:j,:]
        x_pt = X_test[ind:(ind+1),(j):(j+1),:]
        y_pt = Y_test[ind:(ind+1),(j):(j+1),:]
    
        t1_alpaca = time.process_time()
        y, s = agent.test(sess1, X_update, Y_update, x_pt)
        t2_alpaca = time.process_time()
        
        t1_maml = time.process_time()
        y_maml, s = maml_agent.test(sess3, X_update, Y_update, x_pt)
        t2_maml = time.process_time()
        y_maml5, s = maml_agent.test(sess3, X_update, Y_update, x_pt, num_updates=5)
        
        time_list_alpaca.append(t2_alpaca - t1_alpaca)
        time_list_maml.append(t2_maml - t1_maml)
        
        mse_list_alpaca.append(MSE(y_pt[0,0,0],y[0,0,0]))
        mse_list_maml.append(MSE(y_pt[0,0,0], y_maml[0,0,0]))
        mse_list_maml5.append(MSE(y_pt[0,0,0], y_maml5[0,0,0]))
        
    time_mean_alpaca, time_var_alpaca = get_stats(time_list_alpaca,N_test)
    time_mean_maml, time_var_maml = get_stats(time_list_maml,N_test)
    
    mse_mean_alpaca, mse_var_alpaca = get_stats(mse_list_alpaca,N_test)
    mse_mean_maml, mse_var_maml = get_stats(mse_list_maml,N_test)
    mse_mean_maml5, mse_var_maml5 = get_stats(mse_list_maml5,N_test)
    
    alpaca_time_mean.append(time_mean_alpaca)
    alpaca_time_var.append(time_var_alpaca)
    
    maml_time_mean.append(time_mean_maml)
    maml_time_var.append(time_var_maml)
    
    alpaca_mse_mean.append(mse_mean_alpaca)
    alpaca_mse_var.append(mse_var_alpaca)
    
    maml_mse_mean.append(mse_mean_maml)
    maml_mse_var.append(mse_var_maml)
    
    maml5_mse_mean.append(mse_mean_maml5)
    maml5_mse_var.append(mse_var_maml5)

In [None]:
plt.figure(figsize=(3.5,3))
nll_plot(alpaca_nll_mean,alpaca_nll_var,lpaca_nll_mean,lpaca_nll_var,gp_nll_mean,gp_nll_var,N_test,legend=True)
plt.savefig('figures/nll_sinusoid.pdf')
plt.show()

In [None]:
plt.figure(figsize=(3.5,3))
mse_plot(alpaca_mse_mean,alpaca_mse_var,maml_mse_mean,maml_mse_var,maml5_mse_mean,maml5_mse_var,N_test,legend=True)
plt.tight_layout()
plt.savefig('figures/mse_sinusoid.pdf')
plt.show()

In [None]:
plt.figure(figsize=(3.5,3))
time_plot(alpaca_time_mean,alpaca_time_var,lpaca_time_mean,lpaca_time_var,gp_time_mean,gp_time_var,N_test,legend=True)
plt.savefig('figures/time_sinusoid.pdf')
plt.show()