In [1]:
import time
import gym
import numpy as np
import control as ct
import tensorflow as tf
import tensorflow.keras.layers as layers
import tensorflow.keras.regularizers as reg
from tensorflow_probability import distributions as tfd
import matplotlib.pyplot as plt
import os
import pickle as pkl
tf.enable_eager_execution()


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.



In [2]:
def split_forward_filter_fn(A,B,u,g,C,sigma,l_a_posteriori,P_a_posteriori,z):
    '''Calculates prior distribution based on the previous posterior distribution
        and the current residual updates posterior distribution based on the new
        prior distribution
    '''
#     print('z',z)
#     print('A', A)
#     print('B',B)
#     print('u',u)
#     print('g',g)
#     print('sigma',sigma)
#     print('C', C)
#     print('l_a_posteriori', l_a_posteriori)
#     print('P_a_posteriori', P_a_posteriori)
    _I = tf.eye(int(A.shape[0]), dtype = tf.float64)
    
    z = tf.reshape(z,[1,1])
    l_a_priori = tf.matmul(A,l_a_posteriori) + tf.matmul(B,u)
#     print('l_a_priori',l_a_priori)
    P_a_priori = tf.matmul(tf.matmul(A,P_a_posteriori), A, transpose_b = True) + tf.matmul(g,g, transpose_b=True)
#     print('P_a_priori',P_a_priori)
    y_pre = z - tf.matmul(C,l_a_priori)
#     print('y_pre', y_pre)
    S = tf.square(sigma) + tf.matmul(tf.matmul(C, P_a_priori), C, transpose_b=True)
#     print('S',S)
    S_inv = tf.math.reciprocal(S)
#     print('S_inv', S_inv)
    K = tf.matmul(tf.matmul(P_a_priori, C, transpose_b=True), S_inv)
#     print('K', K)
    l_a_posteriori = l_a_priori + tf.matmul(K,y_pre)
#     print('l_a_posteriori', l_a_posteriori)
    I_KC = _I-tf.matmul(K,C)
#     print('I-KC', I_KC)
    P_a_posteriori = tf.matmul(tf.matmul(I_KC, P_a_priori), I_KC, transpose_b=True) + \
                        tf.matmul(tf.matmul(K,tf.matmul(sigma, sigma, transpose_b = True)),
                                K, transpose_b=True)
#     print('P_a_posteriori',P_a_posteriori)
    y_post = z-tf.matmul(C,l_a_posteriori)
    squared_error = tf.squeeze(tf.matmul(y_post,y_post, transpose_a=True))
#     print(squared_error)
#     print('y_post', y_post)
    pred = tf.matmul(C, l_a_posteriori)
#     print('pred', pred)

    return l_a_posteriori,P_a_posteriori,z, pred, squared_error

In [3]:
class split_KF_Model(object):
    def __init__(self, model_name, thetaacc_error = 0, env_params_variation = [0,0,0,0], initial_state_variation = [0,0,0,0], control = False):
        self.m = 4
        self.dim_z = self.m
        self.n = 4
        self.r = 1
        self.sigma_upper_bound = 1
        self.sigma_lower_bound = 0
        self.g_upper_bound = 1
        self.g_lower_bound = 0.1
        self.mu_0_upper_bound = 1
        self.mu_0_lower_bound = 0
        self.Sigma_0_upper_bound = 1
        self.Sigma_0_lower_bound = 0
        self.weight_beta = .1
        self.bias_beta = .1
        self.thetaacc_error = thetaacc_error
        self.global_epoch = 0
        self.model_name = model_name
        
        self.mu_0_NN = tf.keras.Sequential(name = 'mu_0_NN', layers = [layers.Dense(self.m*8, activation = tf.sigmoid, kernel_regularizer = reg.l2(self.weight_beta),
                                                            bias_regularizer = reg.l2(self.bias_beta),name = 'mu_0_dense1'),
                                               layers.Dense(self.m, activation = tf.nn.leaky_relu, kernel_regularizer = reg.l2(self.weight_beta),
                                                            bias_regularizer = reg.l2(self.bias_beta),name = 'mu_0_dense2')])
        self.Sigma_0_NN = tf.keras.Sequential(name='Sigma_0_NN',layers=[layers.Dense(self.m*8, activation = tf.sigmoid, kernel_regularizer = reg.l2(self.weight_beta),
                                                               bias_regularizer = reg.l2(self.bias_beta),name = 'Sigma_0dense1'),
                                                  layers.Dense(self.m, activation = tf.nn.leaky_relu, kernel_regularizer = reg.l2(self.weight_beta),
                                                               bias_regularizer = reg.l2(self.bias_beta),name = 'Sigma_0dense2')])
        
        self.A_NN = tf.keras.Sequential(name='A_dense_NN',layers=[layers.Dense(self.m*self.n*8, activation = tf.sigmoid, kernel_regularizer = reg.l1(self.weight_beta),
                                                         bias_regularizer = reg.l2(self.bias_beta),name = 'A_dense1'),
                                            layers.Dense(self.m*self.n, activation = tf.nn.leaky_relu, kernel_regularizer = reg.l1(self.weight_beta),
                                                         bias_regularizer = reg.l2(self.bias_beta),name = 'A_dense2')])
        self.g_NN = tf.keras.Sequential(name='g_dense_NN',layers=[layers.Dense(self.m*8, activation = tf.sigmoid, kernel_regularizer = reg.l1(self.weight_beta),
                                                         bias_regularizer = reg.l2(self.bias_beta),name = 'g_dense1'),
                                            layers.Dense(self.m, activation = tf.nn.leaky_relu, kernel_regularizer = reg.l1(self.weight_beta),
                                                         bias_regularizer = reg.l2(self.bias_beta),name = 'g_dense2')])
        self.sigma1_NN = tf.keras.Sequential(name='sigma1_dense_NN',layers=[layers.Dense(1*8, activation = tf.sigmoid, kernel_regularizer = reg.l1(self.weight_beta),
                                                              bias_regularizer = reg.l2(self.bias_beta),name = 'sigma1_dense1'),
                                                 layers.Dense(1, activation = tf.nn.leaky_relu, kernel_regularizer = reg.l1(self.weight_beta),
                                                              bias_regularizer = reg.l2(self.bias_beta),name = 'sigma1_dense2')])
        self.sigma2_NN = tf.keras.Sequential(name='sigma2_dense_NN',layers=[layers.Dense(1*8, activation = tf.sigmoid, kernel_regularizer = reg.l1(self.weight_beta),
                                                              bias_regularizer = reg.l2(self.bias_beta),name = 'sigma2_dense1'),
                                                 layers.Dense(1, activation = tf.nn.leaky_relu, kernel_regularizer = reg.l1(self.weight_beta),
                                                              bias_regularizer = reg.l2(self.bias_beta),name = 'sigma2_dense2')])
        self.sigma3_NN = tf.keras.Sequential(name='sigma3_dense_NN',layers=[layers.Dense(1*8, activation = tf.sigmoid, kernel_regularizer = reg.l1(self.weight_beta),
                                                              bias_regularizer = reg.l2(self.bias_beta),name = 'sigma3_dense1'),
                                                 layers.Dense(1, activation = tf.nn.leaky_relu, kernel_regularizer = reg.l1(self.weight_beta),
                                                              bias_regularizer = reg.l2(self.bias_beta),name = 'sigma3_dense2')])
        self.sigma4_NN = tf.keras.Sequential(name='sigma4_dense_NN',layers=[layers.Dense(1*8, activation = tf.sigmoid, kernel_regularizer = reg.l1(self.weight_beta),
                                                              bias_regularizer = reg.l2(self.bias_beta),name = 'sigma4_dense1'),
                                                 layers.Dense(1, activation = tf.nn.leaky_relu, kernel_regularizer = reg.l1(self.weight_beta),
                                                              bias_regularizer = reg.l2(self.bias_beta),name = 'sigma4_dense2')])
        self.NNs = [self.mu_0_NN,self.Sigma_0_NN,self.A_NN,self.g_NN,
                    self.sigma1_NN,self.sigma2_NN,self.sigma3_NN,self.sigma4_NN]
        
        self.lstm_sizes = [256,128]        
        lstms = [tf.contrib.rnn.LSTMCell(size, reuse=tf.get_variable_scope().reuse) for size in self.lstm_sizes]
        dropouts = [tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob = 0.5) for lstm in lstms]

        self.cell = tf.contrib.rnn.MultiRNNCell(dropouts)
        dummy_state = self.cell.get_intitial_state(batch_size=1,dtype = tf.float64)
        dummy_input = tf.zeros([1,8],dtype = tf.float64)
        dummy_output,dummy_state=self.cell(inputs=dummy_input,state=dummy_state)
        
        self.C_1 = tf.Variable(np.array([[1,0,0,0]]), dtype = tf.float64, trainable=False)
        self.C_2 = tf.Variable(np.array([[0,1,0,0]]), dtype = tf.float64, trainable=False)
        self.C_3 = tf.Variable(np.array([[0,0,1,0]]), dtype = tf.float64, trainable=False)
        self.C_4 = tf.Variable(np.array([[0,0,0,1]]), dtype = tf.float64, trainable=False)
        
        '''Temporary LQR variables'''
        self.Q = np.eye(4)*[1,1,100,1]
        self.R = 100
        self.u_clip_value = tf.Variable(10, dtype = tf.float64, trainable = False)

        self.initial_variance_estimate = tf.Variable(np.array([[1]]), dtype = tf.float64, trainable=False) 

        self.env = gym.make('Custom_CartPole-v0', thetaacc_error=self.thetaacc_error, env_params_var=env_params_variation, initial_state_var=initial_state_variation)
        gravity = self.env.gravity
        cart_mass = self.env.masscart
        pole_mass = self.env.masspole
        pole_length = self.env.length
        self.env_params = tf.expand_dims(np.array([gravity, cart_mass,pole_mass,pole_length],
                                             dtype=np.float64),0)
        self.control = control
#         self.variables = []
        
    def build_LSTM(self):
        lstms = [tf.contrib.rnn.LSTMCell(size, reuse=tf.get_variable_scope().reuse) for size in self.lstm_sizes]
        dropouts = [tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob = 0.5) for lstm in lstms]

        self.cell = tf.contrib.rnn.MultiRNNCell(dropouts)
        return self
    
#     def get_variables(self):
#         return self.variables
#     def reset_variables(self):
#         self.variables = []
#         return self
    def set_control(self, control, control_type):
        self.control = control
        self.control_type = control_type
        return self
    def set_env(self,env_params_variation,initial_state_variation):
        self.env = gym.make('Custom_CartPole-v0', thetaacc_error=self.thetaacc_error, env_params_var=env_params_variation, initial_state_var=initial_state_variation)
        return self
    def set_u_clip(self,u_clip_value):
        self.u_clip_value = tf.Variable(u_clip_value, dtype = tf.float64, trainable = False)
        return self
    def set_LQR_params(self,Q, R):
        self.Q = np.eye(4)*Q
        self.R = R
        return self
    
    def likelihood_fn(self, params, inputs):
        A, B, u, g, C, sigma, l_filtered, P_filtered = inputs
        mu_1, Sigma_1 = params
