In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import gym
from main.maml_agent import MAMLAgent, MAMLDynamics, register_flags
register_flags()
from main.alpaca import *
from main.datagen import *
from main.dataViz import *

In [None]:
cfg_filename = 'configs/swerving-config.yml'
with open(cfg_filename,'r') as ymlfile:
    config = yaml.load(ymlfile)

In [None]:
import pandas as pd
trajs = pd.read_pickle("data/trajectories_slim.pkl")

In [None]:
# build traj dataset 
traj_mat = np.zeros((1106,33,8))
traj_mat_p = np.zeros((1106,33,8))
for i,label in enumerate(['x1', 'y1', 'xd1', 'yd1', 'x2', 'y2','xd2', 'yd2']):
    itr = -1
    for k,v in trajs[label].items():
        if k[1] == 0:
            itr += 1
        if k[1] < 33:
            traj_mat[itr,k[1],i] = v
        if (not (k[1] == 0)) and k[1]<34:
            traj_mat_p[itr,k[1]-1,i] = v

#shuffle this data
inds = np.random.choice(traj_mat.shape[0],traj_mat.shape[0])
traj_mat = traj_mat[inds,:,:]
traj_mat_p = traj_mat_p[inds,:,:]
            
N_train = 1000
N_test = 100
X = traj_mat[:1000,:,:]
Y = traj_mat_p[:1000,:,:] - X

X_test = traj_mat[1000:1100,:,:]
Y_test = traj_mat_p[1000:1100,:,:] - X_test

In [None]:
datagen = DataGenFromData(X,Y)

## Default agent

In [None]:
g1 = tf.Graph()
sess1 = tf.Session(config=tf.ConfigProto(log_device_placement=True), graph=g1)
agent = AdaptiveDynamics(config)

In [None]:
agent.construct_model(sess1,g1)

In [None]:
agent.train(sess1,Y,X,4000)

## No Meta Agent

In [None]:
g2 = tf.Graph()
sess2 = tf.Session(config=tf.ConfigProto(log_device_placement=True), graph=g2)
agent_nometa = AdaptiveDynamics(config)
agent_nometa.config['data_horizon'] = 0
agent_nometa.construct_model(sess2,g2)

In [None]:
agent_nometa.train(sess2, Y,X, 2000)

## MAML Agent

In [None]:
g3 = tf.Graph()
sess3 = tf.Session(config=tf.ConfigProto(log_device_placement=True), graph=g3)
maml_model = MAMLDynamics(config, exp_string="swerve_test")
maml_model.construct_model(sess3,g3)

In [None]:
maml_model.train(sess3, datagen, 25000)

In [None]:
def plot_swerve(agent,sess,X,Y,N_samples=30,T=33,T_rollout=12, ind=0):
    #plot trajs for each car
    plt.plot(X[ind,:,0],X[ind,:,1],color='k',linestyle=':', alpha=0.5)
    plt.plot(X[ind,:,4],X[ind,:,5],color='k', linestyle=':', alpha=0.5)
    for Nu in [0,10,20]:
        tt = np.arange(T+1)
        ux = X[ind:ind+1,:Nu,:]
        uy = Y[ind:ind+1,:Nu,:]

        K0 = sess.run(agent.K)
        L0 = sess.run(agent.L)
        SigEps = sess.run(agent.SigEps)

        Phi = sess.run( agent.phi, {agent.x: X} )
        uPhi = Phi[ind:ind+1,:Nu,:]

        Kn = K0
        Ln = L0
        Ln_inv = np.linalg.inv(Ln)
        if Nu > 0:
            Kn,Ln_inv = agent.batch_update_np(K0,L0,uPhi[0,:,:],uy[0,:,:])
            Ln = np.linalg.inv(Ln_inv)

        x_pred = np.zeros([N_samples, T+1, X.shape[2]])
        x_pred[:,:Nu+1,:] = X[ind:ind+1, :Nu+1, :]

        for j in range(N_samples):
            K = sampleMN(Kn,Ln_inv,SigEps)
        #         print(K)
            for t in range(Nu,Nu+T_rollout):
                phi_t = sess.run( agent.phi, {agent.x: x_pred[j:j+1, t:t+1, :]})
                x_pred[j,t+1,:] = x_pred[j,t,:] + np.squeeze( phi_t[0,:,:] @ K )

        #plot trajs for each car
        colors = ['C0','C1']
        for j in range(N_samples):
            plt.plot(x_pred[j,Nu:Nu+T_rollout,0], x_pred[j,Nu:Nu+T_rollout,1], color=colors[0], alpha=5.0/N_samples)
            plt.plot(x_pred[j,Nu:Nu+T_rollout,4], x_pred[j,Nu:Nu+T_rollout,5], color=colors[1], alpha=5.0/N_samples)
        plt.plot(X[ind,Nu:Nu+T_rollout,0],X[ind,Nu:Nu+T_rollout,1],color='k',alpha=0.5)
        plt.plot(X[ind,Nu:Nu+T_rollout,4],X[ind,Nu:Nu+T_rollout,5],color='k',alpha=0.5)

