In [2]:
import gym
import math
import numpy as np
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch.distributions.relaxed_bernoulli import RelaxedBernoulli
from torch.distributions import Bernoulli
from torch.distributions.categorical import Categorical
from torch.distributions.kl import kl_divergence
from torch.distributions import Normal

import seaborn as sns

import time

#Some tools

In [3]:
def padding_tensor(sequences):
    """
    :param sequences: list of tensors with shape [seq, state dim]
    :return: tensor with shape [num, max_seq_length, state dim]
    """
    num = len(sequences)
    max_len = max([s.size(0) for s in sequences])
    feature_dim = sequences[0].size(-1)
    out_dims = (num, max_len, feature_dim)

    out_tensor = sequences[0].data.new(*out_dims).fill_(0)

    mask = sequences[0].data.new(*out_dims).fill_(0)
    for i, tensor in enumerate(sequences):
        length = tensor.size(0)
        out_tensor[i, :length, :] = tensorx
        mask[i, :length,:] = 1
    return out_tensor, mask


def truncate_sequence(list_of_sequence, batch_first, min_len=None):
    """
    list_of_sequence: list of tensor with shape [seq, feature dim]
    return : tensor with shape [min_seq_length, batch, fea_dim] or [batch, min_seq_length, fea_dim] if batch_first
    """
    feature_dim = list_of_sequence[0].size(-1)
    if min_len is None:
        min_len = min([s.size(0) for s in list_of_sequence])


    container = torch.zeros(len(list_of_sequence), min_len, feature_dim)
    for i in range(len(list_of_sequence)):
        #random truncation
        #start = np.random.choice((list_of_sequence[i].size(0) - min_len +1))
        #container[i] = list_of_sequence[i][start : start+min_len, :]
        container[i] = list_of_sequence[i][:min_len,:]
    
    if batch_first:
        return container
    else:
        return container.permute(1,0,2)



def save_model_policy(model, model_optimiser, policy, policy_optimiser, save_model_path, save_policy_path):
    save_model_path = save_model_path + "/model.tar"
    save_policy_path = save_policy_path + "/policy.tar" 
    torch.save({
        "model_dict": model.state_dict(),
        "trainer_dict": model_optimiser.state_dict()
    }, save_model_path)

    torch.save({
        "model_dict": policy.state_dict(),
        "trainer_dict": policy_optimiser.state_dict()
    }, save_policy_path)



    print('Checkpointed')

def plot_LLB(true_obs, mean, std):

    position = true_obs[1:,0,0].data.numpy()
    velocity = true_obs[1:,0,1].data.numpy()
    angle = true_obs[1:,0,2].data.numpy()
    angle_v = true_obs[1:,0,3].data.numpy()

    lower = mean - std      #[seq-1, output_dim]
    upper = mean + std
    position_mean = mean[:,0].data.cpu().numpy()
    velocity_mean = mean[:,1].data.cpu().numpy()
    angle_mean = mean[:,2].data.cpu().numpy()
    angle_v_mean = mean[:,3].data.cpu().numpy()

    lower_position = lower[:,0].data.cpu().numpy()
    lower_velocity = lower[:,1].data.cpu().numpy()
    lower_angle = lower[:,2].data.cpu().numpy()
    lower_angle_v = lower[:,3].data.cpu().numpy()


    upper_position = upper[:,0].data.cpu().numpy()
    upper_velocity = upper[:,1].data.cpu().numpy()
    upper_angle = upper[:,2].data.cpu().numpy()
    upper_angle_v = upper[:,3].data.cpu().numpy()

    fig, ax = plt.subplots(1,4, figsize = (20,5))
    x = np.arange(0, position.shape[0])

    ax[0].plot(position, label = 'True')
    ax[0].plot(x, position_mean, label = 'mean')
    ax[0].fill_between(x, upper_position, lower_position, facecolor='grey',
                                    color = 'grey', alpha = 0.2)
    ax[0].set_title('Position');
    ax[0].legend();

    ax[1].plot(velocity, label = 'True')
    ax[1].plot(x, velocity_mean, label = 'mean')
    ax[1].fill_between(x, upper_velocity, lower_velocity, facecolor='grey',
                                    color = 'grey', alpha = 0.2)
    ax[1].set_title('Velocity')
    ax[1].legend();


    ax[2].plot(angle, label = 'True')
    ax[2].plot(x, angle_mean, label = 'mean')
    ax[2].fill_between(x, upper_angle, lower_angle, facecolor='grey',
                                    color = 'grey', alpha = 0.2)
    ax[2].set_title('angle')
    ax[2].legend();

    ax[3].plot(angle_v, label = 'True')
    ax[3].plot(x, angle_v_mean, label = 'mean')
    ax[3].fill_between(x, upper_angle_v, lower_angle_v, facecolor='grey',
                                    color = 'grey', alpha = 0.2)
    ax[3].set_title('angle velocity')
    ax[3].legend();




#Env

In [4]:
# -*- coding: utf-8 -*-
"""
Classic cart-pole system implemented by Rich Sutton et al.
Copied from http://incompleteideas.net/sutton/book/code/pole.c
permalink: https://perma.cc/C9ZM-652R
Modified by Aaditya Ravindran to include friction and random sensor & actuator noise
"""

import logging
import math
import random
import gym
from gym import spaces
from gym.utils import seeding
import numpy as np

logger = logging.getLogger(__name__)