#         print('A',len(A))
#         print('B',len(B))
#         print('u',len(u))
#         print('C',len(C))
#         print('g',len(g))
#         print('sigma',len(sigma))
#         print('l_filtered',len(l_filtered))
#         print('p_filtered',len(P_filtered))
#         print('mu_1',mu_1.shape)
#         print('Sigma_1',Sigma_1.shape)
        mu = [mu_1]
        Sigma = [Sigma_1]
        assert(len(A)==len(B) and len(B)==len(u) and len(u)==len(sigma) and 
               len(sigma)==len(l_filtered) and len(l_filtered)==len(P_filtered)),"Not all sequences are same length"
        for i in range(len(A)):
            mu.append(tf.matmul(C, tf.add(tf.matmul(A[i],l_filtered[i]), tf.matmul(B[i],u[i]))))
            temp = tf.matmul(tf.matmul(A[i], P_filtered[i]), A[i], transpose_b=True) + \
                        tf.matmul(g[i], g[i], transpose_b=True)
            Sigma.append(tf.matmul(tf.matmul(C, temp), C, transpose_b=True) + \
                        tf.matmul(sigma[i],sigma[i],transpose_b=True))
        return mu,Sigma
    
    def step(self, output_single):
        '''Calculate SSM parameters from LSTM output'''
#         A = layers.Dense(output_single, self.m*self.n, kernel_regularizer = reg.l1(self.weight_beta),
#                             bias_regularizer = reg.l2(self.bias_beta),
#                             name = 'A_dense', reuse = True)
#         g = layers.Dense(output_single, self.m, kernel_regularizer = reg.l1(self.weight_beta),
#                                 bias_regularizer = reg.l2(self.bias_beta),
#                                 name = 'g_dense', reuse = True)
#         sigma1 = layers.Dense(output_single, 1, kernel_regularizer = reg.l1(self.weight_beta),
#                                  bias_regularizer = reg.l2(self.bias_beta),
#                                  name = 'sigma1_dense', reuse = True)
#         sigma2 = layers.Dense(output_single, 1, kernel_regularizer = reg.l1(self.weight_beta),
#                                  bias_regularizer = reg.l2(self.bias_beta),
#                                  name = 'sigma2_dense', reuse = True)
#         sigma3 = layers.Dense(output_single, 1, kernel_regularizer = reg.l1(self.weight_beta),
#                                  bias_regularizer = reg.l2(self.bias_beta),
#                                  name = 'sigma3_dense', reuse = True)
#         sigma4 = layers.Dense(output_single, 1, kernel_regularizer = reg.l1(self.weight_beta),
#                                  bias_regularizer = reg.l2(self.bias_beta),
#                                  name = 'sigma4_dense', reuse = True)
        A = self.A_NN(output_single)
        g = self.g_NN(output_single)
        sigma1 = self.sigma1_NN(output_single)
        sigma2 = self.sigma2_NN(output_single)
        sigma3 = self.sigma3_NN(output_single)
        sigma4 = self.sigma4_NN(output_single)
        A = tf.reshape(A, shape = (self.m,self.n))
        g = tf.reshape(g, shape = (self.m, 1))
        g = ((self.g_upper_bound-self.g_lower_bound)/(1+tf.exp(-g)))+self.g_lower_bound
        sigma1 = ((self.sigma_upper_bound-self.sigma_lower_bound)/(1+tf.exp(-sigma1)))+self.sigma_lower_bound
        sigma2 = ((self.sigma_upper_bound-self.sigma_lower_bound)/(1+tf.exp(-sigma2)))+self.sigma_lower_bound
        sigma3 = ((self.sigma_upper_bound-self.sigma_lower_bound)/(1+tf.exp(-sigma3)))+self.sigma_lower_bound
        sigma4 = ((self.sigma_upper_bound-self.sigma_lower_bound)/(1+tf.exp(-sigma4)))+self.sigma_lower_bound
        return A, g, sigma1, sigma2, sigma3, sigma4
    
    def control_step(self, output_single, u, A, observation):
        '''Calculate SSM parameter B from LSTM output, and
            calculate u'''
        B = tf.layers.Dense(output_single, self.m*self.r, kernel_regularizer = reg.l1(self.weight_beta),
                            bias_regularizer = reg.l2(self.bias_beta),
                            name = 'B_dense', reuse = True)
        B = tf.reshape(B, shape = (self.m,self.r))
        '''Use one of the below options for directly predicting u from LSTM'''
        if self.control_type == 'NN regularized':
            u = tf.layers.Dense(output_single, 1, kernel_regularizer = reg.l1(self.weight_beta),
                                bias_regularizer = reg.l2(self.bias_beta),
                                name = 'u_dense', reuse = True)
        elif self.control_type == 'NN':
            u = tf.layers.Dense(output_single, 1, name = 'u_dense', reuse = True)
            '''LQR'''
        elif self.control_type == 'LQR':
            K,S,E = ct.lqr(A.numpy(),B.numpy(),self.Q,self.R)
            u = -tf.matmul(K.astype(np.float64),
                           tf.expand_dims(tf.convert_to_tensor(observation,dtype=tf.float64),-1))
            u = tf.clip_by_value(u, -self.u_clip_value, self.u_clip_value)
            '''Random action sampling'''
        elif self.control_type == 'uniform random':
            u = u + tf.random.uniform(shape = [1,self.r], minval=-3.5, maxval=3.5, dtype = tf.float64)
        else:
            pass
        return B, u
    
    def look_ahead_prediction(self, prediction_horizon, observation, output_single, state_single,
                              l_a_post1,l_a_post2,l_a_post3,l_a_post4,
                              P_a_post1,P_a_post2,P_a_post3,P_a_post4):
        LA_output_single = output_single
        LA_state_single = state_single
        
        '''Set initial prediction states to current observation'''
        LA_pred1 = tf.convert_to_tensor(observation[0], dtype = tf.float64)
        LA_pred2 = tf.convert_to_tensor(observation[1], dtype = tf.float64)
        LA_pred3 = tf.convert_to_tensor(observation[2], dtype = tf.float64)
        LA_pred4 = tf.convert_to_tensor(observation[3], dtype = tf.float64)
        LA_preds = []
        LA_l_a_post1 = []
        LA_l_a_post2 = []
        LA_l_a_post3 = []
        LA_l_a_post4 = []
        LA_P_a_post1 = []
        LA_P_a_post2 = []
        LA_P_a_post3 = []
        LA_P_a_post4 = []
        LA_A = []
        LA_B = []
        LA_u = []
        LA_g = []
        LA_sigma1 = []
        LA_sigma2 = []
        LA_sigma3 = []
        LA_sigma4 = []
        
        for i in range(prediction_horizon):
            '''Get SSM parameters from LSTM'''
            A_pred,g_pred,sigma1_pred,sigma2_pred,sigma3_pred,sigma4_pred = self.step(LA_output_single)
            if self.control:
                B_pred, u_pred = self.control_step(LA_output_single, u, A, observation)
            else:
                B_pred = tf.zeros(shape = (self.m,self.r), dtype = tf.float64)
                u_pred = tf.zeros(shape = [1,self.r], dtype=tf.float64)
                
            LA_A.append(A_pred)
            LA_B.append(B_pred)
            LA_u.append(u_pred)
            LA_g.append(g_pred)
            LA_sigma1.append(sigma1_pred)
            LA_sigma2.append(sigma2_pred)
            LA_sigma3.append(sigma3_pred)
            LA_sigma4.append(sigma4_pred)
            '''Predict next states from Kalman Filter'''
            l_a_post1,P_a_post1,env_state1,LA_pred1,_ = split_forward_filter_fn(A_pred,B_pred,u_pred,g_pred,self.C_1,sigma1_pred,l_a_post1,P_a_post1,LA_pred1)
            l_a_post2,P_a_post2,env_state2,LA_pred2,_ = split_forward_filter_fn(A_pred,B_pred,u_pred,g_pred,self.C_2,sigma2_pred,l_a_post2,P_a_post2,LA_pred2)
            l_a_post3,P_a_post3,env_state3,LA_pred3,_ = split_forward_filter_fn(A_pred,B_pred,u_pred,g_pred,self.C_3,sigma3_pred,l_a_post3,P_a_post3,LA_pred3)
            l_a_post4,P_a_post4,env_state4,LA_pred4,_ = split_forward_filter_fn(A_pred,B_pred,u_pred,g_pred,self.C_4,sigma4_pred,l_a_post4,P_a_post4,LA_pred4)
            LA_l_a_post1.append(l_a_post1)
            LA_l_a_post2.append(l_a_post2)
            LA_l_a_post3.append(l_a_post3)
            LA_l_a_post4.append(l_a_post4)
            LA_P_a_post1.append(P_a_post1)
            LA_P_a_post2.append(P_a_post2)
            LA_P_a_post3.append(P_a_post3)
            LA_P_a_post4.append(P_a_post4)
            LA_preds.append(tf.squeeze(tf.concat((LA_pred1,LA_pred2,LA_pred3,LA_pred4),axis=-1)))
            LA_next_input = tf.concat((self.env_params,LA_pred1,LA_pred2,LA_pred3,LA_pred4),axis=1)
            LA_output_single,LA_state_single=self.cell(inputs=LA_next_input,state=LA_state_single)
#         print('LA_A',len(LA_A))
#         print('LA_g',len(LA_g))
#         print('LA_sigma1',len(LA_sigma1))
#         print('LA_lpost',len(LA_l_a_post1))
#         print('LA_Ppost',len(LA_P_a_post1))
        return (LA_preds,LA_A,LA_B,LA_u,LA_g,LA_sigma1,LA_sigma2,LA_sigma3,LA_sigma4,LA_l_a_post1,LA_l_a_post2,LA_l_a_post3,LA_l_a_post4,LA_P_a_post1,LA_P_a_post2,LA_P_a_post3,LA_P_a_post4)
    
    def __call__(self, prediction_horizon, view = False):
#         self.reset_variables()
        rewards = 0
        A_all = []
        B_all = []
        u_all = []
        g_all = []
        sigma1_all = []
        sigma2_all = []
        sigma3_all = []
        sigma4_all = []
        l_a_posteriori1 = []
        l_a_posteriori2 = []
        l_a_posteriori3 = []
        l_a_posteriori4 = []
        P_a_posteriori1 = []
        P_a_posteriori2 = []
        P_a_posteriori3 = []
        P_a_posteriori4 = []
        env_states1 = []
        env_states2 = []
        env_states3 = []
        env_states4 = []
        KF_preds1 = []
        KF_preds2 = []
        KF_preds3 = []
        KF_preds4 = []
        squared_error1 = []
        squared_error2 = []
        squared_error3 = []
        squared_error4 = []
        
        
        
        KF1_params = [l_a_posteriori1,P_a_posteriori1,env_states1, KF_preds1, squared_error1]
        KF2_params = [l_a_posteriori2,P_a_posteriori2,env_states2, KF_preds2, squared_error2]
        KF3_params = [l_a_posteriori3,P_a_posteriori3,env_states3, KF_preds3, squared_error3]
        KF4_params = [l_a_posteriori4,P_a_posteriori4,env_states4, KF_preds4, squared_error4]
        
        '''Prediction function variables'''
        look_ahead_preds = []
        LA_A_all = []
        LA_B_all = []
        LA_u_all = []
        LA_g_all = []
        LA_sigma1_all = []
        LA_sigma2_all = []
        LA_sigma3_all = []
        LA_sigma4_all = []
        LA_l_a_posteriori1 = []
        LA_l_a_posteriori2 = []
        LA_l_a_posteriori3 = []
        LA_l_a_posteriori4 = []
        LA_P_a_posteriori1 = []
        LA_P_a_posteriori2 = []
        LA_P_a_posteriori3 = []
        LA_P_a_posteriori4 = []
        look_ahead_vars = [look_ahead_preds,LA_A_all,LA_B_all,LA_u_all,LA_g_all,LA_sigma1_all,LA_sigma2_all,LA_sigma3_all,LA_sigma4_all,
                           LA_l_a_posteriori1,LA_l_a_posteriori2,LA_l_a_posteriori3,LA_l_a_posteriori4,
                           LA_P_a_posteriori1,LA_P_a_posteriori2,LA_P_a_posteriori3,LA_P_a_posteriori4]
        
        '''p-quantile loss'''
        Q50_numerator = np.zeros(4)
        Q90_numerator = np.zeros(4)
        
        '''Build LSTM'''