def plot_swerve_maml(agent,sess,X,Y,N_samples=30,T=33,T_rollout=12,ind=0):
    #plot trajs for each car
    plt.plot(X[ind,:,0],X[ind,:,1],color='k',linestyle=':', alpha=0.5)
    plt.plot(X[ind,:,4],X[ind,:,5],color='k', linestyle=':', alpha=0.5)
    for Nu in [0,10,20]:
        tt = np.arange(T+1)
        x_dim = Y.shape[2]
        actions = X[0,:,x_dim:]
        agent.reset_to_prior()
        for t in range(0,Nu):
            x = X[0,t,:x_dim]
            u = X[0,t,x_dim:]
            xp = x + Y[0,t,:]

            agent.incorporate_transition(sess, x,u,xp)



        x_pred = np.zeros([1, T+1, x_dim])
        x_pred[:,:Nu+1,:] = X[ind, :Nu+1, :x_dim]
        x_pred[:,Nu+1:,:] = agent.sample_rollout(sess, x_pred[:,Nu,:], actions[Nu:,:])


        colors = ['C0','C1']
        plt.plot(x_pred[0,Nu:Nu+T_rollout,0], x_pred[0,Nu:Nu+T_rollout,1], color=colors[0], alpha=0.8)
        plt.plot(x_pred[0,Nu:Nu+T_rollout,4], x_pred[0,Nu:Nu+T_rollout,5], color=colors[1], alpha=0.8)
        plt.plot(X[ind,Nu:Nu+T_rollout,0],X[ind,Nu:Nu+T_rollout,1],color='k',alpha=0.5)
        plt.plot(X[ind,Nu:Nu+T_rollout,4],X[ind,Nu:Nu+T_rollout,5],color='k',alpha=0.5)

In [None]:
N_examples = 4
plt.figure(figsize=(5.5,9))
for i, ind in enumerate(np.random.choice(100, N_examples)):
    ax1 = plt.subplot(N_examples*2, 2, 2*i + 1)
    plot_swerve(agent,sess1,X_test,Y_test,T_rollout=10,ind=ind)
    plt.ylabel('Lane Position')
    if i == 0:
        plt.title('ALPaCA')
    if i < N_examples - 1:
        plt.setp(ax1.get_xticklabels(), visible=False)
    if i == N_examples - 1:
        plt.xlabel('Longitudinal Position')
        
    ax3 = plt.subplot(N_examples*2, 2, 2*i + 2, sharey=ax1)
    plot_swerve_maml(maml_model,sess3,X_test,Y_test,T_rollout=10,ind=ind)
    if i == 0:
        plt.title('MAML')
    if i < N_examples - 1:
        plt.setp(ax3.get_xticklabels(), visible=False)
    plt.setp(ax3.get_yticklabels(), visible=False)
    if i == N_examples - 1:
        plt.xlabel('Longitudinal Position')

plt.tight_layout(w_pad=0.0,h_pad=-0.5)
plt.savefig('figures/swerving_rollouts.pdf') 
plt.show()

In [None]:
# NLL computation
import time

def gaussian_nll(y,mu,Sig):
    _,T,n = y.shape
    total_logdet = 0
    total_quadform = 0
    
    for t in range(T):
        total_logdet += np.log(np.linalg.det(Sig[0,t,:,:])) #np.log(np.linalg.det(Sig))
        total_quadform += ((y[0:1,t,:]-mu[0:1,t,:]) @ (np.linalg.inv(Sig[0,t,:,:])) @ (y[0:1,t,:]-mu[0:1,t,:]).T)
    nll = n*np.log(2*np.pi) + (total_logdet + total_quadform)/T
    
    return 0.5*nll[0,0]

def MSE(y,mu):
    return np.mean( np.sum( (y-mu)**2, axis=-1) )

def get_stats(meas, N):
    mean = sum(meas)/N
    var = sum([(nl - mean)**2 for nl in meas])/(N-1)
    return mean, var

alpaca_nll_mean = []
alpaca_nll_var = []
alpaca_time_mean = []
alpaca_time_var = []
alpaca_mse_mean = []
alpaca_mse_var = []

maml_time_mean = []
maml_time_var = []
maml_mse_mean = []
maml_mse_var = []
maml5_mse_mean = []
maml5_mse_var = []

lpaca_nll_mean = []
lpaca_nll_var = []
lpaca_time_mean = []
lpaca_time_var = []