class CartPoleModEnv(gym.Env):
    metadata = {
            'render.modes': ['human', 'rgb_array'],
            'video.frames_per_second' : 50
    }

    def __init__(self,case):
        self.__version__ = "0.2.0"
        print("CartPoleModEnv - Version {}, Noise case: {}".format(self.__version__,case))
        self.gravity = 9.8
        self.masscart = 1.0
        self.masspole = 0.1
        self.total_mass = (self.masspole + self.masscart)
        self.length = 0.5 # actually half the pole's length
        self.polemass_length = (self.masspole * self.length)
        self.seed()

        self.origin_case = case
        if case<6:          #only model  noise
            self.force_mag = 30.0*(1+self.addnoise(case))
            self.case = 1
        elif case>9:    #both model and data noise
            self.force_mag = 30.0*(1+self.addnoise(case))
            self.case = 10
        else:               #only data noise
            self.force_mag = 30.0
            self.case = case
            
        self.tau = 0.02     # seconds between state updates

        self.min_action = -1.
        self.max_action = 1.0


		# Angle at which to fail the episode
        self.theta_threshold_radians = 12 * 2 * math.pi / 360
        self.x_threshold = 2.4

        # Angle limit set to 2 * theta_threshold_radians so failing observation is still within bounds
        high = np.array([
            self.x_threshold * 2,
            np.finfo(np.float32).max,
            self.theta_threshold_radians * 2,
            np.finfo(np.float32).max])

        #self.action_space = spaces.Discrete(2) # AA Set discrete states back to 2
        self.action_space = spaces.Box(
                low = self.min_action,
                high = self.max_action,
                shape = (1,) 
        )

        self.observation_space = spaces.Box(-high, high)

        self.viewer = None
        self.state = None

        self.steps_beyond_done = None

    def addnoise(self,x):
        return {
        1 : 0,
        2 : self.np_random.uniform(low=-0.05, high=0.05, size=(1,)), #  5% actuator noise ,  small model uniform noise
        3 : self.np_random.uniform(low=-0.10, high=0.10, size=(1,)), # 10% actuator noise ,  large model uniform noise
        4 : self.np_random.normal(loc=0, scale=np.sqrt(0.10), size=(1,)),                  # small model gaussian noise
        5 : self.np_random.normal(loc=0, scale=np.sqrt(0.50), size=(1,)),                 #  large model gaussian noise
        6 : self.np_random.uniform(low=-0.05, high=0.05, size=(1,)), #  5% sensor noise ,    small data uniform noise
        7 : self.np_random.uniform(low=-0.10, high=0.10, size=(1,)), # 10% sensor noise ,    large data uniform noise
        8 : self.np_random.normal(loc=0, scale=np.sqrt(0.10), size=(1,)), # 0.1              small data gaussian noise
        9 : self.np_random.normal(loc=0, scale=np.sqrt(0.20), size=(1,)), # 0.2              large data gaussian noise
        10: self.np_random.normal(loc = 0, scale = np.sqrt(0.10), size = (1,)),           #  small both gaussian noise
        11: self.np_random.normal(loc = 0, scale = np.sqrt(0.50), size = (1,)),          #    large both gaussian noise
        }.get(x,1)

    def seed(self, seed=None): # Set appropriate seed value
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def stepPhysics(self, force):
        x, x_dot, theta, theta_dot = self.state
        costheta = math.cos(theta)
        sintheta = math.sin(theta)
        temp = (force + self.polemass_length * theta_dot * theta_dot * sintheta) / self.total_mass
        thetaacc = (self.gravity * sintheta - costheta * temp) / \
                    (self.length * (4.0/3.0 - self.masspole * costheta * costheta / self.total_mass))
        xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass
        #noise = self.addnoise(self.case) 
        x  = (x + self.tau * x_dot)
        x_dot = (x_dot + self.tau * xacc)
        theta = (theta + self.tau * theta_dot)#*(1 + noise)
        theta_dot = (theta_dot + self.tau * thetaacc)
        return (x, x_dot, theta, theta_dot)


    def step(self, action):
        assert self.action_space.contains(action), "%r (%s) invalid"%(action, type(action))
        force = self.force_mag * float(action)
        self.state = self.stepPhysics(force)
        x, x_dot, theta, theta_dot = self.state         #true state

        #adding measurement noisy to theta
        noise = self.addnoise(self.case)
        theta = theta * (1+noise)
        noise = self.addnoise(self.case)
        x = x * (1+noise)
        noise = self.addnoise(self.case)
        x_dot = x_dot*(1+noise)
        noise = self.addnoise(self.case)
        theta_dot = theta_dot*(1+noise)

        output_state = (x, x_dot, theta, theta_dot) 
        output_state = np.array(output_state)  


        done = x < -self.x_threshold \
            or x > self.x_threshold \
            or theta < -self.theta_threshold_radians \
            or theta > self.theta_threshold_radians
        done = bool(done)

        if not done:
            reward = 1.0
        elif self.steps_beyond_done is None:
            # Pole just fell!
            self.steps_beyond_done = 0
            reward = 1.0
        else:
            if self.steps_beyond_done == 0:
                logger.warn("You are calling 'step()' even though this environment has already returned done = True. You should always call 'reset()' once you receive 'done = True' -- any further steps are undefined behavior.")
            self.steps_beyond_done += 1
            reward = 0.0

        #return np.array(self.state), reward, done, {}
        return output_state, reward, done, {}


    def reset(self):
        self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
        self.steps_beyond_done = None

        #also reset the force
        if self.origin_case<6:          #only model  noise
            self.force_mag = 30.0*(1+self.addnoise(self.origin_case))
        elif self.origin_case>9:    #both model and data noise
            self.force_mag = 30.0*(1+self.addnoise(self.origin_case))
        else:               #only data noise
            self.force_mag = 30.0

        return np.array(self.state)

    def render(self, mode='human', close=False):
        if close:
            if self.viewer is not None:
                self.viewer.close()
                self.viewer = None
            return

        screen_width = 600
        screen_height = 400

        world_width = self.x_threshold*2
        scale = screen_width/world_width
        carty = 100 # TOP OF CART
        polewidth = 10.0
        polelen = scale * 1.0
        cartwidth = 50.0
        cartheight = 30.0

        if self.viewer is None:
            from gym.envs.classic_control import rendering
            self.viewer = rendering.Viewer(screen_width, screen_height)
            l,r,t,b = -cartwidth/2, cartwidth/2, cartheight/2, -cartheight/2
            axleoffset =cartheight/4.0
            cart = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)])
            self.carttrans = rendering.Transform()
            cart.add_attr(self.carttrans)
            self.viewer.add_geom(cart)
            l,r,t,b = -polewidth/2,polewidth/2,polelen-polewidth/2,-polewidth/2
            pole = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)])
            pole.set_color(.8,.6,.4)
            self.poletrans = rendering.Transform(translation=(0, axleoffset))
            pole.add_attr(self.poletrans)
            pole.add_attr(self.carttrans)
            self.viewer.add_geom(pole)
            self.axle = rendering.make_circle(polewidth/2)
            self.axle.add_attr(self.poletrans)
            self.axle.add_attr(self.carttrans)
            self.axle.set_color(.5,.5,.8)
            self.viewer.add_geom(self.axle)
            self.track = rendering.Line((0,carty), (screen_width,carty))
            self.track.set_color(0,0,0)
            self.viewer.add_geom(self.track)

        if self.state is None: return None

        x = self.state
        cartx = x[0]*scale+screen_width/2.0 # MIDDLE OF CART
        self.carttrans.set_translation(cartx, carty)
        self.poletrans.set_rotation(-x[2])
        return self.viewer.render(return_rgb_array = mode=='rgb_array')

#RSSM

In [5]:


class RSSM(nn.Module):
    def __init__(self, input_size=1, hidden_size=32, output_size=4, state_size=32, device = 'cpu', mode = 'LSTM'):            #action, hidden and observation dim
        super(RSSM, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.state_size = state_size

        self.mode = mode
        self.device = device

                


        if mode == 'RNN':
            self.transition_RNN = nn.RNNCell(input_size = hidden_size, hidden_size = hidden_size)
            #self.transition_RNN = nn.RNN(input_size = input_size, hidden_size = hidden_size)
        elif mode == 'LSTM':
            self.transition_RNN = nn.LSTMCell(input_size = hidden_size, hidden_size = hidden_size)
            #self.transition_RNN = nn.LSTM(input_size = input_size, hidden_size = hidden_size)
        elif mode == 'GRU':
            self.transition_RNN = nn.GRUCell(input_size = hidden_size, hidden_size = hidden_size)
            #self.transition_RNN = nn.GRU(input_size = input_size, hidden_size = hidden_size)



        #linear layer converting the state + action
        self.state_action_layer = nn.Sequential(
            nn.Linear(self.state_size+self.input_size, self.hidden_size),
            nn.ReLU()
        )
        #prior
        #self.hidden_prior = nn.Sequential(
        #    nn.Linear(self.hidden_size, self.hidden_size),
        #    nn.ReLU()
        #)
        self.prior_mean = nn.Linear(self.hidden_size, self.state_size)
        self.prior_sigma = nn.Linear(self.hidden_size, self.state_size)
        self._min_stddev = 0.1

        #poster
        self.hidden_obs = nn.Sequential(
            nn.Linear(self.hidden_size+self.output_size, self.hidden_size),
            nn.ReLU()
        )
        self.poster = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ReLU()
        )
        self.post_mean = nn.Linear(self.hidden_size, self.state_size)
        self.post_sigma = nn.Linear(self.hidden_size, self.state_size)

        #decoder
        self.state_hidden = nn.Sequential(
            nn.Linear(self.state_size+self.hidden_size, self.hidden_size),
            nn.ReLU()
        )
        #self.obs = nn.Sequential(
        #    nn.Linear(self.hidden_size, self.hidden_size),
        #    nn.ReLU()
        #)
        self.obs_mean = nn.Linear(self.hidden_size, self.output_size)
        self.obs_sigma = nn.Linear(self.hidden_size, self.output_size)


        #intial hidden encoder (take the fist observation as input and output an initial hidden state) 
        self.init_h = nn.Sequential(
            nn.Linear(self.output_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.Tanh()
        )

        self.loss_list = []


    def prior(self, state, action, rnn_hidden, rnn_hidden_c=None):
        """ 
        h_t+1 = f(h_t, s_t, a_t)
        prior : p(s_t+1 | h_t+1)

        state : [batch, state_dim]
        action : [batch, action_dim]
        rnn_hidden; [batch, rnn hidden dim]
        """
        state_action = self.state_action_layer(torch.cat([state, action], dim = -1))         #[batch, hidden]
        if self.mode == 'LSTM':
            rnn_hidden, rnn_hidden_c = self.transition_RNN(state_action, (rnn_hidden, rnn_hidden_c))
        else:
            rnn_hidden = self.transition_RNN(state_action, rnn_hidden)      #[batch, hidden]
        #hidden_prior = self.hidden_prior(rnn_hidden)            #[batch, hidden]
        hidden_prior = rnn_hidden
        prior_mean = self.prior_mean(hidden_prior)              #[batch, state]
        #prior_sigma = F.softplus(self.prior_sigma(hidden_prior)) + self._min_stddev     #[batch, state]
        prior_sigma = torch.exp(self.prior_sigma(hidden_prior))
        if self.mode == 'LSTM':
            return prior_mean, prior_sigma, rnn_hidden, rnn_hidden_c
        else:
            return prior_mean, prior_sigma, rnn_hidden       

    def posterior(self, rnn_hidden, obs):
        """
        posterior q(s_t | h_t, o_t)

        rnn_hidden: [batch, hidden]
        embedded_obs : [batch, output_dim]
        """
        hidden_obs = self.hidden_obs(torch.cat([rnn_hidden, obs], dim = -1))        #[batch, hidden]
        poster = self.poster(hidden_obs)            #[batch, hidden]
        poster_mean = self.post_mean(poster)        #[batch, state]
        #poster_sigma = F.softplus(self.post_sigma(poster)) + self._min_stddev     #[batch, state]
        poster_sigma = torch.exp(self.post_sigma(poster))

        return poster_mean, poster_sigma

    def obs_model(self, state, rnn_hidden):
        """
        p(o_t | s_t, h_t)
        """
        state_hidden = self.state_hidden(torch.cat([state, rnn_hidden], dim = -1))      #[batch, hidden]
        #obs = self.obs(state_hidden)            #[batch, hidden]
        obs = state_hidden
        obs_mean = self.obs_mean(obs)           #[batch, output_size]
        #obs_sigma = F.softplus(self.obs_sigma(obs))+self._min_stddev         #[batch, output_size]
        obs_sigma = torch.exp(self.obs_sigma(obs))

        return obs_mean, obs_sigma




    def forward(self, X, A, beta=1,print_output=False):
        """
        Likelihood objective function for a given trajectory (change to batched verision later)
        X: data matrix of shape [seq_length, batch, output_size]]      (we only feed one trajectory here for testing)
        A: data matrix of action [seq_length-1, batch]
        """
        assert X.size(0) == A.size(0)+1, print('the seq length of X and A are wrong')
        kl_loss = 0             #KL divergence term
        Ell_loss = 0             #expected log likelihood term
        batch_size = X.size(1)

        if len(X.size()) != 3:
            print('The input data matrix should be the shape of [seq_length, batch_size, input_dim]')
        
        X = X.to(self.device)
        A = A.to(self.device)

        #container
        states = torch.zeros(A.size(0), A.size(1), self.state_size).to(self.device)         #[seq-1, batch, state]
        rnn_hiddens = torch.zeros(A.size(0), A.size(1), self.hidden_size).to(self.device)   #[seq-1, batch, hidden]

        #initialising state and rnn hidden state
        #state = torch.zeros(X.size(1), self.state_size).to(self.device)
        rnn_hidden = self.init_h(X[0]).to(self.device)          #[batch, hidden]
        if self.mode == 'LSTM':
            rnn_hidden_c = torch.zeros_like(rnn_hidden).to(self.device)     #[batch, hidden]

        #temp_prior = self.hidden_prior(rnn_hidden)      #[batch, state]
        temp_prior = rnn_hidden
        prior_mean = self.prior_mean(temp_prior)        #[batch, state]
        prior_sigma = torch.exp(self.prior_sigma(temp_prior))       #[batch, state]
        state = self.reparametrise(prior_mean, prior_sigma)     #[batch, state]

        #rnn_hidden = torch.zeros(X.size(1), self.hidden_size).to(self.device)

        
        #emission_mean = X[0]
        for t in range(1, X.size()[0]):          #for each time step, compute the free energy for each batch of data (start from the second hid state)
            if self.mode == 'LSTM':
                next_state_prior_m, next_state_prior_sigma, rnn_hidden, rnn_hidden_c= self.prior(state, A[t-1].unsqueeze(-1), 
                                                                                                 rnn_hidden, rnn_hidden_c)
            else:
                next_state_prior_m, next_state_prior_sigma, rnn_hidden = self.prior(state, A[t-1].unsqueeze(-1), rnn_hidden)

            next_state_post_m, next_state_post_sigma = self.posterior(rnn_hidden, X[t])
            state = self.reparametrise(next_state_post_m, next_state_post_sigma)        #[batch, state_size]
            states[t-1] = state
            rnn_hiddens[t-1] = rnn_hidden
            next_state_prior = Normal(next_state_prior_m, next_state_prior_sigma)
            next_state_post = Normal(next_state_post_m, next_state_post_sigma)

            #kl = kl_divergence(next_state_prior, next_state_post).sum(dim=1)        #[batch]
            kl = kl_divergence(next_state_post, next_state_prior).sum(dim=1)        #[batch]

            kl_loss += kl.mean()
        kl_loss /= A.size(0)

        #compute nll

        #flatten state
        flatten_states = states.view(-1, self.state_size)
        flatten_rnn_hiddens = rnn_hiddens.view(-1, self.hidden_size)
        flatten_x_mean, flatten_x_sigma = self.obs_model(flatten_states, flatten_rnn_hiddens)


        nll = self.batched_gaussian_ll(flatten_x_mean, flatten_x_sigma, X[1:,:,:].reshape(-1, self.output_size))
        nll = nll.mean()

        FE = nll - kl_loss

        if print_output:
            #print('ELL loss=', Ell_loss, 'KL loss=', kl_loss)
            print('Free energy of this batch = {}. Nll loss = {}. KL div = {}.'.format(float(FE.data)
                                                                        , float(nll.data), 
                                                                        float(kl_loss.data)))
  

        return FE, nll, kl_loss


    def mc_predict(self, initial_obs, actions, mean_obs = False):
        """ 
        initial_obs : [1, output_dim]
        actions: [seq-1, 1, action_dim]
        """

        initial_obs = initial_obs.to(self.device)
        actions = actions.to(self.device)

        total_list = []

        time_step = actions.size(0) 

        total_list = []
                                 #[1, output]
        for i in range(200):
            temp_pred = []
            #container
            #states = torch.zeros(actions.size(0), actions.size(1), self.state_size).to(self.device)         #[seq-1, 1, state]
            #rnn_hiddens = torch.zeros(actions.size(0), actions.size(1), self.hidden_size).to(self.device)   #[seq-1, 1, hidden]

            #initialising state and rnn hidden state
            #state = torch.zeros(initial_obs.size(0), self.state_size).to(self.device)             #[1, state]
            rnn_hidden = self.init_h(initial_obs)                                       #[1, hidden]
            if self.mode == 'LSTM':
                rnn_hidden_c = torch.zeros_like(rnn_hidden)         #[1, hidden]
            
            #temp_prior = self.hidden_prior(rnn_hidden)      #[1, state]
            temp_prior = rnn_hidden
            prior_mean = self.prior_mean(temp_prior)
            prior_sigma = torch.exp(self.prior_sigma(temp_prior))
            state = self.reparametrise(prior_mean, prior_sigma)

            #x_sample = initial_obs      #[1, output_dim]
            for t in range(time_step):
                if self.mode == 'LSTM':
                    next_state_prior_m, next_state_prior_sigma, rnn_hidden, rnn_hidden_c= self.prior(state, 
                                                                                                     actions[t], 
                                                                                                     rnn_hidden, rnn_hidden_c)
                else:
                    next_state_prior_m, next_state_prior_sigma, rnn_hidden = self.prior(state, actions[t], rnn_hidden)
                    
                state = self.reparametrise(next_state_prior_m, next_state_prior_sigma)


                #next_state_post_m, next_state_post_sigma = self.posterior(rnn_hidden, x_sample)
                #state = self.reparametrise(next_state_post_m, next_state_post_sigma)        #[batch, state_size]

                x_mean, x_sigma = self.obs_model(state, rnn_hidden)      
                if mean_obs:
                    x_sample = x_mean
                else:
                    x_sample = self.reparametrise(x_mean, x_sigma)          #[1, output_dim]

                temp_pred.append(x_sample.unsqueeze(0))                 #list of shape [1,1,output]
            temp_pred_vec = torch.cat(temp_pred, dim = 0)       #[seq-1, 1, output]

            total_list.append(temp_pred_vec.unsqueeze(-1))      #list of shape [seq-1, 1, output, 1]
        total_list = torch.cat(total_list, dim = -1)            #[seq-1, 1, output, 200]
        mean = total_list.mean(dim = -1)                    #[seq-1, 1, output]
        std = total_list.std(dim = -1)                      #[seq-1, 1, output]

        return mean, std
    
    def imagine(self, init_x, control_f, horizon, plan, mean_obs = False):
        """
        init_x : [batch, output]
        """
        init_x = init_x.to(self.device)
        rnn_hidden = self.init_h(init_x)
        if self.mode == 'LSTM':
            rnn_hidden_c = torch.zeros_like(rnn_hidden).to(self.device)

        #temp_prior = self.hidden_prior(rnn_hidden)      #[1, state]
        temp_prior = rnn_hidden
        prior_mean = self.prior_mean(temp_prior)
        prior_sigma = torch.exp(self.prior_sigma(temp_prior))
        state = self.reparametrise(prior_mean, prior_sigma)

        x_sample = init_x
        pred = []
        action_log_prob_list = []
        for t in range(horizon):
            if plan == 'pg':
                action_samples, action_log_prob = control_f(x_sample)
                action_log_prob_list.append(action_log_prob.unsqueeze(0))           #[1, batch, 1]

            elif plan == 'rp':
                action_samples, _= control_f(x_sample)          #[batch, 1]
                action_log_prob_list = 0
            else:
                raise NotImplementedError
            
            if self.mode == 'LSTM':
                next_state_prior_m, next_state_prior_sigma, rnn_hidden, rnn_hidden_c= self.prior(state, 
                                                                                                    action_samples, 
                                                                                                    rnn_hidden, rnn_hidden_c)   
            else:
                next_state_prior_m, next_state_prior_sigma, rnn_hidden = self.prior(state, action_samples, rnn_hidden)
            
            state = self.reparametrise(next_state_prior_m, next_state_prior_sigma)

            x_mean, x_sigma = self.obs_model(state, rnn_hidden)      
            if mean_obs:
                x_sample = x_mean
            else:
                x_sample = self.reparametrise(x_mean, x_sigma)          #[1, output_dim]
            pred.append(x_sample.unsqueeze(0))

        if plan == 'pg':
            action_log_prob_list = torch.cat(action_log_prob_list)  #[seq-1, batch, 1]

        return torch.cat(pred), action_log_prob_list

    def validate_by_imagination(self, init_x, control_f, plan, mean_obs = False):
        """
        Perform planning on learnt model as opposed to real dynamics
        """
        init_x = init_x.to(self.device)
        rnn_hidden = self.init_h(init_x)
        if self.mode == 'LSTM':
            rnn_hidden_c = torch.zeros_like(rnn_hidden).to(self.device)

        #temp_prior = self.hidden_prior(rnn_hidden)      #[1, state]
        temp_prior = rnn_hidden
        prior_mean = self.prior_mean(temp_prior)
        prior_sigma = torch.exp(self.prior_sigma(temp_prior))
        state = self.reparametrise(prior_mean, prior_sigma)

        x_sample = init_x
        pred = []
        action_log_prob_list = []
        reward = 0
        iter = 0

        while True:
            action_samples, _= control_f(x_sample)          #[batch, 1]

            if self.mode == 'LSTM':
                next_state_prior_m, next_state_prior_sigma, rnn_hidden, rnn_hidden_c= self.prior(state, action_samples, rnn_hidden, rnn_hidden_c)   
            else:
                next_state_prior_m, next_state_prior_sigma, rnn_hidden = self.prior(state, action_samples, rnn_hidden)
        
            state = self.reparametrise(next_state_prior_m, next_state_prior_sigma)

            x_mean, x_sigma = self.obs_model(state, rnn_hidden)      
            if mean_obs:
                x_sample = x_mean
            else:
                x_sample = self.reparametrise(x_mean, x_sigma)          #[1, output_dim]

            reward += 1
            iter += 1

            done = x_sample[:,0] < -2.4 \
                or x_sample[:,0] > 2.4 \
                or x_sample[:,2] < -12 * 2 * math.pi / 360 \
                or x_sample[:,2] > 12 * 2 * math.pi / 360 \
                or iter >= 200
            done = bool(done)
            if done:
                break   
            
        return reward




    def reparametrise(self, mean, sigma):
        """
        sigma should have the same shape as mean (no correaltion)
        """
        eps = torch.rand_like(sigma).normal_()
        eps = eps.to(self.device)
        return mean + sigma*eps


    def batched_gaussian_ll(self, mean, sigma, x):
        """
        log-likelihood of batched observation
        mean : shape [batch, output_size]
        sigma  : shape [batch, output_size]   (diagonal covariance)
        x    : shape [batch, output_size]
        the shape of final result is [batch, ]
        """
        #mean = mean.to(self.device)
        #sigma = sigma.to(self.device)
        if 0 in sigma:
            #sigma = sigma + 1e-10
            print('Zero occurs in diagonal sigma matrix. (batched gaussian ll)')
        if 0 in sigma**2:
            print('Zero occurs after squaring sigma matrix. (batched gaussian ll)')

        inv_diag_cov = self.diagonalise(1/(sigma**2), batch=True)              #a 2d batched matrix----> 3d batched diagonal tensor      


        exp = ((x - mean).unsqueeze(-2)) @ inv_diag_cov @ ((x-mean).unsqueeze(-1))      #
        exp = exp.squeeze()         #[batch]   
        #print(exp) 

        if 0 in torch.prod(sigma**2, dim = -1):
            print('Zero occurs when calculating determinant of diagonal covariance. (batched gaussian ll)')


        logdet = torch.sum(2 * torch.log(sigma) , dim = -1)
        #logdet = torch.log(torch.prod(sigma**2, dim = -1))         #product of all diagonal variance for each batch, shape [batch]
        #print('logdet=', logdet)
        n = mean.size()[-1]


        return -(n/2) * np.log(2*np.pi) - 0.5*logdet - 0.5 * exp        #need double checking

     
    
    def diagonalise(self, input, batch):
        """
        if input a vector, return a diagonal matrix
        if input a non-batched 2d matrix, return a diagonal matrix, eg: [[1,2],[3,4]] ---> diag([1,2,3,4])
        if input a batched 2d matrix, return a batched diagonal matrix
        if input a 3d batched tensor, return a batched diagonal tensor
        """
        if len(input.size())==1:
            return torch.diag(input)
        if len(input.size())==2:
            if not batch:
                return torch.diag(vec(input))
            else:
                bdiag = torch.Tensor().to(self.device)
                for i in range(input.size()[0]):
                    bdiag = torch.cat((bdiag, torch.diag(input[i]).unsqueeze(0)), axis = 0)
                return bdiag

        if len(input.size())==3 and batch:
            bdiag = torch.Tensor()
            for i in range(input.size()[0]):
                bdiag = torch.cat((bdiag, torch.diag(vec(input[i])).unsqueeze(0)), axis = 0)
    
            return bdiag
        else:
            print('Dimension of inpout tensor should only be 1,2,3.')



    def print_loss(self):
        return self.loss_list


    def print_params(self):

        pass



#Deterministic controller

In [6]:
class controller(nn.Module):
    def __init__(self, action_dim=1, state_dim=4, deterministic = True,device = 'cuda'):
        super(controller, self).__init__()
        self.action_dim = action_dim
        controller_hid = 16
        self.state_dim  = state_dim
        init_w = 1e-3

        self.linear = nn.Sequential(
            nn.Linear(self.state_dim,controller_hid),
            nn.ReLU(),
            nn.Linear(controller_hid, controller_hid),
            nn.ReLU(),
            nn.Linear(controller_hid, action_dim)
        )




        #if not deterministic:
        #    self.std = 0.1
        
        #self.deterministic = deterministic
        
        #nn.Linear(self.state_dim, self.action_dim)        #output a p_logits for action 1
        self.device = device
        self.optimiser = torch.optim.Adam(self.parameters(), lr = 0.001)

    def forward(self, state):
        """
        Given states input [batch, state_dim],
        """
        state = state.to(self.device)
        #x = F.relu(self.linear1(state))
        #x = F.relu(self.linear2(x))
        a = self.linear(state)
        a = torch.tanh(a)
        log_pi = 0

        return a, log_pi


        #out_mean = self.linear(state)         #[batch, action_dim]
        
        #if not self.deterministic:
        #    eps = torch.rand_like(out_mean).normal_().to(self.device)           #[batch, action_dim]
        #    out = out_mean + self.std * eps
        #else:
        #    out = out_mean

        #if len(out.shape) == 1:
        #    out = torch.clamp(out, -1, 1)
        #else:
        #    out = torch.clamp(out[:,0], -1, 1).unsqueeze(1)             #[1, batch, 1]

        #return torch.tanh(out)


        #if len(out.shape) == 1:
        #    return torch.clamp(out, -1, 1)
        #else:
        #    clamp_out = torch.clamp(out[:, 0], -1, 1).unsqueeze(-1)
        #    return clamp_out
    
    def make_decision(self, state, behaviour_uncertainty):
        """
        given a state [batch, state_dim], output a action
        """
        state = state.to(self.device)
        #x = F.relu(self.linear1(state))
        #x = F.relu(self.linear2(x))
        a = self.linear(state)
        a = torch.tanh(a)

        return a.detach()

        #out_mean = self.linear(state)
        #if not self.deterministic and behaviour_uncertainty:
        #    eps = torch.rand_like(out_mean).normal_().to(self.device) 
        #    out = out_mean + self.std * eps
        #else:
        #    out = out_mean
        #if len(out.shape) == 1:

        #    out = torch.clamp(out, -1, 1)
        #else:
        #    out = torch.clamp(out[:,0], -1, 1)
        #return out.detach()
        #return torch.tanh(out)
    def pg_train(self, num_epoch, initial_state, horizon, cost_f, model_imagine_f, w_uncertainty, e_uncertainty,gamma = 0.95):
        """
        initial_state : [batch, state_dim]

        """
        loss_list = []
        num_particle = 100
        initial_state = initial_state.expand(num_particle, self.state_dim)

        for e in range(num_epoch):
            output_matrix, action_log_prob_matrix = model_imagine_f(initial_state, self.forward, horizon, plan = 'pg',
                                                                    W_uncertainty = w_uncertainty, e_uncertainty = e_uncertainty)
            cost = cost_f(output_matrix).detach()               #[seq-1, batch, 1]  
            
            cost = cost * torch.tensor([gamma**(t+1) for t in range(cost.size(0))]).unsqueeze(-1).unsqueeze(-1).to(self.device)

            #baseline = torch.mean(cost, dim = 0).unsqueeze(0)       #[1, batch, 1]
            #cost = cost - torch.mean(cost, dim = 0).unsqueeze(0)
            #loss = ((cost-baseline) * action_log_prob_matrix).sum(0)
            loss = cost.sum(0) * action_log_prob_matrix.sum(0) 

            loss = loss.sum()
            loss.backward()
            nn.utils.clip_grad_norm_(self.parameters(), 5)

            self.optimiser.step()
            loss_list.append(loss.item())
            if e%50 == 0:
                print('Epoch = {}; Policy gradient training loss = {}'.format(e, loss.item()))
        return loss_list




    
    def rp_train(self, num_epoch, num_particle,initial_state, horizon , cost_f, model_imagine_f, mean_obs, gamma = 0.9):
        """
        From an initial state, use mode imagination function to make prediction of the next state accordin to the action proposed by
        the controller, we fixed the horizon and compute the total reward of the trajectory, from which the gradient w.r.t policy 
        parameters is taken.
        inital_state: [batch, output_dim]
        """
        loss_list = []
        num_particle = num_particle
        initial_state = initial_state.expand(num_particle, self.state_dim)

        cost_mean_list = []
        cost_std_list = []
        
        for e in range(num_epoch):
            self.optimiser.zero_grad()
            output_matrix, action_matrix= model_imagine_f(initial_state, self.forward, horizon, plan = 'rp', mean_obs = mean_obs) 
            
            #print('output matrix dim = ',output_matrix.shape)
            self.action_matrix = action_matrix
            self.temp_output_matrix = torch.cat([initial_state.unsqueeze(0).to(self.device), output_matrix], dim = 0)

            cost = cost_f(output_matrix)                 #[seq-1, batch, 1]  

            mean_cost = cost.data.sum(0).mean(0)      #[]
            std_cost = cost.data.sum(0).std(0)
            cost_mean_list.append(mean_cost)
            cost_std_list.append(std_cost)

            #multiply by discount factor
            #cost = cost *  ((torch.arange(cost.size(0)+1,1,-1).float()).unsqueeze(-1).unsqueeze(-1)/cost.size(0)
            #                    ).expand(cost.shape).float().to(self.device)
            
            cost = cost * torch.tensor([gamma**(t+1) for t in range(cost.size(0))]).unsqueeze(-1).unsqueeze(-1).to(self.device)

            loss = cost.sum()      #[batch, 1]
            #loss = torch.exp(action_log_prob_matrix.sum(0)) * cost.sum(0)
            #loss = (cost * action_log_prob_matrix).sum(0)                 #[batch, 1]
            loss.backward()

            nn.utils.clip_grad_norm_(self.parameters(), 1)
            self.optimiser.step()
            loss_list.append(loss.item())
            #print('policy loss = {}', loss.item())
            #if e%10 == 0:
            #print('Epoch = {}; Policy gradient training loss = {}; Cost: mean {} std {}.'.format(e, loss.item()/num_particle,
            #                                                                                           mean_cost.item(), std_cost.item()))


        return loss_list, torch.cat(cost_mean_list), torch.cat(cost_std_list)

    def rp_validate(self, num_particle, initial_state, horizon, cost_f, model_imagine_f, mean_obs):
        initial_state = initial_state.expand(num_particle, self.state_dim)
        output_matrix, action_matrix = model_imagine_f(initial_state, self.forward, horizon, plan = 'rp', mean_obs = mean_obs)

        cost = cost_f(output_matrix)
        mean_cost = cost.data.sum(0).mean(0)
        return mean_cost.item()




#Agent

In [7]:
class Agent:
    def __init__(self, env_case, state_dim = 4, action_dim = 1, mode = 'LSTM',device='cuda', rand_seed = 1):
        

        self.env =  CartPoleModEnv(case = env_case)
        self.env_case = env_case

        #self.env = gym.make('CartPoleMod-v2')

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.device = device
        self.observations_list = []
        self.actions_list= []



        self.model = RSSM(input_size=action_dim, hidden_size=32, output_size=4, state_size= 32, device=device,
                          mode = mode).to('cuda')


        self.model_optimiser = torch.optim.Adam(self.model.parameters(), lr = 0.001)
        self.model_training_loss_list = []

        self.policy = controller(action_dim, state_dim, device).to(device)

        #np.random.seed(rand_seed)
        #self.env.seed(rand_seed)
        #torch.manual_seed(rand_seed)
    
    def env_reset(self):
        self.env = ContinuousCartPoleEnv(case = self.env_case)

    
    def env_rollout(self, if_remember, plan, behaviour_uncertainty):
        """
        interaction with the environment using the current policy. 
        """
        done = False
        state = self.env.reset()
        total_reward = 0
        i = 0
        temp_obs_list = []
        temp_actions_list = []

        if if_remember:
            temp_obs_list.append(torch.tensor(state))
        
        while not done:
            i+=1
            
            if plan == 'random':
                action = self.env.action_space.sample()
            elif plan == 'pg' or 'rp':

                state_tensor = torch.tensor(np.vstack(state)).float().squeeze()


                action = self.policy.make_decision(state_tensor.to(self.device), behaviour_uncertainty)
                action = action.detach().cpu().numpy()

                
            else:
                raise NotImplementedError
            #print('action = ', action)

            #next_state, reward, done, _  = self.env.step(int(action))
            next_state, reward, _, _  = self.env.step(action)

            #print('next state = ', next_state)

            done = next_state[0] < -2.4 \
                or next_state[0] > 2.4 \
                or next_state[2] < -12 * 2 * math.pi / 360 \
                or next_state[2] > 12 * 2 * math.pi / 360 \
                or i >= 200
            
            #print('done = ', done)

            if if_remember:

                temp_obs_list.append(torch.tensor(np.vstack(next_state)).squeeze())

                temp_actions_list.append(torch.tensor(action).float())
            
            state = next_state
            total_reward += 1

            #if total_reward > 200:
            #    break

        if if_remember:
            self.observations_list.append(torch.stack(temp_obs_list).float())       #list of shape [seq, output]
            self.actions_list.append(torch.stack(temp_actions_list).float())        #list of shape [seq-1, 1]

        return total_reward
    
    def model_learning(self, num_epoch, num_batch):
        """
        perform model leanring using data self.observation_list and self.actions_list; since the data has variable length, one could 
        try truncate the data into same length or pack_padded_sequence, but here we would simply train each single sample in a batch,
        and during each epoch, the parameter is only updated once using part of the dataset
        num_epoch : number of training epoch
        num_batch: this is actually number of samples we want the model to be trained on during each epoch
        """



        for e in range(num_epoch):
            self.model_optimiser.zero_grad()

            idx = np.random.choice(len(self.observations_list), num_batch)
            trun_obs = truncate_sequence([self.observations_list[i] for i in idx], batch_first = False)
            trun_actions = truncate_sequence([self.actions_list[j] for j in idx], batch_first = False)

            b_FE, b_LL, b_KL = self.model(trun_obs, trun_actions.squeeze())
            loss = -b_FE
            loss.backward()

            nn.utils.clip_grad_norm_(self.model.parameters(), 1)
            #for i in idx:
            #    training_obs = self.observations_list[i].unsqueeze(1)       #[seq, 1, output]
            #    training_actions = self.actions_list[i]                      #[seq-1, 1]

            #    pred = self.model(training_obs[0,:,:], training_actions.unsqueeze(-1))
            #    loss = self.mseloss(torch.cat(pred).unsqueeze(1), training_obs[1:, :, :].to(self.device))
            #    temp_loss += loss
            #temp_loss.backward()

            self.model_optimiser.step()
            self.model_training_loss_list.append(loss.item())

            if e%1000 == 0:
                print('Epoch{}; FE = {}; LL = {}; KL = {}.'.format(e, b_FE.item(), b_LL.item(), b_KL.item()))



    def cost(self, state):
        """
        cost = 5*angle^2 + position^2
        state : [seq, batch, output]
        return [seq, batch, 1]
        """
        return (5*state[:,:,2]**2 + state[:,:,0]**2).unsqueeze(-1)      #[seq, batch, 1]

    '''
    def cost(self, states, sigma=0.25):
        """
        states : [seq, batch, output]
        return : [seq, batch, 1]
        """
        l = 0.6
        seq_length = states.size(0)
        batch_size = states.size(1)
        feature_dim = states.size(-1)
        
        goal = Variable(torch.FloatTensor([0.0, l])).unsqueeze(0).unsqueeze(0).expand(seq_length,1, 2).to(self.device)     #[seq, 1,2]

        # Cart position
        cart_x = states[:,:, 0]         #[seq, batch]
        # Pole angle
        thetas = states[:,:,2]          #[seq, bnatch]
        # Pole position
        x = torch.sin(thetas)*l         #[seq, batch]
        y = torch.cos(thetas)*l
        positions = torch.stack([cart_x + x, y], -1)             #[seq, batch, 2]

        
        squared_distance = torch.sum((goal - positions)**2, -1).unsqueeze(-1)          #[]

        squared_sigma = sigma**2
        cost = 1 - torch.exp(-0.5*squared_distance/squared_sigma)
        
        return cost
    '''




    def policy_learning(self, imagine_num, num_particle,num_epoch, batch_size, horizon, plan, mean_obs,
                        plot=False):
        """
        we utilise the current learned model to do policy learning on imagined data
        num_epoch : number of epochs we want to run our policy gradient for
        batch_size : number of samples we want to train the policy on/ number of initial states

        we creat batch_size number of initial state, the model then rollout for a fixed length(horizon), the sum of cost for each 
        imagined trajectory is computed, from which the gradient is taken w.r.t the policy parameters
        """
        #creat inital states 
        for i in range(imagine_num):
            #initial_state = []
            #for b in range(batch_size):
            #    init_x = self.env.reset()
            #    initial_state.append(torch.tensor(init_x).float())
            #initial_state = torch.stack(initial_state)          #[batch, output]
            initial_state = torch.tensor(self.env.reset()).unsqueeze(0).float()         #[1, output]
            if plot:
                initial_state = torch.zeros_like(initial_state)
            #initial_state = torch.zeros_like(initial_state)
            #initial_state = torch.tensor(np.array([ 0.04263216,  0.00452452, -0.03763419, -0.03992425])).float().unsqueeze(0)
            #learn the policy parameter using current model

            model_f = self.model.imagine

            if plan == 'pg':
                policy_train_loss = self.policy.pg_train(num_epoch, initial_state, horizon , self.cost, model_f, gamma=1)
            elif plan == 'rp':
                policy_train_loss = self.policy.rp_train(num_epoch, num_particle,initial_state, horizon, self.cost, model_f, 
                                              mean_obs,gamma = 1)
        """
        self.policy_loss = policy_train_loss
        total_reward = []
        for i in range(20):
            init_x = torch.tensor(self.env.reset()).unsqueeze(0).float() 

            imagine_reward = self.model.validate_by_imagination(init_x, self.policy.forward, plan, mean_obs = False)
            total_reward.append(imagine_reward)
            #print('temp reward', imagine_reward)
        mean_reward = np.mean(total_reward)
        std_reward = np.std(total_reward)
        print('Training reward: mean {}, std {}.'.format(mean_reward, std_reward))
        
        return mean_reward, std_reward
        """

        """
        total_cost10 = []
        total_cost100 = []
        for i in range(20):
            initial_state = torch.tensor(self.env.reset()).unsqueeze(0).float()         #[1, output]
            mean_cost10 = self.policy.rp_validate(num_particle, initial_state, 10, self.cost, model_f, mean_obs)
            mean_cost100 = self.policy.rp_validate(num_particle, initial_state, 100, self.cost, model_f, mean_obs)
            total_cost10.append(mean_cost10)
            total_cost100.append(mean_cost100)

        mean_cost10 = np.mean(total_cost10)
        std_cost10 = np.std(total_cost10)
        mean_cost100 = np.mean(total_cost100)
        std_cost100 = np.std(total_cost100)

        return mean_cost10, std_cost10, mean_cost100, std_cost100
        """


#Learning-Planning iterations

In [8]:
time1 = time.time()

testing_reward_list = []
behaviour_uncertainty = False
deterministic = False
plan = 'rp'
num_data = 10
mean_training_reward_list = []
std_training_reward_list = []

torch.manual_seed(1)
agent = Agent(env_case = 1, device = 'cuda', mode = 'LSTM')
for i in range(num_data):
     _ = agent.env_rollout(True, behaviour_uncertainty = behaviour_uncertainty,plan = 'random')

agent.model_learning(num_epoch=1000, num_batch = 10)
#agent.model = rssm
#mean_training_reward, std_training_reward = 
agent.policy_learning(imagine_num=50, num_particle = 1000, num_epoch = 1, batch_size = 10, horizon = 10, plan = plan, mean_obs = False)
print('\n Finish policy learning...')

#mean_training_reward_list.append(mean_training_reward)
#std_training_reward_list.append(std_training_reward)


print('\n ------------------TESTING-------------------')
#over 10 trails
avg_rewards = 0
for j in range(20):
    rewards = agent.env_rollout(if_remember=False, behaviour_uncertainty = behaviour_uncertainty,plan = plan)
    print(j, rewards)
    avg_rewards += rewards
avg_rewards = avg_rewards/20
testing_reward_list.append(avg_rewards)
print('Total trajs:', j, avg_rewards)
if avg_rewards > 200:
    print('success')
        

avg_data_length_list = []
mean_cost10_list = []
std_cost10_list = []
mean_cost100_list = []
std_cost100_list = []

test_std_list = []


for i in range(50):
    print('Epoch = ',i+1)
    _ = agent.env_rollout(True, behaviour_uncertainty = behaviour_uncertainty,plan = plan)

    total = 0
    for i in range(len(agent.observations_list)):
        total += len(agent.observations_list[i])
    print('average training data length = ', total/len(agent.observations_list))
    avg_data_length_list.append(total/len(agent.observations_list))

    agent.model_learning(num_epoch=1000, num_batch = 10)

    #mean_training_reward, std_training_reward=
    #agent.policy = controller(1, 4, 'cuda').to('cuda')
    agent.policy_learning(imagine_num=50, num_particle = 1000, num_epoch = 1, batch_size = 10, horizon = 10, plan = plan, mean_obs = False)
    #mean_training_reward_list.append(mean_training_reward)
    #std_training_reward_list.append(std_training_reward)

    print('\n Finish policy learning...')
    """
    total_cost10 = []
    for i in range(20):
        initial_state = torch.tensor(agent.env.reset()).unsqueeze(0).float()
        mean_cost = agent.policy.rp_validate(num_particle=1000, initial_state = initial_state, 
                                             horizon = 10 ,cost_f = agent.cost, model_imagine_f=
                                             agent.model.imagine, mean_obs=False)
        total_cost10.append(mean_cost)
    mean_cost10 = np.mean(total_cost10)
    std_cost10 = np.mean(total_cost10)
    mean_cost10_list.append(mean_cost10)
    std_cost10_list.append(std_cost10)

    total_cost100 = []
    for i in range(20):
        initial_state = torch.tensor(agent.env.reset()).unsqueeze(0).float()
        mean_cost = agent.policy.rp_validate(num_particle=1000, initial_state = initial_state, 
                                             horizon = 100 ,cost_f = agent.cost, model_imagine_f=
                                             agent.model.imagine, mean_obs = False)        
        total_cost100.append(mean_cost)
    mean_cost100 = np.mean(total_cost100)
    std_cost100 = np.mean(total_cost100)
    mean_cost100_list.append(mean_cost100)
    std_cost100_list.append(std_cost100)
    """
    """
    if i%5==0:
        temp_std = 0
        for j in range(10):

            testing_obs = testing_obs_list[j].unsqueeze(1).float()
            testing_actions = testing_actions_list[j].float()
            init_obs = testing_obs[0,:,:]

            mean, std = agent.model.mc_predict(init_obs, testing_actions, mean_obs = False)          #[seq-1,1, output]
            temp_std += std.mean(0).detach().cpu()
        test_std_list.append(temp_std/10)
    """



    print('\n ------------------TESTING-------------------')
    #over 10 trails
    avg_rewards = 0
    for j in range(20):
        rewards = agent.env_rollout(if_remember=False, behaviour_uncertainty = behaviour_uncertainty,plan = plan)
        print(j, rewards)
        avg_rewards += rewards
    avg_rewards = avg_rewards/20
    testing_reward_list.append(avg_rewards)
    print('Total trajs:', j, avg_rewards)
    if avg_rewards > 200:
        print('success')

time2 = time.time()
print(time2 - time1)

CartPoleModEnv - Version 0.2.0, Noise case: 1




Epoch0; FE = -5.973501205444336; LL = -4.559955596923828; KL = 1.4135457277297974.

 Finish policy learning...

 ------------------TESTING-------------------
0 18
1 20
2 24
3 23
4 17
5 19
6 20
7 17
8 23
9 18
10 25
11 21
12 24
13 17
14 22
15 17
16 22
17 25
18 26
19 18
Total trajs: 19 20.8
Epoch =  1
average training data length =  13.363636363636363
Epoch0; FE = 6.617893695831299; LL = 6.893275260925293; KL = 0.2753816545009613.

 Finish policy learning...

 ------------------TESTING-------------------
0 21
1 32
2 25
3 23
4 29
5 32
6 26
7 23
8 20
9 24
10 28
11 31
12 33
13 26
14 28
15 28
16 34
17 23
18 25
19 30
Total trajs: 19 27.05
Epoch =  2
average training data length =  14.333333333333334
Epoch0; FE = 8.081825256347656; LL = 8.250612258911133; KL = 0.16878710687160492.


KeyboardInterrupt: ignored