#         self.build_LSTM()
        
        '''Start gym environment'''
        observation=self.env.reset()

        '''Get initial lstm state and input, get first output/state'''
        initial_state = self.cell.get_initial_state(batch_size=1,dtype = tf.float64)
        initial_input = tf.concat((self.env_params,tf.expand_dims(tf.convert_to_tensor(observation,dtype=tf.float64),0)),
                                  axis=1)
        output_single, state_single = self.cell(inputs=initial_input, state=initial_state)
#         if not self.control or self.control_type =='uniform random':
#             self.variables.extend(self.cell.trainable_variables)

#         print('LSTM cell trainable',len(self.cell.trainable_variables))
#         print('Rewards', self.rewards)
#         print('VARIABLES',[x.name for x in self.cell.trainable_variables])
#         print('\n\n\nWEIGHTS',[x.name for x in self.cell.trainable_weights])

        '''Calculate mu_0,Sigma_0, distribution using initial LSTM output'''
        container = tf.contrib.eager.EagerVariableStore()
#         control_container = tf.contrib.eager.EagerVariableStore()
#         with container.as_default():
#             mu_0 = tf.layers.Dense(output_single, self.m, kernel_regularizer = reg.l2(self.weight_beta),
#                                        bias_regularizer = reg.l2(self.bias_beta),
#                                        name = 'mu_0dense', reuse = True)
#             Sigma_0 = tf.layers.Dense(output_single, self.m, kernel_regularizer = reg.l2(self.weight_beta),
#                                           bias_regularizer = reg.l2(self.bias_beta),
#                                           name = 'Sigma_0dense', reuse = True)
        mu_0 = self.mu_0_NN(output_single)
        Sigma_0 = self.Sigma_0_NN(output_single)
        mu_0 = tf.reshape(mu_0, shape = (self.m,1))
        mu_0 = ((self.mu_0_upper_bound-self.mu_0_lower_bound)/(1+tf.exp(-mu_0)))+self.mu_0_lower_bound

        Sigma_0 = tf.reshape(Sigma_0, shape = (self.m,1))
        Sigma_0 = tf.matmul(Sigma_0,Sigma_0,transpose_b=True)+tf.eye(4, dtype=tf.float64)*1e-8
        '''Calculate predicted initial distribution'''
        l_0_dist = tfd.MultivariateNormalFullCovariance(loc = tf.squeeze(mu_0),
                                                                covariance_matrix= Sigma_0,
                                                                validate_args=True)
        l_0 = tf.expand_dims(l_0_dist.sample(),1)
        l_a_posteriori1.append(l_0)
        l_a_posteriori2.append(l_0)
        l_a_posteriori3.append(l_0)
        l_a_posteriori4.append(l_0)
        P_a_posteriori1.append(tf.eye(4, dtype = tf.float64)*self.initial_variance_estimate)
        P_a_posteriori2.append(tf.eye(4, dtype = tf.float64)*self.initial_variance_estimate)
        P_a_posteriori3.append(tf.eye(4, dtype = tf.float64)*self.initial_variance_estimate)
        P_a_posteriori4.append(tf.eye(4, dtype = tf.float64)*self.initial_variance_estimate)
        LA_l_a_posteriori1.append(l_0)
        LA_l_a_posteriori2.append(l_0)
        LA_l_a_posteriori3.append(l_0)
        LA_l_a_posteriori4.append(l_0)
        LA_P_a_posteriori1.append(tf.eye(4, dtype = tf.float64)*self.initial_variance_estimate)
        LA_P_a_posteriori2.append(tf.eye(4, dtype = tf.float64)*self.initial_variance_estimate)
        LA_P_a_posteriori3.append(tf.eye(4, dtype = tf.float64)*self.initial_variance_estimate)
        LA_P_a_posteriori4.append(tf.eye(4, dtype = tf.float64)*self.initial_variance_estimate)


        '''set initial u for random uniform control'''
        u = tf.Variable([0.0], dtype = tf.float64, trainable = False)
        
        first_pass = True
        done = False
        while not done:
            if view and self.control:
                self.env.render()
                
#             with container.as_default():
            A, g, sigma1, sigma2, sigma3, sigma4 = self.step(output_single)
            if self.control:
                B, u = self.control_step(output_single, u, A, observation)
            else:
                B = tf.zeros(shape = (self.m,self.r), dtype = tf.float64)
                u = tf.zeros(shape = [1,self.r], dtype=tf.float64)
            '''If this is first pass in loop, add variables to graph'''
#             if first_pass:
#                 self.variables.extend(container.trainable_variables())
#                 first_pass = False
            observation, reward, done, info = self.env.step(tf.squeeze(u))
            '''Calculate:
                A,B,u,g,C,sigma,l_a_posteriori,P_a_posteriori,env_states'''
            KF1_update = split_forward_filter_fn(A,B,u,g,self.C_1,sigma1,l_a_posteriori1[-1],P_a_posteriori1[-1],
                                                tf.convert_to_tensor(observation[0],dtype=tf.float64))
            KF2_update = split_forward_filter_fn(A,B,u,g,self.C_2,sigma2,l_a_posteriori2[-1],P_a_posteriori2[-1],
                                                tf.convert_to_tensor(observation[1],dtype=tf.float64))
            KF3_update = split_forward_filter_fn(A,B,u,g,self.C_3,sigma3,l_a_posteriori3[-1],P_a_posteriori3[-1],
                                                tf.convert_to_tensor(observation[2],dtype=tf.float64))
            KF4_update = split_forward_filter_fn(A,B,u,g,self.C_4,sigma4,l_a_posteriori4[-1],P_a_posteriori4[-1],
                                                tf.convert_to_tensor(observation[3],dtype=tf.float64))
            '''Update lists:
                A_all,B_all,u_all,g_all,C_all,sigma_all,l_a_posteriori,P_a_posteriori,env_states'''
            A_all.append(A)
            B_all.append(B)
            u_all.append(u)
            g_all.append(g)
            sigma1_all.append(sigma1)
            sigma2_all.append(sigma2)
            sigma3_all.append(sigma3)
            sigma4_all.append(sigma4)
            for KF_single,KF_param in zip(KF1_update,KF1_params):
                KF_param.append(KF_single)
            for KF_single,KF_param in zip(KF2_update,KF2_params):
                KF_param.append(KF_single)
            for KF_single,KF_param in zip(KF3_update,KF3_params):
                KF_param.append(KF_single)
            for KF_single,KF_param in zip(KF4_update,KF4_params):
                KF_param.append(KF_single)
                
            next_input = tf.concat((self.env_params,env_states1[-1],env_states2[-1],
                                    env_states3[-1],env_states4[-1]),axis=1)
            output_single,state_single=self.cell(inputs=next_input,state=state_single)
            if rewards%prediction_horizon==0:
#                 LA_preds,LA_A,LA_B,LA_u,LA_g,LA_sigma1,LA_sigma2,LA_sigma3,LA_sigma4,LA_l_a_posteriori, LA_P_a_posteriori
                LA_update = self.look_ahead_prediction(prediction_horizon, observation, output_single, state_single,
                                                       l_a_posteriori1[-1],l_a_posteriori2[-1],l_a_posteriori3[-1],l_a_posteriori4[-1],
                                                       P_a_posteriori1[-1],P_a_posteriori2[-1],P_a_posteriori3[-1],P_a_posteriori4[-1])
                for var,param in zip(look_ahead_vars,LA_update):
                    var.extend(param)
#                 look_ahead_preds.extend(LA_preds)
#                 LA_A_all.extend(LA_A)
#                 LA_B_all.extend(LA_B)
#                 LA_u_all.extend(LA_u)
#                 LA_g_all.extend(LA_g)
#                 LA_sigma1_all.extend(LA_sigma1)
#                 LA_sigma2_all.extend(LA_sigma2)
#                 LA_sigma3_all.extend(LA_sigma3)
#                 LA_sigma4_all.extend(LA_sigma4)
#                 LA_l_a_posteriori1.extend(LA_l_a_posteriori[0])
#                 LA_l_a_posteriori2.extend(LA_l_a_posteriori[1])
#                 LA_l_a_posteriori3.extend(LA_l_a_posteriori[2])
#                 LA_l_a_posteriori4.extend(LA_l_a_posteriori[3])
#                 LA_P_a_posteriori1.extend(LA_P_a_posteriori[0])
#                 LA_P_a_posteriori2.extend(LA_P_a_posteriori[1])
#                 LA_P_a_posteriori3.extend(LA_P_a_posteriori[2])
#                 LA_P_a_posteriori4.extend(LA_P_a_posteriori[3])
            rewards+=1

        LA_A_all = LA_A_all[:rewards]
        LA_B_all = LA_B_all[:rewards]
        LA_u_all = LA_u_all[:rewards]
        LA_g_all = LA_g_all[:rewards]
        LA_sigma1_all = LA_sigma1_all[:rewards]
        LA_sigma2_all = LA_sigma2_all[:rewards]
        LA_sigma3_all = LA_sigma3_all[:rewards]
        LA_sigma4_all = LA_sigma4_all[:rewards]
        LA_l_a_posteriori1 = LA_l_a_posteriori1[:rewards+1]
        LA_l_a_posteriori2 = LA_l_a_posteriori2[:rewards+1]
        LA_l_a_posteriori3 = LA_l_a_posteriori3[:rewards+1]
        LA_l_a_posteriori4 = LA_l_a_posteriori4[:rewards+1]
        LA_P_a_posteriori1 = LA_P_a_posteriori1[:rewards+1]
        LA_P_a_posteriori2 = LA_P_a_posteriori2[:rewards+1]
        LA_P_a_posteriori3 = LA_P_a_posteriori3[:rewards+1]
        LA_P_a_posteriori4 = LA_P_a_posteriori4[:rewards+1]
        if view and self.control:
            self.env.close()

#         param_names = ['A_all','B_all','u_all','g_all','C_all','sigma_all',
#                        'l_a_posteriori','P_a_posteriori','env_states','preds']
#             for name,KF_param in zip(param_names,all_KF_params):
#                 print(name,len(KF_param), KF_param[0].shape)