prior_nll_mean = []
prior_nll_var = []

N_test = 100
data_horz = 30
for j in range(0,data_horz):
    nll_list_alpaca = []
    nll_list_lpaca = []
    
    mse_list_alpaca = []
    mse_list_maml = []
    mse_list_maml5 = []

    time_list_alpaca = []
    time_list_maml = []
    time_list_lpaca = []
    
    nll_list_prior = []

    
    for ind in range(N_test):
        X_update = X_test[ind:(ind+1),:j,:]
        Y_update = Y_test[ind:(ind+1),:j,:]
        
        X_empty = X_test[ind:(ind+1),:0,:]
        Y_empty = Y_test[ind:(ind+1),:0,:]
                
        x_pt = X_test[ind:(ind+1),:,:]
        y_pt = Y_test[ind:(ind+1),:,:]
    
        t1_alpaca = time.process_time()
        y, s = agent.test(sess1, X_update, Y_update, x_pt)
        t2_alpaca = time.process_time()
        
        y_prior,s_prior = agent_nometa.test(sess2, X_empty, Y_empty, x_pt)
        
        t1_maml = time.process_time()
        y_maml, _ = maml_model.test(sess3, X_update, Y_update, x_pt)
        t2_maml = time.process_time()
        y_maml5, _ = maml_model.test(sess3, X_update, Y_update, x_pt, num_updates=5)
        t1_lpaca = time.process_time()
        y_lpaca,s_lpaca = agent_nometa.test(sess2, X_update, Y_update, x_pt)
        t2_lpaca = time.process_time()
        
        time_list_alpaca.append(t2_alpaca - t1_alpaca)
        time_list_maml.append(t2_maml - t1_maml)
        time_list_lpaca.append(t2_lpaca - t1_lpaca)
            
        nll_list_alpaca.append(gaussian_nll(y_pt,y,s))
        nll_list_lpaca.append(gaussian_nll(y_pt,y_lpaca,s_lpaca))
        nll_list_prior.append(gaussian_nll(y_pt,y_prior,s_prior))
        
        mse_list_alpaca.append(MSE(y_pt,y))
        mse_list_maml.append(MSE(y_pt, y_maml))
        mse_list_maml5.append(MSE(y_pt, y_maml5))
        
    nll_mean_alpaca, nll_var_alpaca = get_stats(nll_list_alpaca,N_test)
    nll_mean_lpaca, nll_var_lpaca = get_stats(nll_list_lpaca,N_test)
    nll_mean_prior, nll_var_prior = get_stats(nll_list_prior,N_test)

    time_mean_alpaca, time_var_alpaca = get_stats(time_list_alpaca,N_test)
    time_mean_lpaca, time_var_lpaca = get_stats(time_list_lpaca,N_test)
    
    mse_mean_alpaca, mse_var_alpaca = get_stats(mse_list_alpaca,N_test)
    mse_mean_maml, mse_var_maml = get_stats(mse_list_maml,N_test)
    mse_mean_maml5, mse_var_maml5 = get_stats(mse_list_maml5,N_test)
    
    alpaca_nll_mean.append(nll_mean_alpaca)
    alpaca_nll_var.append(nll_var_alpaca)
    
    lpaca_nll_mean.append(nll_mean_lpaca)
    lpaca_nll_var.append(nll_var_lpaca)
    
    prior_nll_mean.append(nll_mean_prior)
    prior_nll_var.append(nll_var_prior)
    
    alpaca_time_mean.append(time_mean_alpaca)
    alpaca_time_var.append(time_var_alpaca)
    
    lpaca_time_mean.append(time_mean_lpaca)
    lpaca_time_var.append(time_var_lpaca)
    
    alpaca_mse_mean.append(mse_mean_alpaca)
    alpaca_mse_var.append(mse_var_alpaca)
    
    maml_mse_mean.append(mse_mean_maml)
    maml_mse_var.append(mse_var_maml)
    
    maml5_mse_mean.append(mse_mean_maml5)
    maml5_mse_var.append(mse_var_maml5)

In [None]:
plt.figure(figsize=(3.5,3))
mse_plot(alpaca_mse_mean,alpaca_mse_var,maml_mse_mean,maml_mse_var,maml5_mse_mean,maml5_mse_var,N_test,legend=True)
plt.tight_layout()
plt.savefig('figures/mse_swerving.pdf')
plt.show()

In [None]:
plt.figure(figsize=(3.5,3))
nll_plot(alpaca_nll_mean,alpaca_nll_var,lpaca_nll_mean,lpaca_nll_var,prior_nll_mean,prior_nll_var,N_test,legend=True,last_legend_label=r'ALPaCA (no update)')
plt.tight_layout()
plt.savefig('figures/nll_swerving.pdf')
plt.show()