#         print('mu_0',mu_0)
#         print('Sigma_0',Sigma_0)
#         print('A_all',A_all[0])
#         print('B_all',B_all[0])
#         print('u_all',u_all[0])
#         print('C_1',C_1)
#         print('sigma1_all',sigma1_all[0])
        '''Start Maximum a posteriori section'''
        mu_11 = tf.add(tf.matmul(tf.slice(A_all[0],(0,0),(1,4)), mu_0),tf.matmul(tf.slice(B_all[0],(0,0),(1,1)),u_all[0]))
        Sigma_11 = tf.add(tf.matmul(tf.matmul(self.C_1,Sigma_0),self.C_1, transpose_b=True),
                     tf.matmul(sigma1_all[0],sigma1_all[0],transpose_b=True))
        mu_12 = tf.add(tf.matmul(tf.slice(A_all[0],(1,0),(1,4)), mu_0),tf.matmul(tf.slice(B_all[0],(1,0),(1,1)),u_all[0]))
        Sigma_12 = tf.add(tf.matmul(tf.matmul(self.C_2,Sigma_0),self.C_2, transpose_b=True),
                     tf.matmul(sigma2_all[0],sigma2_all[0],transpose_b=True))
        mu_13 = tf.add(tf.matmul(tf.slice(A_all[0],(2,0),(1,4)), mu_0),tf.matmul(tf.slice(B_all[0],(2,0),(1,1)),u_all[0]))
        Sigma_13 = tf.add(tf.matmul(tf.matmul(self.C_3,Sigma_0),self.C_3, transpose_b=True),
                     tf.matmul(sigma3_all[0],sigma3_all[0],transpose_b=True))
        mu_14 = tf.add(tf.matmul(tf.slice(A_all[0],(3,0),(1,4)), mu_0),tf.matmul(tf.slice(B_all[0],(3,0),(1,1)),u_all[0]))
        Sigma_14 = tf.add(tf.matmul(tf.matmul(self.C_4,Sigma_0),self.C_4, transpose_b=True),
                     tf.matmul(sigma4_all[0],sigma4_all[0],transpose_b=True))
        
        LA_mu_11 = tf.add(tf.matmul(tf.slice(LA_A_all[0],(0,0),(1,4)), mu_0),tf.matmul(tf.slice(LA_B_all[0],(0,0),(1,1)),LA_u_all[0]))
        LA_Sigma_11 = tf.add(tf.matmul(tf.matmul(self.C_1,Sigma_0),self.C_1, transpose_b=True),
                     tf.matmul(LA_sigma1_all[0],LA_sigma1_all[0],transpose_b=True))
        LA_mu_12 = tf.add(tf.matmul(tf.slice(LA_A_all[0],(1,0),(1,4)), mu_0),tf.matmul(tf.slice(LA_B_all[0],(1,0),(1,1)),LA_u_all[0]))
        LA_Sigma_12 = tf.add(tf.matmul(tf.matmul(self.C_2,Sigma_0),self.C_2, transpose_b=True),
                     tf.matmul(LA_sigma2_all[0],LA_sigma2_all[0],transpose_b=True))
        LA_mu_13 = tf.add(tf.matmul(tf.slice(LA_A_all[0],(2,0),(1,4)), mu_0),tf.matmul(tf.slice(LA_B_all[0],(2,0),(1,1)),LA_u_all[0]))
        LA_Sigma_13 = tf.add(tf.matmul(tf.matmul(self.C_3,Sigma_0),self.C_3, transpose_b=True),
                     tf.matmul(LA_sigma3_all[0],LA_sigma3_all[0],transpose_b=True))
        LA_mu_14 = tf.add(tf.matmul(tf.slice(LA_A_all[0],(3,0),(1,4)), mu_0),tf.matmul(tf.slice(LA_B_all[0],(3,0),(1,1)),LA_u_all[0]))
        LA_Sigma_14 = tf.add(tf.matmul(tf.matmul(self.C_4,Sigma_0),self.C_4, transpose_b=True),
                     tf.matmul(LA_sigma4_all[0],LA_sigma4_all[0],transpose_b=True))

        if rewards > 1:
            mu1,Sigma1 = self.likelihood_fn((mu_11,Sigma_11),(A_all[1:],B_all[1:],u_all[1:],g_all[1:],
                                                     self.C_1,sigma1_all[1:],
                                                     l_a_posteriori1[1:-1],
                                                     P_a_posteriori1[1:-1]))
            mu2,Sigma2 = self.likelihood_fn((mu_12,Sigma_12),(A_all[1:],B_all[1:],u_all[1:],g_all[1:],
                                                     self.C_2,sigma2_all[1:],
                                                     l_a_posteriori2[1:-1],
                                                     P_a_posteriori2[1:-1]))
            mu3,Sigma3 = self.likelihood_fn((mu_13,Sigma_13),(A_all[1:],B_all[1:],u_all[1:],g_all[1:],
                                                     self.C_3,sigma3_all[1:],
                                                     l_a_posteriori3[1:-1],
                                                     P_a_posteriori3[1:-1]))
            mu4,Sigma4 = self.likelihood_fn((mu_14,Sigma_14),(A_all[1:],B_all[1:],u_all[1:],g_all[1:],
                                                     self.C_4,sigma4_all[1:],
                                                     l_a_posteriori4[1:-1],
                                                     P_a_posteriori4[1:-1]))

            LA_mu1,LA_Sigma1 = self.likelihood_fn((LA_mu_11,LA_Sigma_11),(LA_A_all[1:],LA_B_all[1:],LA_u_all[1:],LA_g_all[1:],
                                                     self.C_1,LA_sigma1_all[1:],
                                                     LA_l_a_posteriori1[1:-1],
                                                     LA_P_a_posteriori1[1:-1]))
            LA_mu2,LA_Sigma2 = self.likelihood_fn((LA_mu_12,LA_Sigma_12),(LA_A_all[1:],LA_B_all[1:],LA_u_all[1:],LA_g_all[1:],
                                                     self.C_2,LA_sigma2_all[1:],
                                                     LA_l_a_posteriori2[1:-1],
                                                     LA_P_a_posteriori2[1:-1]))
            LA_mu3,LA_Sigma3 = self.likelihood_fn((LA_mu_13,LA_Sigma_13),(LA_A_all[1:],LA_B_all[1:],LA_u_all[1:],LA_g_all[1:],
                                                     self.C_3,LA_sigma3_all[1:],
                                                     LA_l_a_posteriori3[1:-1],
                                                     LA_P_a_posteriori3[1:-1]))
            LA_mu4,LA_Sigma4 = self.likelihood_fn((LA_mu_14,LA_Sigma_14),(LA_A_all[1:],LA_B_all[1:],LA_u_all[1:],LA_g_all[1:],
                                                     self.C_4,LA_sigma4_all[1:],
                                                     LA_l_a_posteriori4[1:-1],
                                                     LA_P_a_posteriori4[1:-1]))            
        else:
            mu1,Sigma1 = LA_mu1,LA_Sigma1 = mu_11,Sigma_11
            mu2,Sigma2 = LA_mu2,LA_Sigma2 = mu_12,Sigma_12
            mu3,Sigma3 = LA_mu3,LA_Sigma3 = mu_13,Sigma_13
            mu4,Sigma4 = LA_mu4,LA_Sigma4 = mu_14,Sigma_14
        '''End Maximum A posteriori section'''
    
        '''p-quantile loss'''
        for j in range(rewards):
#             Q50_numerator[0] += QL(0.5,look_ahead_preds[j][0],env_states1[j])
#             Q90_numerator[0] += QL(0.9,look_ahead_preds[j][0],env_states1[j])
            Q50_numerator[0] += QL(0.5, KF_preds1[j], env_states1[j])
            Q90_numerator[0] += QL(0.9, KF_preds1[j], env_states1[j])
        for j in range(rewards):
#             Q50_numerator[1] += QL(0.5,look_ahead_preds[j][1],env_states2[j])
#             Q90_numerator[1] += QL(0.9,look_ahead_preds[j][1],env_states2[j])
            Q50_numerator[1] += QL(0.5, KF_preds2[j], env_states2[j])
            Q90_numerator[1] += QL(0.9, KF_preds2[j], env_states2[j])
        for j in range(rewards):
#             Q50_numerator[2] += QL(0.5,look_ahead_preds[j][2],env_states3[j])
#             Q90_numerator[2] += QL(0.9,look_ahead_preds[j][2],env_states3[j])
            Q50_numerator[2] += QL(0.5, KF_preds3[j], env_states3[j])
            Q90_numerator[2] += QL(0.9, KF_preds3[j], env_states3[j])
        for j in range(rewards):
#             Q50_numerator[3] += QL(0.5,look_ahead_preds[j][3],env_states4[j])
#             Q90_numerator[3] += QL(0.9,look_ahead_preds[j][3],env_states4[j])
            Q50_numerator[3] += QL(0.5, KF_preds4[j], env_states4[j])
            Q90_numerator[3] += QL(0.9, KF_preds4[j], env_states4[j])

        Q_denomenator1 = np.sum(np.abs(np.squeeze(np.array(env_states1))), axis = 0)
        Q_denomenator2 = np.sum(np.abs(np.squeeze(np.array(env_states2))), axis = 0)
        Q_denomenator3 = np.sum(np.abs(np.squeeze(np.array(env_states3))), axis = 0)
        Q_denomenator4 = np.sum(np.abs(np.squeeze(np.array(env_states4))), axis = 0)

        pq50_loss1 = 2*np.divide(Q50_numerator[0],Q_denomenator1)
        pq90_loss1 = 2*np.divide(Q90_numerator[0],Q_denomenator1)
        pq50_loss2 = 2*np.divide(Q50_numerator[1],Q_denomenator2)
        pq90_loss2 = 2*np.divide(Q90_numerator[1],Q_denomenator2)
        pq50_loss3 = 2*np.divide(Q50_numerator[2],Q_denomenator3)
        pq90_loss3 = 2*np.divide(Q90_numerator[2],Q_denomenator3)
        pq50_loss4 = 2*np.divide(Q50_numerator[3],Q_denomenator4)
        pq90_loss4 = 2*np.divide(Q90_numerator[3],Q_denomenator4)


        '''Compute Likelihood of observations given KF evaluation'''
        z1_distribution = tfd.Normal(loc = mu1, scale = Sigma1)
        z1_likelihood = z1_distribution.log_prob(env_states1)
        z2_distribution = tfd.Normal(loc = mu2, scale = Sigma2)
        z2_likelihood = z2_distribution.log_prob(env_states2)
        z3_distribution = tfd.Normal(loc = mu3, scale = Sigma3)
        z3_likelihood = z3_distribution.log_prob(env_states3)
        z4_distribution = tfd.Normal(loc = mu4, scale = Sigma4)
        z4_likelihood = z4_distribution.log_prob(env_states4)
        LA_z1_distribution = tfd.Normal(loc = LA_mu1, scale = LA_Sigma1)
        LA_z1_likelihood = LA_z1_distribution.log_prob(env_states1)
        LA_z2_distribution = tfd.Normal(loc = LA_mu2, scale = LA_Sigma2)
        LA_z2_likelihood = LA_z2_distribution.log_prob(env_states2)
        LA_z3_distribution = tfd.Normal(loc = LA_mu3, scale = LA_Sigma3)
        LA_z3_likelihood = LA_z3_distribution.log_prob(env_states3)
        LA_z4_distribution = tfd.Normal(loc = LA_mu4, scale = LA_Sigma4)
        LA_z4_likelihood = LA_z4_distribution.log_prob(env_states4)
        self.global_epoch += 1

#         print('container len', len(container.variables()))
#         for var in container.variables():
#             print(var.name)
        return((z1_likelihood,z2_likelihood,z3_likelihood,z4_likelihood),(LA_z1_likelihood,LA_z2_likelihood,LA_z3_likelihood,LA_z4_likelihood),
               rewards,tf.squeeze(tf.convert_to_tensor((KF_preds1,KF_preds2,KF_preds3,KF_preds4))),
               (env_states1,env_states2,env_states3,env_states4),(squared_error1,squared_error2,squared_error3,squared_error4),
               (pq50_loss1,pq90_loss1,pq50_loss2,pq90_loss2,pq50_loss3,pq90_loss3,pq50_loss4,pq90_loss4), look_ahead_preds)

def QL(rho, z, z_pred):
    if z > z_pred:
        return rho*(z-z_pred)
    else:
        return (1-rho)*(z_pred-z)
    
def look_ahead_loss(model,view,prediction_horizon):
    likelihoods, LA_likelihoods, rewards, preds, trajectory, squared_error, pq_loss, look_ahead_preds = model(prediction_horizon,view)
    loss = tf.Variable([0.0], trainable=False, dtype = tf.float64)
    for i in range(rewards):
        loss = tf.add(loss,tf.square(trajectory[0][i])-look_ahead_preds[i][0])
        loss = tf.add(loss,tf.square(trajectory[1][i])-look_ahead_preds[i][1])
        loss = tf.add(loss,tf.square(trajectory[2][i])-look_ahead_preds[i][2])
        loss = tf.add(loss,tf.square(trajectory[3][i])-look_ahead_preds[i][3])
#         loss = tf.add(loss,((i%prediction_horizon)+1)*tf.square(trajectory[0][i])-look_ahead_preds[i][0])
#         loss = tf.add(loss,((i%prediction_horizon)+1)*tf.square(trajectory[1][i])-look_ahead_preds[i][1])
#         loss = tf.add(loss,((i%prediction_horizon)+1)*tf.square(trajectory[2][i])-look_ahead_preds[i][2])
#         loss = tf.add(loss,((i%prediction_horizon)+1)*tf.square(trajectory[3][i])-look_ahead_preds[i][3])
    for likelihood in LA_likelihoods:
        for loss_term in likelihood:
            loss = tf.add(loss,-loss_term)
    
    return loss, rewards, preds, trajectory, squared_error, pq_loss, look_ahead_preds
    
def standard_loss(model,view,prediction_horizon):
    likelihoods, LA_likelihoods, rewards, preds, trajectory, squared_error, pq_loss, look_ahead_preds = model(prediction_horizon,view)
    loss = tf.Variable([0.0], trainable = False, dtype = tf.float64)
    for likelihood in LA_likelihoods:
        for loss_term in likelihood:
            loss = tf.add(loss,-loss_term)
    return loss, rewards, preds, trajectory, squared_error, pq_loss, look_ahead_preds

def inverse_multiplicative_loss(model,view,prediction_horizon):
    '''This gives loss terms which are a multiple of their time step'''
    likelihoods, LA_likelihoods, rewards, preds, trajectory, squared_error, pq_loss, look_ahead_preds = model(prediction_horizon,view)
    loss = tf.Variable([0.0], trainable = False, dtype = tf.float64)
    for likelihood in likelihoods:
        for t,loss_term in enumerate(likelihood):
            loss = tf.add(loss,-(loss_term*(1/(t+1))))
    return loss, rewards, preds, trajectory, squared_error, pq_loss, look_ahead_preds

def multiplicative_loss(model,view,prediction_horizon):
    '''This gives loss terms which are a multiple of their time step'''
    likelihoods, LA_likelihoods, rewards, preds, trajectory, squared_error, pq_loss, look_ahead_preds = model(prediction_horizon,view)
    loss = tf.Variable([0.0], trainable = False, dtype = tf.float64)
    for likelihood in likelihoods:
        for t,loss_term in enumerate(likelihood):
            loss = tf.add(loss,-(loss_term*t))
    return loss, rewards, preds, trajectory, squared_error, pq_loss, look_ahead_preds

def exponential_loss(model, alpha,view,prediction_horizon):
    '''For alpha > 1 this gives exponentially increasing loss
        For 0<alpha<1 this gives discounted loss'''
    likelihoods, LA_likelihoods, rewards, preds, trajectory, squared_error, pq_loss, look_ahead_preds = model(prediction_horizon,view)
    loss = tf.Variable([0.0], trainable = False, dtype = tf.float64)
    for likelihood in likelihoods:
        for t,loss_term in enumerate(likelihood):
            loss = tf.add(loss,-(tf.math.pow(alpha,t)*loss_term))
    return loss, rewards, preds, trajectory, squared_error, pq_loss, look_ahead_preds

def control_loss(model, alpha, view,prediction_horizon):
    likelihoods, LA_likelihoods, rewards, preds, trajectory, squared_error, pq_loss, look_ahead_preds = model(prediction_horizon,view)
    loss = tf.Variable([0.0], trainable=False, dtype = tf.float64)
    max_seq_len = tf.Variable([200.0], trainable=False, dtype = tf.float64)
    for likelihood in likelihoods:
        for t,loss_term in enumerate(likelihood):
            loss = tf.add(loss,tf.math.pow(alpha,t)*loss_term)
    return loss, rewards, preds, trajectory, squared_error, pq_loss, look_ahead_preds

def compute_gradient(model, loss_type, alpha,prediction_horizon,view = False):
    with tf.GradientTape() as tape:
        if loss_type == 'standard':
            loss_value, rewards, preds,trajectory, squared_error, pq_loss, look_ahead_preds = standard_loss(model,view,prediction_horizon)
        elif loss_type == 'inverse_multiplicative':
            loss_value, rewards, preds,trajectory, squared_error, pq_loss, look_ahead_preds = inverse_multiplicative_loss(model,view,prediction_horizon)
        elif loss_type == 'multiplicative':
            loss_value, rewards, preds,trajectory, squared_error, pq_loss, look_ahead_preds = multiplicative_loss(model,view,prediction_horizon)
        elif loss_type == 'exponential':
            loss_value, rewards, preds,trajectory, squared_error, pq_loss, look_ahead_preds = exponential_loss(model, tf.convert_to_tensor(alpha, dtype = tf.float64),view,prediction_horizon)
        elif loss_type == 'control':
            loss_value, rewards, preds,trajectory, squared_error, pq_loss, look_ahead_preds = control_loss(model,tf.convert_to_tensor(alpha, dtype = tf.float64),view,prediction_horizon)
        elif loss_type == 'look ahead':
            loss_value, rewards, preds,trajectory, squared_error, pq_loss, look_ahead_preds = look_ahead_loss(model,view,prediction_horizon)
#         for var in tape.watched_variables():
#               print(var.name)
#         print(tape.watched_variables())
#         print('tape len',len(tape.watched_variables()))
#         if view:
#             for var in tape.watched_variables():
#                 print(var.name)
#         print('model variables',len(model.get_variables()))
    return (tape.watched_variables(), tape.gradient(loss_value, tape.watched_variables()), loss_value.numpy(),rewards, preds, trajectory, squared_error, pq_loss, look_ahead_preds)

In [4]:
def train(model, num_epochs, optimizer, loss_type='standard', alpha = None, view_rate = False, clip_gradients = False, prediction_horizon=5):
    start = time.time()
    losses = []
    rewards = []
    pq_losses = []
    grad_norms = []
    predicted_trajectories = []
    actual_trajectories = []
#     squared_errors = []
    look_ahead_predictions = []
    for i in range(num_epochs):
#         print('very top', model.global_epoch)
        '''Try to load model along with previous metrics'''        
        if i==0:
            try:
                (losses,rewards,pq_losses,grad_norms,predicted_trajectories,actual_trajectories,look_ahead_predictions) =\
                        load_model(model,model.model_name)
                print('Model loaded from /{}/'.format(model.model_name))
            except tf.errors.NotFoundError:
                print('Model not found, continuing to train new model')
                pass
            except FileNotFoundError:
                print('Model not found, continuing to train new model')
                pass
            except:
                print("other error")
        '''Run model with or without viewing'''
#         print('early top', model.global_epoch)

        if view_rate:
            if (model.global_epoch)%view_rate==0:
                watched_vars, grads, loss_, reward_, pred, trajectory,squared_error, pq_loss, look_ahead_preds = compute_gradient(model, loss_type, alpha, view=True,prediction_horizon=prediction_horizon)
            else:
                watched_vars, grads, loss_, reward_, pred, trajectory,squared_error, pq_loss, look_ahead_preds = compute_gradient(model, loss_type, alpha,prediction_horizon=prediction_horizon)
        else:
            watched_vars, grads, loss_, reward_, pred, trajectory,squared_error, pq_loss, look_ahead_preds = compute_gradient(model, loss_type, alpha,prediction_horizon=prediction_horizon)
        
        '''Keep track of loss, rewards, etc.'''
        if model.global_epoch==1:
            for var in watched_vars:
                grad_norms.append([var.name])
        losses.extend(loss_)
        rewards.append(reward_)
        predicted_trajectories.append(pred)
        actual_trajectories.append(trajectory)
#         squared_errors.append(squared_error)
        pq_losses.append(pq_loss)
        look_ahead_predictions.append(look_ahead_preds)
        for idx, grad in enumerate(grads):
            grad_norms[idx].append(np.linalg.norm(grad))
#         print('mid', model.global_epoch)

        '''clip gradients if called for and apply gradients'''
#         clipped_grads = [tf.clip_by_value(grad_, -1.,1.) for grad_ in grads]
        if clip_gradients:
            clipped_grads = [tf.clip_by_norm(grad, 1.) for grad in grads]
            optimizer.apply_gradients(zip(clipped_grads,model.get_variables()))
        else:
            optimizer.apply_gradients(zip(grads,watched_vars), tf.Variable(model.global_epoch))
#         print('near end', model.global_epoch)
        if (i+1)%view_rate == 0:
            print('Epoch {}'.format(model.global_epoch))
            print('Minutes elapsed: {}'.format((time.time()-start)/60))
            print('Last {} averages: Loss: {}, reward: {}, loss/reward: {}'.format(view_rate,np.mean(losses[-view_rate:]), np.mean(rewards[-view_rate:]),
                                                                                   (np.mean(losses[-view_rate:])/np.mean(rewards[-view_rate:]))))
            print('very end', model.global_epoch)
            save_model(model,model.model_name,(losses,rewards,pq_losses,
                                               grad_norms,predicted_trajectories,
                                               actual_trajectories,look_ahead_predictions))
#             print('Model variables:')
#             for var in watched_vars:
#                 print(var)

In [5]:
def save_model(model,savepath,metrics):
    if not os.path.exists(savepath):
        os.mkdir(savepath)
        print("Directory " , savepath ,  " Created ")
    pkl.dump(metrics, open(savepath+'/metrics.pkl','wb'))
    for NN in model.NNs:
        NN.save_weights('{}/{}.tf'.format(savepath,NN.name))
    pkl.dump(model.cell.get_weights(), open(savepath+'/lstmweights.pkl', 'wb'))
    pkl.dump(model.global_epoch, open(savepath+'/globalepoch.pkl', 'wb'))

    
def load_model(model,loadpath):
    metrics = pkl.load(open(loadpath+'/metrics.pkl','rb'))
    for NN in model.NNs:
        NN.load_weights('{}/{}.tf'.format(loadpath,NN.name))
    lstm_weights = pkl.load(open(loadpath+'/lstmweights.pkl','rb'))
    model.global_epoch = pkl.load(open(loadpath+'/globalepoch.pkl','rb'))
    model.cell.set_weights(lstm_weights)
    return metrics

## Create empty model

In [6]:
split_model = split_KF_Model(model_name = 'testing', env_params_variation=[0,0,0,0],initial_state_variation=[1,0.1,0.1,0.1])
optimizer = tf.train.AdamOptimizer()
# losses = []
# rewards = []
# pq_losses = []
# grad_norms = []
# predicted_trajectories = []
# actual_trajectories = []
# squared_errors = []
# look_ahead_predictions = []
# metrics = [losses,rewards,pq_losses,grad_norms,predicted_trajectories,actual_trajectories,look_ahead_predictions]

Instructions for updating:
This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.
Instructions for updating:
This class is equivalent as tf.keras.layers.StackedRNNCells, and will be replaced by that in Tensorflow 2.0.


## Train model

In [None]:
# split_model.set_control(True, 'uniform random')
# split_model.set_env(env_params_variation=[0,0,0,0],initial_state_variation=[1,0.1,0.1,0.1])
for i in range(50):
    train(split_model,100,optimizer, view_rate=100, loss_type='standard',prediction_horizon=5)
    tf.reset_default_graph()
    tf.keras.backend.clear_session()
    
# split_model.set_env(env_params_variation=[0,0,0,0],initial_state_variation=[1,0.1,0.1,0.1])
# split_model.set_control(True,'NN regularized')
# split_model.set_u_clip(15)
# split_model.set_LQR_params([1,1,100,1],10)
# train(split_model,100,tf.train.AdamOptimizer(), view_rate=5, loss_type = 'exponential',alpha = 0.95)

very top 0
Model not found, continuing to train new model
early top 0
Instructions for updating:
Colocations handled automatically by placer.
mid 1
near end 1
very top 1
early top 1
mid 2
near end 2
very top 2
early top 2
mid 3
near end 3
very top 3
early top 3
mid 4
near end 4
very top 4
early top 4
mid 5
near end 5
very top 5
early top 5
mid 6
near end 6
very top 6
early top 6
mid 7
near end 7
very top 7
early top 7
mid 8
near end 8
very top 8
early top 8
mid 9
near end 9
very top 9
early top 9
mid 10
near end 10
very top 10
early top 10
mid 11
near end 11
very top 11
early top 11
mid 12
near end 12
very top 12
early top 12
mid 13
near end 13
very top 13
early top 13
mid 14
near end 14
very top 14
early top 14
mid 15
near end 15
very top 15
early top 15
mid 16
near end 16
very top 16
early top 16
mid 17
near end 17
very top 17
early top 17
mid 18
near end 18
very top 18
early top 18
mid 19
near end 19
very top 19
early top 19
mid 20
near end 20
very top 20
early top 20
mid 21
near en

mid 162
near end 162
very top 162
early top 162
mid 163
near end 163
very top 163
early top 163
mid 164
near end 164
very top 164
early top 164
mid 165
near end 165
very top 165
early top 165
mid 166
near end 166
very top 166
early top 166
mid 167
near end 167
very top 167
early top 167
mid 168
near end 168
very top 168
early top 168
mid 169
near end 169
very top 169
early top 169
mid 170
near end 170
very top 170
early top 170
mid 171
near end 171
very top 171
early top 171
mid 172
near end 172
very top 172
early top 172
mid 173
near end 173
very top 173
early top 173
mid 174
near end 174
very top 174
early top 174
mid 175
near end 175
very top 175
early top 175
mid 176
near end 176
very top 176
early top 176
mid 177
near end 177
very top 177
early top 177
mid 178
near end 178
very top 178
early top 178
mid 179
near end 179
very top 179
early top 179
mid 180
near end 180
very top 180
early top 180
mid 181
near end 181
very top 181
early top 181
mid 182
near end 182
very top 182
early 

mid 326
near end 326
very top 326
early top 326
mid 327
near end 327
very top 327
early top 327
mid 328
near end 328
very top 328
early top 328
mid 329
near end 329
very top 329
early top 329
mid 330
near end 330
very top 330
early top 330
mid 331
near end 331
very top 331
early top 331
mid 332
near end 332
very top 332
early top 332
mid 333
near end 333
very top 333
early top 333
mid 334
near end 334
very top 334
early top 334
mid 335
near end 335
very top 335
early top 335
mid 336
near end 336
very top 336
early top 336
mid 337
near end 337
very top 337
early top 337
mid 338
near end 338
very top 338
early top 338
mid 339
near end 339
very top 339
early top 339
mid 340
near end 340
very top 340
early top 340
mid 341
near end 341
very top 341
early top 341
mid 342
near end 342
very top 342
early top 342
mid 343
near end 343
very top 343
early top 343
mid 344
near end 344
very top 344
early top 344
mid 345
near end 345
very top 345
early top 345
mid 346
near end 346
very top 346
early 

mid 493
near end 493
very top 493
early top 493
mid 494
near end 494
very top 494
early top 494
mid 495
near end 495
very top 495
early top 495
mid 496
near end 496
very top 496
early top 496
mid 497
near end 497
very top 497
early top 497
mid 498
near end 498
very top 498
early top 498
mid 499
near end 499
very top 499
early top 499
mid 500
near end 500
Epoch 500
Minutes elapsed: 4.399091037114461
Last 100 averages: Loss: -143.90032792767005, reward: 48.62, loss/reward: -2.9596941161594006
very end 500
very top 500
Model loaded from /testing/
early top 500
mid 501
near end 501
very top 501
early top 501
mid 502
near end 502
very top 502
early top 502
mid 503
near end 503
very top 503
early top 503
mid 504
near end 504
very top 504
early top 504
mid 505
near end 505
very top 505
early top 505
mid 506
near end 506
very top 506
early top 506
mid 507
near end 507
very top 507
early top 507
mid 508
near end 508
very top 508
early top 508
mid 509
near end 509
very top 509
early top 509
mid 

mid 657
near end 657
very top 657
early top 657
mid 658
near end 658
very top 658
early top 658
mid 659
near end 659
very top 659
early top 659
mid 660
near end 660
very top 660
early top 660
mid 661
near end 661
very top 661
early top 661
mid 662
near end 662
very top 662
early top 662
mid 663
near end 663
very top 663
early top 663
mid 664
near end 664
very top 664
early top 664
mid 665
near end 665
very top 665
early top 665
mid 666
near end 666
very top 666
early top 666
mid 667
near end 667
very top 667
early top 667
mid 668
near end 668
very top 668
early top 668
mid 669
near end 669
very top 669
early top 669
mid 670
near end 670
very top 670
early top 670
mid 671
near end 671
very top 671
early top 671
mid 672
near end 672
very top 672
early top 672
mid 673
near end 673
very top 673
early top 673
mid 674
near end 674
very top 674
early top 674
mid 675
near end 675
very top 675
early top 675
mid 676
near end 676
very top 676
early top 676
mid 677
near end 677
very top 677
early 

mid 821
near end 821
very top 821
early top 821
mid 822
near end 822
very top 822
early top 822
mid 823
near end 823
very top 823
early top 823
mid 824
near end 824
very top 824
early top 824
mid 825
near end 825
very top 825
early top 825
mid 826
near end 826
very top 826
early top 826
mid 827
near end 827
very top 827
early top 827
mid 828
near end 828
very top 828
early top 828
mid 829
near end 829
very top 829
early top 829
mid 830
near end 830
very top 830
early top 830
mid 831
near end 831
very top 831
early top 831
mid 832
near end 832
very top 832
early top 832
mid 833
near end 833
very top 833
early top 833
mid 834
near end 834
very top 834
early top 834
mid 835
near end 835
very top 835
early top 835
mid 836
near end 836
very top 836
early top 836
mid 837
near end 837
very top 837
early top 837
mid 838
near end 838
very top 838
early top 838
mid 839
near end 839
very top 839
early top 839
mid 840
near end 840
very top 840
early top 840
mid 841
near end 841
very top 841
early 

mid 988
near end 988
very top 988
early top 988
mid 989
near end 989
very top 989
early top 989
mid 990
near end 990
very top 990
early top 990
mid 991
near end 991
very top 991
early top 991
mid 992
near end 992
very top 992
early top 992
mid 993
near end 993
very top 993
early top 993
mid 994
near end 994
very top 994
early top 994
mid 995
near end 995
very top 995
early top 995
mid 996
near end 996
very top 996
early top 996
mid 997
near end 997
very top 997
early top 997
mid 998
near end 998
very top 998
early top 998
mid 999
near end 999
very top 999
early top 999
mid 1000
near end 1000
Epoch 1000
Minutes elapsed: 4.497777160008749
Last 100 averages: Loss: -265.10845191647354, reward: 49.16, loss/reward: -5.392767532881887
very end 1000
very top 1000
Model loaded from /testing/
early top 1000
mid 1001
near end 1001
very top 1001
early top 1001
mid 1002
near end 1002
very top 1002
early top 1002
mid 1003
near end 1003
very top 1003
early top 1003
mid 1004
near end 1004
very top 100

mid 1140
near end 1140
very top 1140
early top 1140
mid 1141
near end 1141
very top 1141
early top 1141
mid 1142
near end 1142
very top 1142
early top 1142
mid 1143
near end 1143
very top 1143
early top 1143
mid 1144
near end 1144
very top 1144
early top 1144
mid 1145
near end 1145
very top 1145
early top 1145
mid 1146
near end 1146
very top 1146
early top 1146
mid 1147
near end 1147
very top 1147
early top 1147
mid 1148
near end 1148
very top 1148
early top 1148
mid 1149
near end 1149
very top 1149
early top 1149
mid 1150
near end 1150
very top 1150
early top 1150
mid 1151
near end 1151
very top 1151
early top 1151
mid 1152
near end 1152
very top 1152
early top 1152
mid 1153
near end 1153
very top 1153
early top 1153
mid 1154
near end 1154
very top 1154
early top 1154
mid 1155
near end 1155
very top 1155
early top 1155
mid 1156
near end 1156
very top 1156
early top 1156
mid 1157
near end 1157
very top 1157
early top 1157
mid 1158
near end 1158
very top 1158
early top 1158
mid 1159
nea

mid 1295
near end 1295
very top 1295
early top 1295
mid 1296
near end 1296
very top 1296
early top 1296
mid 1297
near end 1297
very top 1297
early top 1297
mid 1298
near end 1298
very top 1298
early top 1298
mid 1299
near end 1299
very top 1299
early top 1299
mid 1300
near end 1300
Epoch 1300
Minutes elapsed: 4.842346159617106
Last 100 averages: Loss: -315.85754644644095, reward: 48.31, loss/reward: -6.538140063060255
very end 1300
very top 1300
Model loaded from /testing/
early top 1300
mid 1301
near end 1301
very top 1301
early top 1301
mid 1302
near end 1302
very top 1302
early top 1302
mid 1303
near end 1303
very top 1303
early top 1303
mid 1304
near end 1304
very top 1304
early top 1304
mid 1305
near end 1305
very top 1305
early top 1305
mid 1306
near end 1306
very top 1306
early top 1306
mid 1307
near end 1307
very top 1307
early top 1307
mid 1308
near end 1308
very top 1308
early top 1308
mid 1309
near end 1309
very top 1309
early top 1309
mid 1310
near end 1310
very top 1310
ea

mid 1446
near end 1446
very top 1446
early top 1446
mid 1447
near end 1447
very top 1447
early top 1447
mid 1448
near end 1448
very top 1448
early top 1448
mid 1449
near end 1449
very top 1449
early top 1449
mid 1450
near end 1450
very top 1450
early top 1450
mid 1451
near end 1451
very top 1451
early top 1451
mid 1452
near end 1452
very top 1452
early top 1452
mid 1453
near end 1453
very top 1453
early top 1453
mid 1454
near end 1454
very top 1454
early top 1454
mid 1455
near end 1455
very top 1455
early top 1455
mid 1456
near end 1456
very top 1456
early top 1456
mid 1457
near end 1457
very top 1457
early top 1457
mid 1458
near end 1458
very top 1458
early top 1458
mid 1459
near end 1459
very top 1459
early top 1459
mid 1460
near end 1460
very top 1460
early top 1460
mid 1461
near end 1461
very top 1461
early top 1461
mid 1462
near end 1462
very top 1462
early top 1462
mid 1463
near end 1463
very top 1463
early top 1463
mid 1464
near end 1464
very top 1464
early top 1464
mid 1465
nea

very top 1600
Model loaded from /testing/
early top 1600
mid 1601
near end 1601
very top 1601
early top 1601
mid 1602
near end 1602
very top 1602
early top 1602
mid 1603
near end 1603
very top 1603
early top 1603
mid 1604
near end 1604
very top 1604
early top 1604
mid 1605
near end 1605
very top 1605
early top 1605
mid 1606
near end 1606
very top 1606
early top 1606
mid 1607
near end 1607
very top 1607
early top 1607
mid 1608
near end 1608
very top 1608
early top 1608
mid 1609
near end 1609
very top 1609
early top 1609
mid 1610
near end 1610
very top 1610
early top 1610
mid 1611
near end 1611
very top 1611
early top 1611
mid 1612
near end 1612
very top 1612
early top 1612
mid 1613
near end 1613
very top 1613
early top 1613
mid 1614
near end 1614
very top 1614
early top 1614
mid 1615
near end 1615
very top 1615
early top 1615
mid 1616
near end 1616
very top 1616
early top 1616
mid 1617
near end 1617
very top 1617
early top 1617
mid 1618
near end 1618
very top 1618
early top 1618
mid 161

mid 1754
near end 1754
very top 1754
early top 1754
mid 1755
near end 1755
very top 1755
early top 1755
mid 1756
near end 1756
very top 1756
early top 1756
mid 1757
near end 1757
very top 1757
early top 1757
mid 1758
near end 1758
very top 1758
early top 1758
mid 1759
near end 1759
very top 1759
early top 1759
mid 1760
near end 1760
very top 1760
early top 1760
mid 1761
near end 1761
very top 1761
early top 1761
mid 1762
near end 1762
very top 1762
early top 1762
mid 1763
near end 1763
very top 1763
early top 1763
mid 1764
near end 1764
very top 1764
early top 1764
mid 1765
near end 1765
very top 1765
early top 1765
mid 1766
near end 1766
very top 1766
early top 1766
mid 1767
near end 1767
very top 1767
early top 1767
mid 1768
near end 1768
very top 1768
early top 1768
mid 1769
near end 1769
very top 1769
early top 1769
mid 1770
near end 1770
very top 1770
early top 1770
mid 1771
near end 1771
very top 1771
early top 1771
mid 1772
near end 1772
very top 1772
early top 1772
mid 1773
nea

mid 1905
near end 1905
very top 1905
early top 1905
mid 1906
near end 1906
very top 1906
early top 1906
mid 1907
near end 1907
very top 1907
early top 1907
mid 1908
near end 1908
very top 1908
early top 1908
mid 1909
near end 1909
very top 1909
early top 1909
mid 1910
near end 1910
very top 1910
early top 1910
mid 1911
near end 1911
very top 1911
early top 1911
mid 1912
near end 1912
very top 1912
early top 1912
mid 1913
near end 1913
very top 1913
early top 1913
mid 1914
near end 1914
very top 1914
early top 1914
mid 1915
near end 1915
very top 1915
early top 1915
mid 1916
near end 1916
very top 1916
early top 1916
mid 1917
near end 1917
very top 1917
early top 1917
mid 1918
near end 1918
very top 1918
early top 1918
mid 1919
near end 1919
very top 1919
early top 1919
mid 1920
near end 1920
very top 1920
early top 1920
mid 1921
near end 1921
very top 1921
early top 1921
mid 1922
near end 1922
very top 1922
early top 1922
mid 1923
near end 1923
very top 1923
early top 1923
mid 1924
nea

mid 2060
near end 2060
very top 2060
early top 2060
mid 2061
near end 2061
very top 2061
early top 2061
mid 2062
near end 2062
very top 2062
early top 2062
mid 2063
near end 2063
very top 2063
early top 2063
mid 2064
near end 2064
very top 2064
early top 2064
mid 2065
near end 2065
very top 2065
early top 2065
mid 2066
near end 2066
very top 2066
early top 2066
mid 2067
near end 2067
very top 2067
early top 2067
mid 2068
near end 2068
very top 2068
early top 2068
mid 2069
near end 2069
very top 2069
early top 2069
mid 2070
near end 2070
very top 2070
early top 2070
mid 2071
near end 2071
very top 2071
early top 2071
mid 2072
near end 2072
very top 2072
early top 2072
mid 2073
near end 2073
very top 2073
early top 2073
mid 2074
near end 2074
very top 2074
early top 2074
mid 2075
near end 2075
very top 2075
early top 2075
mid 2076
near end 2076
very top 2076
early top 2076
mid 2077
near end 2077
very top 2077
early top 2077
mid 2078
near end 2078
very top 2078
early top 2078
mid 2079
nea

mid 2211
near end 2211
very top 2211
early top 2211
mid 2212
near end 2212
very top 2212
early top 2212
mid 2213
near end 2213
very top 2213
early top 2213
mid 2214
near end 2214
very top 2214
early top 2214
mid 2215
near end 2215
very top 2215
early top 2215
mid 2216
near end 2216
very top 2216
early top 2216
mid 2217
near end 2217
very top 2217
early top 2217
mid 2218
near end 2218
very top 2218
early top 2218
mid 2219
near end 2219
very top 2219
early top 2219
mid 2220
near end 2220
very top 2220
early top 2220
mid 2221
near end 2221
very top 2221
early top 2221
mid 2222
near end 2222
very top 2222
early top 2222
mid 2223
near end 2223
very top 2223
early top 2223
mid 2224
near end 2224
very top 2224
early top 2224
mid 2225
near end 2225
very top 2225
early top 2225
mid 2226
near end 2226
very top 2226
early top 2226
mid 2227
near end 2227
very top 2227
early top 2227
mid 2228
near end 2228
very top 2228
early top 2228
mid 2229
near end 2229
very top 2229
early top 2229
mid 2230
nea

mid 2366
near end 2366
very top 2366
early top 2366
mid 2367
near end 2367
very top 2367
early top 2367
mid 2368
near end 2368
very top 2368
early top 2368
mid 2369
near end 2369
very top 2369
early top 2369
mid 2370
near end 2370
very top 2370
early top 2370
mid 2371
near end 2371
very top 2371
early top 2371
mid 2372
near end 2372
very top 2372
early top 2372
mid 2373
near end 2373
very top 2373
early top 2373
mid 2374
near end 2374
very top 2374
early top 2374
mid 2375
near end 2375
very top 2375
early top 2375
mid 2376
near end 2376
very top 2376
early top 2376
mid 2377
near end 2377
very top 2377
early top 2377
mid 2378
near end 2378
very top 2378
early top 2378
mid 2379
near end 2379
very top 2379
early top 2379
mid 2380
near end 2380
very top 2380
early top 2380
mid 2381
near end 2381
very top 2381
early top 2381
mid 2382
near end 2382
very top 2382
early top 2382
mid 2383
near end 2383
very top 2383
early top 2383
mid 2384
near end 2384
very top 2384
early top 2384
mid 2385
nea

mid 2517
near end 2517
very top 2517
early top 2517
mid 2518
near end 2518
very top 2518
early top 2518
mid 2519
near end 2519
very top 2519
early top 2519
mid 2520
near end 2520
very top 2520
early top 2520
mid 2521
near end 2521
very top 2521
early top 2521
mid 2522
near end 2522
very top 2522
early top 2522
mid 2523
near end 2523
very top 2523
early top 2523
mid 2524
near end 2524
very top 2524
early top 2524
mid 2525
near end 2525
very top 2525
early top 2525
mid 2526
near end 2526
very top 2526
early top 2526
mid 2527
near end 2527
very top 2527
early top 2527
mid 2528
near end 2528
very top 2528
early top 2528
mid 2529
near end 2529
very top 2529
early top 2529
mid 2530
near end 2530
very top 2530
early top 2530
mid 2531
near end 2531
very top 2531
early top 2531
mid 2532
near end 2532
very top 2532
early top 2532
mid 2533
near end 2533
very top 2533
early top 2533
mid 2534
near end 2534
very top 2534
early top 2534
mid 2535
near end 2535
very top 2535
early top 2535
mid 2536
nea

mid 2672
near end 2672
very top 2672
early top 2672
mid 2673
near end 2673
very top 2673
early top 2673
mid 2674
near end 2674
very top 2674
early top 2674
mid 2675
near end 2675
very top 2675
early top 2675
mid 2676
near end 2676
very top 2676
early top 2676
mid 2677
near end 2677
very top 2677
early top 2677
mid 2678
near end 2678
very top 2678
early top 2678
mid 2679
near end 2679
very top 2679
early top 2679
mid 2680
near end 2680
very top 2680
early top 2680
mid 2681
near end 2681
very top 2681
early top 2681
mid 2682
near end 2682
very top 2682
early top 2682
mid 2683
near end 2683
very top 2683
early top 2683
mid 2684
near end 2684
very top 2684
early top 2684
mid 2685
near end 2685
very top 2685
early top 2685
mid 2686
near end 2686
very top 2686
early top 2686
mid 2687
near end 2687
very top 2687
early top 2687
mid 2688
near end 2688
very top 2688
early top 2688
mid 2689
near end 2689
very top 2689
early top 2689
mid 2690
near end 2690
very top 2690
early top 2690
mid 2691
nea

mid 2823
near end 2823
very top 2823
early top 2823
mid 2824
near end 2824
very top 2824
early top 2824
mid 2825
near end 2825
very top 2825
early top 2825
mid 2826
near end 2826
very top 2826
early top 2826
mid 2827
near end 2827
very top 2827
early top 2827
mid 2828
near end 2828
very top 2828
early top 2828
mid 2829
near end 2829
very top 2829
early top 2829
mid 2830
near end 2830
very top 2830
early top 2830
mid 2831
near end 2831
very top 2831
early top 2831
mid 2832
near end 2832
very top 2832
early top 2832
mid 2833
near end 2833
very top 2833
early top 2833
mid 2834
near end 2834
very top 2834
early top 2834
mid 2835
near end 2835
very top 2835
early top 2835
mid 2836
near end 2836
very top 2836
early top 2836
mid 2837
near end 2837
very top 2837
early top 2837
mid 2838
near end 2838
very top 2838
early top 2838
mid 2839
near end 2839
very top 2839
early top 2839
mid 2840
near end 2840
very top 2840
early top 2840
mid 2841
near end 2841
very top 2841
early top 2841
mid 2842
nea

mid 2978
near end 2978
very top 2978
early top 2978
mid 2979
near end 2979
very top 2979
early top 2979
mid 2980
near end 2980
very top 2980
early top 2980
mid 2981
near end 2981
very top 2981
early top 2981
mid 2982
near end 2982
very top 2982
early top 2982
mid 2983
near end 2983
very top 2983
early top 2983
mid 2984
near end 2984
very top 2984
early top 2984
mid 2985
near end 2985
very top 2985
early top 2985
mid 2986
near end 2986
very top 2986
early top 2986
mid 2987
near end 2987
very top 2987
early top 2987
mid 2988
near end 2988
very top 2988
early top 2988
mid 2989
near end 2989
very top 2989
early top 2989
mid 2990
near end 2990
very top 2990
early top 2990
mid 2991
near end 2991
very top 2991
early top 2991
mid 2992
near end 2992
very top 2992
early top 2992
mid 2993
near end 2993
very top 2993
early top 2993
mid 2994
near end 2994
very top 2994
early top 2994
mid 2995
near end 2995
very top 2995
early top 2995
mid 2996
near end 2996
very top 2996
early top 2996
mid 2997
nea

mid 3129
near end 3129
very top 3129
early top 3129
mid 3130
near end 3130
very top 3130
early top 3130
mid 3131
near end 3131
very top 3131
early top 3131
mid 3132
near end 3132
very top 3132
early top 3132
mid 3133
near end 3133
very top 3133
early top 3133
mid 3134
near end 3134
very top 3134
early top 3134
mid 3135
near end 3135
very top 3135
early top 3135
mid 3136
near end 3136
very top 3136
early top 3136
mid 3137
near end 3137
very top 3137
early top 3137
mid 3138
near end 3138
very top 3138
early top 3138
mid 3139
near end 3139
very top 3139
early top 3139
mid 3140
near end 3140
very top 3140
early top 3140
mid 3141
near end 3141
very top 3141
early top 3141
mid 3142
near end 3142
very top 3142
early top 3142
mid 3143
near end 3143
very top 3143
early top 3143
mid 3144
near end 3144
very top 3144
early top 3144
mid 3145
near end 3145
very top 3145
early top 3145
mid 3146
near end 3146
very top 3146
early top 3146
mid 3147
near end 3147
very top 3147
early top 3147
mid 3148
nea

mid 3284
near end 3284
very top 3284
early top 3284
mid 3285
near end 3285
very top 3285
early top 3285
mid 3286
near end 3286
very top 3286
early top 3286
mid 3287
near end 3287
very top 3287
early top 3287
mid 3288
near end 3288
very top 3288
early top 3288
mid 3289
near end 3289
very top 3289
early top 3289
mid 3290
near end 3290
very top 3290
early top 3290
mid 3291
near end 3291
very top 3291
early top 3291
mid 3292
near end 3292
very top 3292
early top 3292
mid 3293
near end 3293
very top 3293
early top 3293
mid 3294
near end 3294
very top 3294
early top 3294
mid 3295
near end 3295
very top 3295
early top 3295
mid 3296
near end 3296
very top 3296
early top 3296
mid 3297
near end 3297
very top 3297
early top 3297
mid 3298
near end 3298
very top 3298
early top 3298
mid 3299
near end 3299
very top 3299
early top 3299
mid 3300
near end 3300
Epoch 3300
Minutes elapsed: 4.6644773244857785
Last 100 averages: Loss: -491.30867681751096, reward: 48.59, loss/reward: -10.111312550267769
very

mid 3435
near end 3435
very top 3435
early top 3435
mid 3436
near end 3436
very top 3436
early top 3436
mid 3437
near end 3437
very top 3437
early top 3437
mid 3438
near end 3438
very top 3438
early top 3438
mid 3439
near end 3439
very top 3439
early top 3439
mid 3440
near end 3440
very top 3440
early top 3440
mid 3441
near end 3441
very top 3441
early top 3441
mid 3442
near end 3442
very top 3442
early top 3442
mid 3443
near end 3443
very top 3443
early top 3443
mid 3444
near end 3444
very top 3444
early top 3444
mid 3445
near end 3445
very top 3445
early top 3445
mid 3446
near end 3446
very top 3446
early top 3446
mid 3447
near end 3447
very top 3447
early top 3447
mid 3448
near end 3448
very top 3448
early top 3448
mid 3449
near end 3449
very top 3449
early top 3449
mid 3450
near end 3450
very top 3450
early top 3450
mid 3451
near end 3451
very top 3451
early top 3451
mid 3452
near end 3452
very top 3452
early top 3452
mid 3453
near end 3453
very top 3453
early top 3453
mid 3454
nea

mid 3590
near end 3590
very top 3590
early top 3590
mid 3591
near end 3591
very top 3591
early top 3591
mid 3592
near end 3592
very top 3592
early top 3592
mid 3593
near end 3593
very top 3593
early top 3593
mid 3594
near end 3594
very top 3594
early top 3594
mid 3595
near end 3595
very top 3595
early top 3595
mid 3596
near end 3596
very top 3596
early top 3596
mid 3597
near end 3597
very top 3597
early top 3597
mid 3598
near end 3598
very top 3598
early top 3598
mid 3599
near end 3599
very top 3599
early top 3599
mid 3600
near end 3600
Epoch 3600
Minutes elapsed: 5.1437869985898335
Last 100 averages: Loss: -557.2381778798007, reward: 51.68, loss/reward: -10.782472482194285
very end 3600
very top 3600
Model loaded from /testing/
early top 3600
mid 3601
near end 3601
very top 3601
early top 3601
mid 3602
near end 3602
very top 3602
early top 3602
mid 3603
near end 3603
very top 3603
early top 3603
mid 3604
near end 3604
very top 3604
early top 3604
mid 3605
near end 3605
very top 3605
e

mid 3741
near end 3741
very top 3741
early top 3741
mid 3742
near end 3742
very top 3742
early top 3742
mid 3743
near end 3743
very top 3743
early top 3743
mid 3744
near end 3744
very top 3744
early top 3744
mid 3745
near end 3745
very top 3745
early top 3745
mid 3746
near end 3746
very top 3746
early top 3746
mid 3747
near end 3747
very top 3747
early top 3747
mid 3748
near end 3748
very top 3748
early top 3748
mid 3749
near end 3749
very top 3749
early top 3749
mid 3750
near end 3750
very top 3750
early top 3750
mid 3751
near end 3751
very top 3751
early top 3751
mid 3752
near end 3752
very top 3752
early top 3752
mid 3753
near end 3753
very top 3753
early top 3753
mid 3754
near end 3754
very top 3754
early top 3754
mid 3755
near end 3755
very top 3755
early top 3755
mid 3756
near end 3756
very top 3756
early top 3756
mid 3757
near end 3757
very top 3757
early top 3757
mid 3758
near end 3758
very top 3758
early top 3758
mid 3759
near end 3759
very top 3759
early top 3759
mid 3760
nea

## Recover saved metrics

In [None]:
(losses,rewards,pq_losses,grad_norms,predicted_trajectories,
     actual_trajectories,look_ahead_predictions) = pkl.load(open(split_model.model_name+'/metrics.pkl','rb'))

In [None]:
for var in grad_norms:
    plt.plot(var[1:], label = var[0])
    plt.legend()
    plt.show()

In [None]:
array_pq_losses = np.array(pq_losses).T
labels = ['x q50','x q90','x dot q50','x dot q90','theta q50','theta q90','theta dot q50','theta dot q90']
fig, (ax1,ax2) = plt.subplots(2,1, figsize = (25,8))
for i in range(array_pq_losses.shape[0]):
    if i%2 == 0:
        ax1.plot(array_pq_losses[i], label = labels[i])
    else:
        ax2.plot(array_pq_losses[i], label = labels[i])
ax1.legend()
# ax1.set_ylim(0.0,2.0)
ax2.legend()
# ax2.set_ylim(0.0,2.0)
plt.show()

In [None]:
plt.figure(figsize=(25,6))
plt.plot(losses)
plt.title('losses')
plt.show()
plt.figure(figsize=(25,6))
plt.plot(rewards)
plt.title('rewards')
plt.show()
plt.figure(figsize=(25,6))
plt.plot([loss/reward for loss,reward in zip(losses,rewards)])
plt.title('loss per reward')
# plt.ylim(0,500)
plt.show()

In [None]:
array_squared_errors = np.array(squared_errors)
for i in range(0,split_model.global_epoch,25):
    f, ax = plt.subplots(figsize=(25,6))
    ax.plot(array_squared_errors[i][0], label='x')
    ax.plot(array_squared_errors[i][1],label='x dot')
    ax.plot(array_squared_errors[i][2], label = 'theta')
    ax.plot(array_squared_errors[i][3], label = 'theta dot')
    plt.title(i)
    ax.legend()
    ax.set_xlim([0.0,200])
    plt.show()

In [None]:
lookaheadlabels = ['LA x', 'LA x dot', 'LA theta', 'LA theta dot']
predictedlabels = ['predicted x', 'predicted x dot', 'predicted theta', 'predicted theta dot']
truelabels = ['true x','true x dot', 'true theta','true theta dot']
array_actual_trajectories = np.array(actual_trajectories)
for k in range(0,split_model.global_epoch,50):
    plt.figure(figsize=(25,6))
    for j in range(4):
        plt.plot(np.squeeze(np.array([array_actual_trajectories[k][j][i] for i in range(len(array_actual_trajectories[k][j]))])), label = truelabels[j], color = 'k',linewidth=3)
        plt.plot(predicted_trajectories[k].numpy()[j], label=predictedlabels[j], color = 'y')
        plt.plot(np.array(look_ahead_predictions[k]).T[j], label = lookaheadlabels[j], color = 'r')
#     plt.legend(['x', 'x dot','theta', 'theta dot', 'x', 'x dot','theta', 'theta dot'])
    plt.legend()
    plt.title(k)
    plt.xlim(0,200)
    plt.show()