In [2]:
import torch
import numpy as np
import gym
import math
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.distributions.relaxed_bernoulli import RelaxedBernoulli
from torch.distributions import Bernoulli
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.autograd import Variable

from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.distributions import Normal
from torch.distributions import MultivariateNormal
from torch.distributions.kl import kl_divergence

from torch.autograd import Variable

import seaborn as sns

import time

#Some tools

In [3]:
def padding_tensor(sequences):
    """
    :param sequences: list of tensors with shape [seq, state dim]
    :return: tensor with shape [num, max_seq_length, state dim]
    """
    num = len(sequences)
    max_len = max([s.size(0) for s in sequences])
    feature_dim = sequences[0].size(-1)
    out_dims = (num, max_len, feature_dim)

    out_tensor = sequences[0].data.new(*out_dims).fill_(0)

    mask = sequences[0].data.new(*out_dims).fill_(0)
    for i, tensor in enumerate(sequences):
        length = tensor.size(0)
        out_tensor[i, :length, :] = tensor
        mask[i, :length,:] = 1
    return out_tensor, mask


def truncate_sequence(list_of_sequence, batch_first, min_len=None):
    """
    list_of_sequence: list of tensor with shape [seq, feature dim]
    return : tensor with shape [min_seq_length, batch, fea_dim] or [batch, min_seq_length, fea_dim] if batch_first
    """
    feature_dim = list_of_sequence[0].size(-1)
    if min_len is None:
        min_len = min([s.size(0) for s in list_of_sequence])
        
    container = torch.zeros(len(list_of_sequence), min_len, feature_dim)
    for i in range(len(list_of_sequence)):
        #random truncation
        #start = np.random.choice((list_of_sequence[i].size(0) - min_len +1))
        #container[i] = list_of_sequence[i][start : start+min_len, :]

        container[i] = list_of_sequence[i][:min_len,:]

    
    if batch_first:
        return container
    else:
        return container.permute(1,0,2)



def plot_LLB(true_obs, mean, std):

    position = true_obs[1:,0,0].data.numpy()
    velocity = true_obs[1:,0,1].data.numpy()
    angle = true_obs[1:,0,2].data.numpy()
    angle_v = true_obs[1:,0,3].data.numpy()

    lower = mean - std      #[seq-1, output_dim]
    upper = mean + std
    position_mean = mean[:,0].data.cpu().numpy()
    velocity_mean = mean[:,1].data.cpu().numpy()
    angle_mean = mean[:,2].data.cpu().numpy()
    angle_v_mean = mean[:,3].data.cpu().numpy()

    lower_position = lower[:,0].data.cpu().numpy()
    lower_velocity = lower[:,1].data.cpu().numpy()
    lower_angle = lower[:,2].data.cpu().numpy()
    lower_angle_v = lower[:,3].data.cpu().numpy()


    upper_position = upper[:,0].data.cpu().numpy()
    upper_velocity = upper[:,1].data.cpu().numpy()
    upper_angle = upper[:,2].data.cpu().numpy()
    upper_angle_v = upper[:,3].data.cpu().numpy()

    fig, ax = plt.subplots(1,4, figsize = (20,5))
    x = np.arange(0, position.shape[0])

    ax[0].plot(position, label = 'True')
    ax[0].plot(x, position_mean, label = 'mean')
    ax[0].fill_between(x, upper_position, lower_position, facecolor='grey',
                                    color = 'grey', alpha = 0.2)
    ax[0].set_title('Position');
    ax[0].legend();

    ax[1].plot(velocity, label = 'True')
    ax[1].plot(x, velocity_mean, label = 'mean')
    ax[1].fill_between(x, upper_velocity, lower_velocity, facecolor='grey',
                                    color = 'grey', alpha = 0.2)
    ax[1].set_title('Velocity')
    ax[1].legend();


    ax[2].plot(angle, label = 'True')
    ax[2].plot(x, angle_mean, label = 'mean')
    ax[2].fill_between(x, upper_angle, lower_angle, facecolor='grey',
                                    color = 'grey', alpha = 0.2)
    ax[2].set_title('angle')
    ax[2].legend();

    ax[3].plot(angle_v, label = 'True')
    ax[3].plot(x, angle_v_mean, label = 'mean')
    ax[3].fill_between(x, upper_angle_v, lower_angle_v, facecolor='grey',
                                    color = 'grey', alpha = 0.2)
    ax[3].set_title('angle velocity')
    ax[3].legend();













def save_model_policy(model, model_optimiser, policy, policy_optimiser, save_model_path, save_policy_path):
    save_model_path = save_model_path + "/model.tar"
    save_policy_path = save_policy_path + "/policy.tar" 
    torch.save({
        "model_dict": model.state_dict(),
        "trainer_dict": model_optimiser.state_dict()
    }, save_model_path)

    torch.save({
        "model_dict": policy.state_dict(),
        "trainer_dict": policy_optimiser.state_dict()
    }, save_policy_path)



    print('Checkpointed')


#env

In [4]:
# -*- coding: utf-8 -*-
"""
Classic cart-pole system implemented by Rich Sutton et al.
Copied from http://incompleteideas.net/sutton/book/code/pole.c
permalink: https://perma.cc/C9ZM-652R
Modified by Aaditya Ravindran to include friction and random sensor & actuator noise
"""

import logging
import math
import random
import gym
from gym import spaces
from gym.utils import seeding
import numpy as np

logger = logging.getLogger(__name__)

class CartPoleModEnv(gym.Env):
    metadata = {
            'render.modes': ['human', 'rgb_array'],
            'video.frames_per_second' : 50
    }

    def __init__(self,case):
        self.__version__ = "0.2.0"
        print("CartPoleModEnv - Version {}, Noise case: {}".format(self.__version__,case))
        self.gravity = 9.8
        self.masscart = 1.0
        self.masspole = 0.1
        self.total_mass = (self.masspole + self.masscart)
        self.length = 0.5 # actually half the pole's length
        self.polemass_length = (self.masspole * self.length)
        self.seed()

        self.origin_case = case
        if case<6:          #only model  noise
            self.force_mag = 30.0*(1+self.addnoise(case))
            self.case = 1
        elif case>9:    #both model and data noise
            self.force_mag = 30.0*(1+self.addnoise(case))
            self.case = 10
        else:               #only data noise
            self.force_mag = 30.0
            self.case = case
            
        self.tau = 0.02     # seconds between state updates

        self.min_action = -1.
        self.max_action = 1.0


		# Angle at which to fail the episode
        self.theta_threshold_radians = 12 * 2 * math.pi / 360
        self.x_threshold = 2.4

        # Angle limit set to 2 * theta_threshold_radians so failing observation is still within bounds
        high = np.array([
            self.x_threshold * 2,
            np.finfo(np.float32).max,
            self.theta_threshold_radians * 2,
            np.finfo(np.float32).max])

        #self.action_space = spaces.Discrete(2) # AA Set discrete states back to 2
        self.action_space = spaces.Box(
                low = self.min_action,
                high = self.max_action,
                shape = (1,) 
        )

        self.observation_space = spaces.Box(-high, high)

        self.viewer = None
        self.state = None

        self.steps_beyond_done = None

    def addnoise(self,x):
        return {
        1 : 0,
        2 : self.np_random.uniform(low=-0.05, high=0.05, size=(1,)), #  5% actuator noise ,  small model uniform noise
        3 : self.np_random.uniform(low=-0.10, high=0.10, size=(1,)), # 10% actuator noise ,  large model uniform noise
        4 : self.np_random.normal(loc=0, scale=np.sqrt(0.10), size=(1,)),                  # small model gaussian noise
        5 : self.np_random.normal(loc=0, scale=np.sqrt(0.50), size=(1,)),                 #  large model gaussian noise
        6 : self.np_random.uniform(low=-0.05, high=0.05, size=(1,)), #  5% sensor noise ,    small data uniform noise
        7 : self.np_random.uniform(low=-0.10, high=0.10, size=(1,)), # 10% sensor noise ,    large data uniform noise
        8 : self.np_random.normal(loc=0, scale=np.sqrt(0.10), size=(1,)), # 0.1              small data gaussian noise
        9 : self.np_random.normal(loc=0, scale=np.sqrt(0.20), size=(1,)), # 0.2              large data gaussian noise
        10: self.np_random.normal(loc = 0, scale = np.sqrt(0.10), size = (1,)),           #  small both gaussian noise
        11: self.np_random.normal(loc = 0, scale = np.sqrt(0.50), size = (1,)),          #    large both gaussian noise
        }.get(x,1)

    def seed(self, seed=None): # Set appropriate seed value
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def stepPhysics(self, force):
        x, x_dot, theta, theta_dot = self.state
        costheta = math.cos(theta)
        sintheta = math.sin(theta)
        temp = (force + self.polemass_length * theta_dot * theta_dot * sintheta) / self.total_mass
        thetaacc = (self.gravity * sintheta - costheta * temp) / \
                    (self.length * (4.0/3.0 - self.masspole * costheta * costheta / self.total_mass))
        xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass
        #noise = self.addnoise(self.case) 
        x  = (x + self.tau * x_dot)
        x_dot = (x_dot + self.tau * xacc)
        theta = (theta + self.tau * theta_dot)#*(1 + noise)
        theta_dot = (theta_dot + self.tau * thetaacc)
        return (x, x_dot, theta, theta_dot)


    def step(self, action):
        assert self.action_space.contains(action), "%r (%s) invalid"%(action, type(action))
        force = self.force_mag * float(action)
        self.state = self.stepPhysics(force)
        x, x_dot, theta, theta_dot = self.state

        #adding measurement noisy to theta
        noise = self.addnoise(self.case)
        theta = theta * (1+noise)
        noise = self.addnoise(self.case)
        x = x*(1+noise)
        noise = self.addnoise(self.case)
        x_dot = x_dot*(1+noise)
        noise = self.addnoise(self.case)
        theta_dot = theta_dot*(1+noise)

        output_state = (x, x_dot, theta, theta_dot) 
        output_state = np.array(output_state)  


        done = x < -self.x_threshold \
            or x > self.x_threshold \
            or theta < -self.theta_threshold_radians \
            or theta > self.theta_threshold_radians
        done = bool(done)

        if not done:
            reward = 1.0
        elif self.steps_beyond_done is None:
            # Pole just fell!
            self.steps_beyond_done = 0
            reward = 1.0
        else:
            if self.steps_beyond_done == 0:
                logger.warn("You are calling 'step()' even though this environment has already returned done = True. You should always call 'reset()' once you receive 'done = True' -- any further steps are undefined behavior.")
            self.steps_beyond_done += 1
            reward = 0.0

        #return np.array(self.state), reward, done, {}
        return output_state, reward, done, {}


    def reset(self):
        self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
        self.steps_beyond_done = None

        if self.origin_case<6:          #only model  noise
            self.force_mag = 30.0*(1+self.addnoise(self.origin_case))
        elif self.origin_case>9:    #both model and data noise
            self.force_mag = 30.0*(1+self.addnoise(self.origin_case))
        else:               #only data noise
            self.force_mag = 30.0


        return np.array(self.state)

    def render(self, mode='human', close=False):
        if close:
            if self.viewer is not None:
                self.viewer.close()
                self.viewer = None
            return

        screen_width = 600
        screen_height = 400

        world_width = self.x_threshold*2
        scale = screen_width/world_width
        carty = 100 # TOP OF CART
        polewidth = 10.0
        polelen = scale * 1.0
        cartwidth = 50.0
        cartheight = 30.0

        if self.viewer is None:
            from gym.envs.classic_control import rendering
            self.viewer = rendering.Viewer(screen_width, screen_height)
            l,r,t,b = -cartwidth/2, cartwidth/2, cartheight/2, -cartheight/2
            axleoffset =cartheight/4.0
            cart = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)])
            self.carttrans = rendering.Transform()
            cart.add_attr(self.carttrans)
            self.viewer.add_geom(cart)
            l,r,t,b = -polewidth/2,polewidth/2,polelen-polewidth/2,-polewidth/2
            pole = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)])
            pole.set_color(.8,.6,.4)
            self.poletrans = rendering.Transform(translation=(0, axleoffset))
            pole.add_attr(self.poletrans)
            pole.add_attr(self.carttrans)
            self.viewer.add_geom(pole)
            self.axle = rendering.make_circle(polewidth/2)
            self.axle.add_attr(self.poletrans)
            self.axle.add_attr(self.carttrans)
            self.axle.set_color(.5,.5,.8)
            self.viewer.add_geom(self.axle)
            self.track = rendering.Line((0,carty), (screen_width,carty))
            self.track.set_color(0,0,0)
            self.viewer.add_geom(self.track)

        if self.state is None: return None

        x = self.state
        cartx = x[0]*scale+screen_width/2.0 # MIDDLE OF CART
        self.carttrans.set_translation(cartx, carty)
        self.poletrans.set_rotation(-x[2])
        return self.viewer.render(return_rgb_array = mode=='rgb_array')

#LLB

In [5]:
class BRNN(nn.Module):
    def __init__(self, action_dim, hidden_dim, output_dim, device, mode):
        super(BRNN, self).__init__()
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.device = device
        self.mode = mode

        #self.W_added_noise = np.log(2)
        self.W_min = np.log(0.1)

        init_dim = hidden_dim
        emission_dim = hidden_dim

        self.initial_encoder = nn.Sequential(
            nn.Linear(output_dim, init_dim),
            nn.ReLU(),
            nn.Linear(init_dim, hidden_dim),
            nn.Tanh()
        )
        if mode == 'RNN':
            self.transition = nn.RNN(action_dim, hidden_dim)
        elif mode == 'LSTM':
            self.transition = nn.LSTM(action_dim, hidden_dim)
        elif mode == 'GRU':
            self.transition = nn.GRU(action_dim, hidden_dim)
        else:
            raise NotImplementedError
        
        #self.W_mu = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim + 1).to(self.device), requires_grad=True)
        #self.W_logvar = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim + 1).to(self.device), requires_grad=True)
        


        self.prior_W_mu = torch.zeros_like(torch.Tensor(hidden_dim, hidden_dim+1)).to(self.device)
        self.prior_W_mu = nn.init.kaiming_uniform_(self.prior_W_mu, a=math.sqrt(5))

        self.prior_W_logvar = torch.ones_like(self.prior_W_mu).to(self.device)           #log var
        self.prior_W_logvar = (np.log(0.1)*self.prior_W_logvar).requires_grad_(False)           

        self.prior_W_logvar.uniform_(np.log(0.05), np.log(0.1))
        #self.prior_W_logvar.uniform_(np.log(0.2), np.log(0.5))

        self.W_mu = nn.Parameter(self.prior_W_mu.detach().clone().requires_grad_(True))
        self.W_logvar = nn.Parameter(self.prior_W_logvar.detach().clone().requires_grad_(True))  # log(sigma)
        #self.W_logvar.data.fill_(np.log(0.5))
        #self.W_mu.data = trained_W_mu.clone()
        #self.W_logvar.data = trained_W_logvar.clone()

        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.emission_mean = nn.Sequential(
            nn.Linear(hidden_dim, output_dim)
        )

        #self.emission_mean = nn.Sequential(
        #    nn.Linear(emission_dim, output_dim),
        #    nn.Tanh(),
        #    nn.Linear(output_dim, output_dim)
        #)

        #self.x_var =   torch.tensor([1e-2, 1e-2, 1e-2, 1]).to(self.device).requires_grad_(False)
        #self.x_var = torch.tensor([0.0025, 1e-1, 0.0025, 1]).to(self.device).requires_grad_(False)

        #self.x_var =  torch.tensor([0.0025, 0.0025, 0.0025, 0.25]).to(self.device).requires_grad_(False)


        #x_logvar =  torch.log(torch.tensor([1, 1, 1, 1.])).to(self.device)
        #self.x_logvar = nn.Parameter(x_logvar.clone().requires_grad_(True))

        self.emission_logvar = nn.Linear(hidden_dim, output_dim)


        #self.x_var = torch.tensor([0.0025, 2, 0.0025, 2]).to(self.device).requires_grad_(False)


        #self.emission_logvar = nn.Sequential(
        #    nn.Linear(emission_dim, output_dim))
        #    nn.Tanh(),
        #    nn.Linear(output_dim, output_dim),
        #    nn.Sigmoid()
        #)

       # self.scale = nn.Parameter(torch.Tensor(output_dim).requires_grad_(True))
        #self.scale.data.fill_(1e-1)

        #self.x_var = torch.tensor([1e-4, 1e-6]).to(self.device).requires_grad_(False)



        #self.decoder_mean = nn.Sequential(
        #    nn.Linear(hidden_dim, output_dim),
        #    nn.ReLU(),
        #    nn.Linear(output_dim, output_dim)
        #)
        #self.x_sigma = torch.tensor([1e-4, 1e-6]).to(self.device).requires_grad_(False)

        #self.decoder_sigma = nn.Sequential(
        #    nn.Linear(hidden_dim, hidden_dim),
        #    nn.Tanh(),
        #    nn.Linear(hidden_dim, output_dim)
        #)



        #self.reset()

    def reset(self):
        stdv = 1./math.sqrt(self.W_mu.size(1))
        logvar_init = math.log(stdv)*2
        self.W_mu.data.uniform_(-stdv, stdv)
        self.W_logvar.data.fill_(math.log(0.05))

        #nn.init.kaiming_uniform_(self.W_mu, a=math.sqrt(5))
        #self.W_sigma.data.fill_(np.log(0.5))


    def rollout(self, init_x, A, W_eps = None, track_sigmoid = False, W_uncertainty = True, epsilon_uncertainty = True):
        """
        take initial obs and sequence of actions, output predicted observation mean and variance, 
        if grad = None, using mean, if not , use posterior sharpening

        init_x : [batch, output_dim]
        A : [seq-1, batch, 1]
        return : observational mean and sigma
        """
        if track_sigmoid:
            sigmoid_tracker = []

        #forward propagation
        init_x = init_x.to(self.device)
        A = A.to(self.device)


        batch_size = init_x.size(0)
        
        previous_h = self.initial_encoder(init_x).unsqueeze(0).to(self.device)      #[1, batch ,hidden]
        if self.mode == 'LSTM':
            previous_c = torch.zeros_like(previous_h).to(self.device)
        
        #here we follow the BBB paper where they fix the weight for each mini-batch instead of conditioning on single example
        #if W_eps is not None:
        #    std = torch.exp(0.5 * self.W_logvar)
        #    if len(W_eps.shape) > 2:  # Multiple specified samples from W
        #        W = self.W_mu.unsqueeze(0).expand(W_eps.shape) + W_eps*std.unsqueeze(0).expand(W_eps.shape)  #  [batch, hid, hid+1]

        #    else:
        #        W = self.W_mu + W_eps*std

        #else:
        #    W_eps = torch.rand_like(self.W_logvar).normal_().to(self.device)
        #    W = self.W_mu + W_eps*torch.exp(0.5*self.W_logvar)


        #eps = torch.FloatTensor(batch_size, self.W_mu.size(0), self.W_mu.size(1)).normal_().to(self.device)
        #W =  self.W_mu.unsqueeze(0).expand(eps.shape) + eps*torch.exp(0.5*self.W_logvar).unsqueeze(0).expand(eps.shape)  # [batch, hid, hid+1]


        if W_uncertainty:
            #clamp_W = torch.clamp(self.W_logvar, self.W_min, np.log(2.0))

            W = self.stack_W(batch_size, self.W_mu, torch.exp(0.5 * (self.W_logvar)) )        #[batch, hid, hid+1]
        else:
            W = self.W_mu.unsqueeze(0).expand(batch_size, self.W_mu.size(0), self.W_mu.size(1))

        #W = self.stack_W(batch_size, self.W_mu, torch.exp(0.5*self.W_logvar))        #[batch, hid, hid+1]
        #W_sample = self.reparametrise(self.W_mu, torch.exp(self.W_sigma))
        #W = W_sample.unsqueeze(0).expand(batch_size, W_sample.size(0), W_sample.size(1))        #[batch, hid, hid+1]


        output_mean_list = []
        output_cov_list = []
        preds = []
        for t in range(A.size(0)):
            previous_a = A[t].unsqueeze(0)      #[1, batch, 1]
            if self.mode == 'LSTM':
                current_h, current_c = self.transition(previous_a, (previous_h, previous_c))[-1]   #[1,batch, hidden]
            else:
                current_h = self.transition(previous_a, previous_h)[-1]     #[1, batch, hidden]


            f_t = torch.cat([current_h, torch.ones(current_h.shape[0], current_h.shape[1], 1).to(self.device)], dim = -1)           
            f_t = f_t.permute(1,2,0)                #[batch, hid+1, 1]


            #one = torch.ones(current_h.squeeze(0).size()[0]).to(self.device)     
            #f_t = torch.cat((current_h.squeeze(0), one.unsqueeze(1)), dim = 1)           #[batch, hidden_size + 1]
            #next_h = torch.bmm(W, f_t.unsqueeze(-1)).squeeze(-1)         #[batch,hid, hid+1] *[batch, hid+1, 1] ----> [batch, hid]
            next_h = torch.bmm(W, f_t).squeeze(-1)
            if self.mode == 'GRU':
                next_h = torch.tanh(next_h)

            temp_emission = self.decoder(next_h)                        #[batch, output]
            emission_mean = self.emission_mean(temp_emission)           #[batch, output]
            #emission_sigma = torch.sqrt(self.x_var.expand(emission_mean.shape))

            #emission_var = self.x_var.expand(emission_mean.shape)       #[batch, output]
            #emission_sigma = torch.sqrt(emission_var)                   #[batch, output]
            #emission_var = torch.exp(self.x_logvar).expand(emission_mean.size())       #[batch, output]
            #emission_sigma = torch.sqrt(emission_var) + 1e-4
            emission_var = torch.exp(self.emission_logvar(temp_emission))
            emission_sigma = torch.sqrt(emission_var)


            #emission_var = torch.exp(self.emission_logvar(temp_emission))   #[batch, output]
            #emission_sigma = torch.sqrt(emission_var)

            #emission_var= self.emission_logvar(temp_emission)   #* self.scale         #[batch, output]
            #emission_sigma = torch.sqrt(emission_var)                  #[batch, output]

            if track_sigmoid:
                sigmoid_tracker.append(emission_var) 


            if epsilon_uncertainty:
                emission = self.reparametrise(emission_mean, emission_sigma)        #[batch, output]
            else:
                emission = emission_mean        #[batch, output]


                        
            preds.append(emission.unsqueeze(0))


            #emission_mean = self.decoder_mean(emission)         #[batch, output]
            #emission_sigma = torch.exp(self.decoder_sigma(emission))       #[batch, output]
            if self.mode == 'LSTM':
                previous_c = current_c
            previous_h = next_h.unsqueeze(0)

            output_mean_list.append(emission_mean.unsqueeze(0))
            output_cov_list.append(emission_var.unsqueeze(0))             
        output_mean_list = torch.cat(output_mean_list, dim = 0)     #[seq-1, batch, output]
        output_cov_list = torch.cat(output_cov_list, dim = 0)       #[seq-1, batch, output]

        if track_sigmoid:
            sigmoid_list = torch.cat(sigmoid_tracker, dim = 0)      #[seq-1, output]
            return torch.cat(preds, dim = 0), output_mean_list, output_cov_list, sigmoid_list

        return torch.cat(preds, dim = 0), output_mean_list, output_cov_list



    def forward(self, X, A, N):
        """
        X : [seq, batch, output_dim]
        A : [seq-1, batch, 1]
        """
        X = X.to(self.device)
        A = A.to(self.device)
        #compute nll
        #output_mean, output_std = self.rollout(X[0], A)
        #NLL = self.get_nll(output_mean, output_std, X[1:, :, :])        #[batch]
        #The gradient of nll
        #fix_eps = torch.FloatTensor(X.size(1), self.W_logvar.shape[0], self.W_logvar.shape[1]).normal_().to(self.device)

        #forward pass with sharpening posterior weight
        _, output_mean, output_cov = self.rollout(X[0], A)

        return self.get_loss(output_mean, output_cov, X[1:,...], N)
    
    def get_loss(self, output_mean, output_sigma, target, N):
        """
        calculate free energy, NLL - KL
        output_mean : [seq-1, batch, output]
        output_sigma : [seq-1, batch, output]
        target : [seq-1, batch, output]

        """
        batch_size = target.size(1)
        seq_length = target.size(0)
        T = batch_size * seq_length
        #LL
        flatten_x_mean = output_mean.view(-1, self.output_dim)
        flatten_x_std = torch.sqrt(output_sigma).view(-1, self.output_dim)

        flatten_target = target.reshape(-1, self.output_dim)

        LL = self.batched_gaussian_ll(flatten_x_mean, flatten_x_std, flatten_target)
        LL = LL.sum()

        #LL = 0.5 * (- T*np.log(2*np.pi) - torch.log(output_sigma).sum()  
        #        -  torch.pow(output_sigma, -1).mul(torch.pow(target - output_mean, 2)).sum())        
        #print('LL = {}; LL2 = {}'.format(LL.item(), LL2.item()))
        #KL
        Wprior = Normal(self.prior_W_mu, torch.sqrt(torch.exp(self.prior_W_logvar)))
        Wpost = Normal(self.W_mu, torch.sqrt(torch.exp(self.W_logvar)))
        KL = kl_divergence(Wpost, Wprior).sum()

        #KL = 0.5 * ( self.prior_W_logvar - (self.W_logvar) - 1 \
        #            + torch.exp((self.W_logvar) - self.prior_W_logvar) \
        #            + torch.exp(-self.prior_W_logvar) * torch.pow(self.W_mu - self.prior_W_mu, 2) ).sum()
        
        #print('KL = {}; KL2 = {}'.format(KL.item(), KL2.item()))

        FE = (1/batch_size)*LL - (1/(seq_length*batch_size))*KL
        #FE = LL - (1/batch_size)*KL
        #FE = LL - (1/batch_size)*KL
        return FE, (1/batch_size)*LL, (1/(seq_length*batch_size))*KL
        #return FE, LL, (1/batch_size)*KL


        '''
        #NLL = self.get_nll(output_mean, output_sigma, target)
        T = target.size(0)*target.size(1)

        LL = 0.5 * (- T*np.log(2*np.pi) - torch.log(output_sigma).sum()  
                -  torch.pow(output_sigma, -1).mul(torch.pow(target - output_mean, 2)).sum())
        #KL
        #KL = self.kl_divergence(self.prior_W_mu, torch.exp(self.prior_W_sigma), self.W_mu, torch.exp(self.W_sigma))
        #clamp_W = torch.clamp(self.W_logvar, self.W_min, np.log(2.0))
        clamp_W = self.W_logvar

        KL = 0.5 * ( self.prior_W_logvar - (clamp_W) - 1 \
                    + torch.exp((clamp_W) - self.prior_W_logvar) \
                    + torch.exp(-self.prior_W_logvar) * torch.pow(self.W_mu - self.prior_W_mu, 2) ).sum()


        FE = (1/T)*LL - (1/N)*KL

        #KL_sharp = torch.sum((self.sharp_W - self.kl_W).pow(2)/ (2 * 0.002**2))
        #KL_sharp = self.kl_divergence(self.phi_container, self.condition_prior_W_sigma, self.sharp_W_mean, self.sharp_W_sigma)

        return FE.flatten(), (1/T)*LL.flatten(), (1/N)*KL.flatten()

        '''
    
    def MSE_forward(self, X, A):
        X = X.to(self.device)
        A = A.to(self.device)

        pred, _, _ = self.rollout(X[0], A)      #[seq-1, batch, output]
        return pred



    def mc_prediction(self, init_X, A, track_sigmoid = False):
        """

        init_X : [1, output]
        A : [seq-1, 1, action_dim]

        """
        init_X = init_X.to(self.device)
        A = A.to(self.device)
        total_list = []
        if track_sigmoid:
            total_sigmoid_list = []
        for i in range(500):
            W_eps = 0 #torch.FloatTensor(init_X.size(0), self.W_logvar.shape[0], self.W_logvar.shape[1]).normal_().to(self.device)

            if track_sigmoid:
                pred, _, _, sigmoid_list= self.rollout(init_X, A, W_eps, track_sigmoid = True)     #[seq-1, 1, output]
                total_sigmoid_list.append(sigmoid_list.unsqueeze(-1))     #“seq-1, output, 1]

            else:
                pred, _, _  = self.rollout(init_X, A, W_eps)

            #x_sample = self.reparametrise(output_mean.squeeze(), torch.sqrt(output_sigma).squeeze())
            #total_list.append(x_sample.unsqueeze(-1))           #[seq-1, output, 1]
            total_list.append(pred.permute(0,2,1))              #[seq-1, output, 1]
        total_list = torch.cat(total_list, dim = -1)            #[seq-1, output, 300]
        mean = torch.mean(total_list, dim = -1)      #[seq-1, output]
        std = torch.std(total_list, dim = -1)       #[seq-1, output]
        if track_sigmoid:
            total_sigmoid_list = torch.cat(total_sigmoid_list, dim = -1)        #[seq-1, output, 500]
            sigmoid_mean = torch.mean(total_sigmoid_list, dim = -1)     #[seq-1, output]
            sigmoid_std = torch.std(total_sigmoid_list, dim = -1)         #[seq-1, output]

            return mean, std, sigmoid_mean, sigmoid_std
        else:
            return mean, std

    def uncertainty(self, init_X, A, object):
        """
        init_X : [1, output]
        A : [seq-1, 1, action_dim]
        """
        init_X = init_X.to(self.device)
        A = A.to(self.device)

        total_list = []
        for i in range(500):
            if object == 'W':
                pred, _, _= self.rollout(init_X, A, W_uncertainty = True, epsilon_uncertainty = False)
            elif object == 'e':
                pred, _, _ = self.rollout(init_X, A, W_uncertainty = False, epsilon_uncertainty = True)
            elif object == 'both':
                pred, _, _ = self.rollout(init_X, A, W_uncertainty = False, epsilon_uncertainty = False)
            else:
                print('Either W or e')

            total_list.append(pred.permute(0,2,1))
        total_list = torch.cat(total_list, dim = -1)
        mean = torch.mean(total_list, dim = -1)
        std = torch.std(total_list, dim = -1)

        return mean, std

    def imagine(self, init_x, control_f, horizon, plan, W_uncertainty,e_uncertainty):
        """
        init_x : [batch, output_dim]
        """

        batch_size = init_x.size(0)
        init_x = init_x.to(self.device)
        previous_x = init_x
        previous_h = self.initial_encoder(init_x).unsqueeze(0)      #[1, batch, hidden]
        if self.mode == 'LSTM':
            previous_c = torch.zeros_like(previous_h).to(self.device)       #[1, batch, hidden]
        
        #different W for each initalisation
        if W_uncertainty:
            #clamp_W = torch.clamp(self.W_logvar, self.W_min, np.log(2.0))
            W = self.stack_W(batch_size, self.W_mu, torch.exp(0.5*self.W_logvar) )        #[batch, hid, hid+1]
        else:
            W = self.W_mu.unsqueeze(0).expand(batch_size, self.W_mu.size(0), self.W_mu.size(1))



        preds = []
        action_log_prob_list = []
        for t in range(horizon):
            if plan == 'pg':
                action_samples, action_log_prob = control_f(previous_x)

                #action_samples = (action_dist.probs > 0.5).float()  
                #action_log_prob = action_dist.log_prob(action_samples)
                #action_log_prob_list.append(action_log_prob.unsqueeze(0))

                #action_samples = action_dist.sample()                      #[batch, 1]
                #compute log prob
                action_log_prob_list.append(action_log_prob.unsqueeze(0))           #[1, batch, 1]
            elif plan == 'rp':
                action_samples, _= control_f(previous_x)          #[batch, 1]
                action_log_prob_list = 0
            
            if self.mode == 'LSTM':
                current_h, previous_c = self.transition(action_samples.unsqueeze(0), (previous_h, previous_c))[-1]   #[1,batch, hidden]
            else:
                current_h = self.transition(action_samples.unsqueeze(0), previous_h)[-1]     #[1, batch, hidden]


            f_t = torch.cat([current_h, torch.ones(current_h.shape[0], current_h.shape[1], 1).to(self.device)], dim = -1)           
            f_t = f_t.permute(1,2,0)                #[batch, hid+1, 1]

            next_h = torch.bmm(W, f_t).squeeze(-1)          #[batch, hidden]
            if self.mode == 'GRU':
                next_h = torch.tanh(next_h)

            temp_emission = self.decoder(next_h)                        #[batch, output]
            emission_mean = self.emission_mean(temp_emission)           #[batch, output]
            #emission_sigma = torch.sqrt(self.x_var.expand(emission_mean.shape))
            if e_uncertainty:
                emission_var = torch.exp(self.emission_logvar(temp_emission))
                emission_sigma = torch.sqrt(emission_var)

                #emission_sigma = torch.sqrt(torch.exp(self.x_logvar).expand(emission_mean.shape))
                #emission_sigma = torch.sqrt(self.x_var.expand(emission_mean.shape))                   #[batch, output]
                emission = self.reparametrise(emission_mean, emission_sigma)        #[batch, output]
            else:
                emission = emission_mean
            
            preds.append(emission.unsqueeze(0))             #[1, batch, output]

            previous_h = next_h.unsqueeze(0)
            previous_x = emission
        
        output_list = torch.cat(preds)                  #[seq-1, batch, output]
        if plan == 'pg':
            action_log_prob_list = torch.cat(action_log_prob_list)  #[seq-1, batch, 1]

        return output_list, action_log_prob_list


    def validate_by_imagination(self, init_x, control_f, plan, W_uncertainty, e_uncertainty):
        """
        instead of planning with fixed horizon, we let the agent planning as far as it can, 
        terminating when observations are out of range

        init_x : [1, state]
        return total reward

        """
        action_list = []

        batch_size = init_x.size(0)
        init_x = init_x.to(self.device)
        previous_x = init_x
        previous_h = self.initial_encoder(init_x).unsqueeze(0)      #[1, batch, hidden]
        if self.mode == 'LSTM':
            previous_c = torch.zeros_like(previous_h).to(self.device)       #[1, batch, hidden]
        
        #different W for each initalisation
        if W_uncertainty:
            #clamp_W = torch.clamp(self.W_logvar, self.W_min, np.log(2.0))
            W = self.stack_W(batch_size, self.W_mu, torch.exp(0.5*self.W_logvar) )        #[batch, hid, hid+1]
        else:
            W = self.W_mu.unsqueeze(0).expand(batch_size, self.W_mu.size(0), self.W_mu.size(1))

        reward = 0
        iter = 0

        while True:
            
            action_samples, _ = control_f(previous_x)       

            if self.mode == 'LSTM':
                current_h, previous_c = self.transition(action_samples.unsqueeze(0), (previous_h, previous_c))[-1]   #[1,batch, hidden]
            else:
                current_h = self.transition(action_samples.unsqueeze(0), previous_h)[-1]     #[1, batch, hidden]


            f_t = torch.cat([current_h, torch.ones(current_h.shape[0], current_h.shape[1], 1).to(self.device)], dim = -1)           
            f_t = f_t.permute(1,2,0)                #[batch, hid+1, 1]

            next_h = torch.bmm(W, f_t).squeeze(-1)          #[batch, hidden]
            if self.mode == 'GRU':
                next_h = torch.tanh(next_h)

            temp_emission = self.decoder(next_h)                        #[batch, output]
            emission_mean = self.emission_mean(temp_emission)           #[batch, output]
            #emission_sigma = torch.sqrt(self.x_var.expand(emission_mean.shape))
            if e_uncertainty:
                #emission_sigma = torch.sqrt(self.x_var.expand(emission_mean.shape))                   #[batch, output]
                #emission = self.reparametrise(emission_mean, emission_sigma)        #[batch, output]
                emission_var = torch.exp(self.emission_logvar(temp_emission))
                emission_sigma = torch.sqrt(emission_var)
                #emission_var = torch.exp(self.x_logvar).expand(emission_mean.size())            #[batch, output]
                #emission_sigma = torch.sqrt(emission_var) + 1e-4
                emission = self.reparametrise(emission_mean, emission_sigma)        #[batch, output]

            else:
                emission = emission_mean

            reward += 1.
            iter += 1

            done = emission[:,0] < -2.4 \
                or emission[:,0] > 2.4 \
                or emission[:,2] < -12 * 2 * math.pi / 360 \
                or emission[:,2] > 12 * 2 * math.pi / 360 \
                or iter >= 200
            done = bool(done)
            #print('state=', emission.data)
            #print('done = ',done)
            if done:
                break           

            previous_h = next_h.unsqueeze(0)
            previous_x = emission
        
        return reward






    """
    def predict(self, init_X, A):

        
        init_X = init_X.to(self.device)
        A = A.to(self.device)
        total_list = []
        previous_h = self.initial_encoder(init_X).unsqueeze(0).to(self.device)
        if self.mode == 'LSTM':
            previous_c = torch.zeros_like(previous_h).to(self.device)
        for i in range(500):
            temp_pred = []
            W = self.reparametrise(self.W_mu, torch.exp(self.W_sigma))         #[hid, hid+1]
            for t in range(A.size(0)):
                previous_a = A[t].unsqueeze(0)
                if self.mode == 'LSTM':
                    current_h, current_c = self.transition(previous_a, (previous_h, previous_c))[-1]   #[1,batch, hidden]
                else:
                    current_h = self.transition(previous_a, previous_h)[1]
                one = torch.ones(current_h.squeeze(0).size()[0]).to(self.device)     
                f_t = torch.cat((current_h.squeeze(0), one.unsqueeze(1)), dim = 1)           #[batch, hidden_size + 1]

                #W = self.reparametrise(self.W_mu, torch.exp(self.W_sigma))         #[hid, hid+1]                
                next_h = torch.bmm(W.unsqueeze(0), f_t.unsqueeze(-1)).squeeze(-1)             #[batch, hid, hid+1] * [batch, hid+1, 1] --> [batch, hid]         
                
                emission = self.decoder(next_h)
                emission_mean = self.decoder_mean(emission)     #[1, output]
                emission_sigma = torch.exp(self.decoder_sigma(emission))   #[1, output]
                #emission_sigma = self.x_sigma.expand(emission_mean.size(0), emission_mean.size(1))

                x_sample = self.reparametrise(emission_mean, emission_sigma)       #[1, output]
                
                #x_sample = self.decoder(next_h)         #[1, output_dim]
                temp_pred.append(x_sample)

                previous_h = next_h.unsqueeze(0)
                if self.mode == 'LSTM':
                    previous_c = current_c

            temp_pred_vec = torch.cat(temp_pred)    #[seq-1, output_dim]
            total_list.append(temp_pred_vec.unsqueeze(-1))          #list of tensor of shape [seq-1, output, 1]
        total_list = torch.cat(total_list, dim = -1)     #[seq-1, output_dim, 1000]
        mean = torch.mean(total_list, dim = -1)      #[seq-1, output_dim]
        std = torch.std(total_list, dim = -1)        #[seq-1, output_dim]

        return mean, std
    """

    def batched_gaussian_ll(self, mean, sigma, x):
        """
        log-likelihood of batched observation
        mean : shape [batch, output_size]
        sigma  : shape [batch, output_size]   (diagonal covariance)
        x    : shape [batch, output_size]
        the shape of final result is [batch, ]
        """

        if 0 in sigma:
            print('Zero occurs in diagonal sigma matrix. (batched gaussian ll)')
        if 0 in  sigma**2:
            print('Zero occurs after squaring the sigma matrix. (batched gaussian ll)')

        inv_diag_cov = self.diagonalise(1/(sigma**2), batch=True)              #a 2d batched matrix----> 3d batched diagonal tensor      


        exp = ((x - mean).unsqueeze(-2)) @ inv_diag_cov @ ((x-mean).unsqueeze(-1))      #
        exp = exp.squeeze()         #[batch]   
        #print(exp) 

        #if 0 in torch.prod(cov, dim = -1):
        #    print('Zero occurs when calculating determinant of diagonal covariance. (batched gaussian ll)')


        logdet = torch.sum(2 * torch.log(sigma) , dim = -1)
        #logdet = torch.log(torch.prod(sigma**2, dim = -1))         #product of all diagonal variance for each batch, shape [batch]
        #print('logdet=', logdet)
        n = mean.size()[-1]
        
        #return - np.log(2*np.pi) - 0.5*logdet - 0.5 * exp 

        return -(n/2) * np.log(2*np.pi) - 0.5*logdet - 0.5 * exp 


    def kl_divergence(self, prior_m, prior_sigma, post_m ,post_sigma):
        """
        KL( q || p )
        shape : [hidden, hidden+1]
        """
        if 0 in prior_sigma**2:
            print('Zero occurs in squaring prior sigma')
        if 0 in post_sigma**2:
            print('Zero occurs in squaring posterior sigma')

        multi_normal_prior = MultivariateNormal(vec(prior_m), self.diagonalise(prior_sigma**2, False))
        multi_normal_post = MultivariateNormal(vec(post_m), self.diagonalise(post_sigma**2, False))

        return KL_f(multi_normal_post, multi_normal_prior)

    """
    def kl_divergence(self, prior_m, prior_sigma, post_m, post_sigma):

        d = prior_m.size(0) * prior_m.size(1)       #hidden*hidden+1
        if 0 in prior_sigma**2:
            print('Zero occurs in squaring prior sigma')
        if 0 in post_sigma**2:
            print('Zero occurs in squaring posterior sigma')
        vec_prior_m = vec(prior_m)
        vec_prior_sigma = vec(prior_sigma)
        vec_post_m = vec(post_m)
        vec_post_sigma = vec(post_sigma)

        trace = ((vec_prior_sigma/vec_post_sigma)**2).sum()

        inv_post_diag_cov = self.diagonalise(1/(post_sigma**2), batch=False)        #[hid*hid+1, hid*hid+1]
        exp = (vec_post_m - vec_prior_m) @ inv_post_diag_cov @ (vec_post_m - vec_prior_m)  
        logdet_prior_cov = (2*torch.log(prior_sigma)).sum()
        logdet_post_cov = (2* torch.log(post_sigma)).sum()
        logdet = logdet_post_cov - logdet_prior_cov  
        return 0.5 * (logdet - d + trace + exp)  

    """

    def reparametrise(self, mean, sigma):
        """
        sigma should have the same shape as mean (no correaltion)
        """
        #eps = torch.FloatTensor(sigma).normal_().to('cpu')

        eps = torch.rand_like(sigma).normal_().to(self.device)
        return mean + sigma*eps
    
    def stack_W(self, batch_size, mean, sigma):
        list_of_W = []
        for i in range(batch_size):
            temp_W = self.reparametrise(mean, sigma).unsqueeze(0)       #[1, hid, hid+1]

            list_of_W.append(temp_W)
        return torch.cat(list_of_W)     #[batch_size, hid, hid+1]
    
    def diagonalise(self, input, batch):
        """
        if input a vector, return a diagonal matrix
        if input a non-batched 2d matrix, return a diagonal matrix, eg: [[1,2],[3,4]] ---> diag([1,2,3,4])
        if input a batched 2d matrix, return a batched diagonal matrix
        if input a 3d batched tensor, return a batched diagonal tensor
        """
        if len(input.size())==1:
            return torch.diag(input)
        if len(input.size())==2:
            if not batch:
                return torch.diag(vec(input))
            else:
                bdiag = torch.Tensor().to(self.device)
                for i in range(input.size()[0]):
                    bdiag = torch.cat((bdiag, torch.diag(input[i]).unsqueeze(0)), axis = 0)
                return bdiag

        if len(input.size())==3 and batch:
            bdiag = torch.Tensor()
            for i in range(input.size()[0]):
                bdiag = torch.cat((bdiag, torch.diag(vec(input[i])).unsqueeze(0)), axis = 0)
    
            return bdiag
        else:
            print('Dimension of inpout tensor should only be 1,2,3.')



#Deterministic controller

In [6]:
class controller(nn.Module):
    def __init__(self, action_dim=1, state_dim=4, deterministic = True,device = 'cuda'):
        super(controller, self).__init__()
        self.action_dim = action_dim
        controller_hid = 16
        self.state_dim  = state_dim
        init_w = 1e-3

        self.linear = nn.Sequential(
            nn.Linear(self.state_dim,controller_hid),
            nn.ReLU(),
            nn.Linear(controller_hid, controller_hid),
            nn.ReLU(),
            nn.Linear(controller_hid, action_dim)
        )




        #if not deterministic:
        #    self.std = 0.1
        
        #self.deterministic = deterministic
        
        #nn.Linear(self.state_dim, self.action_dim)        #output a p_logits for action 1
        self.device = device
        self.optimiser = torch.optim.Adam(self.parameters(), lr = 0.001)

    def forward(self, state):
        """
        Given states input [batch, state_dim],
        """
        state = state.to(self.device)
        #x = F.relu(self.linear1(state))
        #x = F.relu(self.linear2(x))
        a = self.linear(state)
        a =  torch.tanh(a)
        log_pi = 0

        return a, log_pi


        #out_mean = self.linear(state)         #[batch, action_dim]
        
        #if not self.deterministic:
        #    eps = torch.rand_like(out_mean).normal_().to(self.device)           #[batch, action_dim]
        #    out = out_mean + self.std * eps
        #else:
        #    out = out_mean

        #if len(out.shape) == 1:
        #    out = torch.clamp(out, -1, 1)
        #else:
        #    out = torch.clamp(out[:,0], -1, 1).unsqueeze(1)             #[1, batch, 1]

        #return torch.tanh(out)


        #if len(out.shape) == 1:
        #    return torch.clamp(out, -1, 1)
        #else:
        #    clamp_out = torch.clamp(out[:, 0], -1, 1).unsqueeze(-1)
        #    return clamp_out
    
    def make_decision(self, state, behaviour_uncertainty):
        """
        given a state [batch, state_dim], output a action
        """
        state = state.to(self.device)
        #x = F.relu(self.linear1(state))
        #x = F.relu(self.linear2(x))
        a = self.linear(state)
        a = torch.tanh(a)

        return a.detach()

        #out_mean = self.linear(state)
        #if not self.deterministic and behaviour_uncertainty:
        #    eps = torch.rand_like(out_mean).normal_().to(self.device) 
        #    out = out_mean + self.std * eps
        #else:
        #    out = out_mean
        #if len(out.shape) == 1:

        #    out = torch.clamp(out, -1, 1)
        #else:
        #    out = torch.clamp(out[:,0], -1, 1)
        #return out.detach()
        #return torch.tanh(out)
    def pg_train(self, num_epoch, initial_state, horizon, cost_f, model_imagine_f, w_uncertainty, e_uncertainty,gamma = 0.95):
        """
        initial_state : [batch, state_dim]

        """
        loss_list = []
        num_particle = 100
        initial_state = initial_state.expand(num_particle, self.state_dim)

        for e in range(num_epoch):
            output_matrix, action_log_prob_matrix = model_imagine_f(initial_state, self.forward, horizon, plan = 'pg',
                                                                    W_uncertainty = w_uncertainty, e_uncertainty = e_uncertainty)
            cost = cost_f(output_matrix).detach()               #[seq-1, batch, 1]  
            
            cost = cost * torch.tensor([gamma**(t+1) for t in range(cost.size(0))]).unsqueeze(-1).unsqueeze(-1).to(self.device)

            #baseline = torch.mean(cost, dim = 0).unsqueeze(0)       #[1, batch, 1]
            #cost = cost - torch.mean(cost, dim = 0).unsqueeze(0)
            #loss = ((cost-baseline) * action_log_prob_matrix).sum(0)
            loss = cost.sum(0) * action_log_prob_matrix.sum(0) 

            loss = loss.sum()
            loss.backward()
            nn.utils.clip_grad_norm_(self.parameters(), 5)

            self.optimiser.step()
            loss_list.append(loss.item())
            if e%50 == 0:
                print('Epoch = {}; Policy gradient training loss = {}'.format(e, loss.item()))
        return loss_list




    
    def rp_train(self, num_epoch, num_particle,initial_state, horizon , cost_f, model_imagine_f, w_uncertainty, e_uncertainty,gamma = 0.9):
        """
        From an initial state, use mode imagination function to make prediction of the next state accordin to the action proposed by
        the controller, we fixed the horizon and compute the total reward of the trajectory, from which the gradient w.r.t policy 
        parameters is taken.
        inital_state: [batch, output_dim]
        """
        loss_list = []
        num_particle = num_particle
        initial_state = initial_state.expand(num_particle, self.state_dim)

        cost_mean_list = []
        cost_std_list = []
        
        for e in range(num_epoch):
            self.optimiser.zero_grad()
            output_matrix, action_matrix= model_imagine_f(initial_state, self.forward, horizon, plan = 'rp',
                                              W_uncertainty = w_uncertainty,e_uncertainty = e_uncertainty) 
            
            self.action_matrix = action_matrix
            self.temp_output_matrix = torch.cat([initial_state.unsqueeze(0).to(self.device), output_matrix], dim = 0)

            cost = cost_f(output_matrix)                 #[seq-1, batch, 1]  

            mean_cost = cost.data.sum(0).mean(0)      #[]
            std_cost = cost.data.sum(0).std(0)
            cost_mean_list.append(mean_cost)
            cost_std_list.append(std_cost)

            #multiply by discount factor
            #cost = cost *  ((torch.arange(cost.size(0)+1,1,-1).float()).unsqueeze(-1).unsqueeze(-1)/cost.size(0)
            #                    ).expand(cost.shape).float().to(self.device)
            
            cost = cost * torch.tensor([gamma**(t+1) for t in range(cost.size(0))]).unsqueeze(-1).unsqueeze(-1).to(self.device)

            loss = cost.sum()      #[batch, 1]
            #loss = torch.exp(action_log_prob_matrix.sum(0)) * cost.sum(0)
            #loss = (cost * action_log_prob_matrix).sum(0)                 #[batch, 1]
            loss.backward()

            nn.utils.clip_grad_norm_(self.parameters(), 1)
            self.optimiser.step()
            loss_list.append(loss.item())
            #print('policy loss = {}', loss.item())
            #print('Epoch = {}; Policy gradient training loss = {}; Cost: mean {} std {}.'.format(e, loss.item()/num_particle,
            #                                                                                        mean_cost.item(), std_cost.item()))


        return loss_list, torch.cat(cost_mean_list), torch.cat(cost_std_list)

    def rp_validate(self, num_particle, initial_state, horizon, cost_f, model_imagine_f, w_uncertainty, e_uncertainty, gamma = 1):
        initial_state = initial_state.expand(num_particle, self.state_dim)
        output_matrix, action_matrix= model_imagine_f(initial_state, self.forward, horizon, plan = 'rp',
                                              W_uncertainty = w_uncertainty,e_uncertainty = e_uncertainty) 
        cost = cost_f(output_matrix)

        mean_cost = cost.data.sum(0).mean(0)
        return mean_cost.item()


#Agent

In [7]:
class Agent:
    def __init__(self, env_case, state_dim = 4, action_dim = 1, model = 'LLB',deterministic = True,device='cuda', rand_seed = 1):
        
        self.env = CartPoleModEnv(case = env_case)
        self.env_case = env_case
        #self.env = gym.make('CartPole-BT-dH-v0')
        #self.env = gym.make('CartPoleMod-v2')

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.device = device
        self.observations_list = []
        self.actions_list= []
        self.MSEloss = nn.MSELoss()
        self.model_name = model

        if model == 'DRNN':
            self.model = DRNN(action_dim, 32, state_dim, device, 'LSTM').to(device)
        elif model == 'SRNN':
            self.model = SRNN(action_dim, 32, state_dim, device, 'LSTM', noise = 0.5 * torch.tensor([0.1,0.1,0.1,1])).to(device)
        elif model == 'LLB':
            self.model = BRNN(action_dim, 32, state_dim, device, 'LSTM').to(device)

        self.model_optimiser = torch.optim.Adam(self.model.parameters(), lr = 0.001)
        self.mseloss = nn.MSELoss()
        self.model_training_loss_list = []

        self.policy = controller(action_dim, state_dim, device).to(device)
        self.deterministic_policy = deterministic

        #np.random.seed(rand_seed)
        #self.env.seed(rand_seed)
        #torch.manual_seed(rand_seed)

    
    def env_rollout(self, if_remember, plan, behaviour_uncertainty):
        """
        interaction with the environment using the current policy. 
        """
        done = False
        state = self.env.reset()
        total_reward = 0
        i = 0
        temp_obs_list = []
        temp_actions_list = []

        if if_remember:
            temp_obs_list.append(torch.tensor(state))
        
        while not done:
            i+=1
            
            if plan == 'random':
                action = self.env.action_space.sample()
            elif plan == 'pg' or 'rp':

                state_tensor = torch.tensor(np.vstack(state)).float().squeeze()

                action = self.policy.make_decision(state_tensor.to(self.device), behaviour_uncertainty)
                action = action.detach().cpu().numpy()

                
            else:
                raise NotImplementedError
            #print('action = ', action)

            #next_state, reward, done, _  = self.env.step(int(action))
            next_state, reward, _, _  = self.env.step(action)

            #print('next state = ', next_state)

            done = next_state[0] < -2.4 \
                or next_state[0] > 2.4 \
                or next_state[2] < -12 * 2 * math.pi / 360 \
                or next_state[2] > 12 * 2 * math.pi / 360 \
                or i >= 200
            
            #print('done = ', done)

            if if_remember:

                temp_obs_list.append(torch.tensor(np.vstack(next_state)).squeeze())

                #temp_obs_list.append(torch.tensor(next_state))


                temp_actions_list.append(torch.tensor(action).float())
            
            state = next_state
            total_reward += 1

            #if total_reward > 200:
            #    break

        if if_remember:
            self.observations_list.append(torch.stack(temp_obs_list).float())       #list of shape [seq, output]
            self.actions_list.append(torch.stack(temp_actions_list).float())        #list of shape [seq-1, 1]

        return total_reward
    
    def model_learning(self, num_epoch, num_batch):
        """
        perform model leanring using data self.observation_list and self.actions_list; since the data has variable length, one could 
        try truncate the data into same length or pack_padded_sequence, but here we would simply train each single sample in a batch,
        and during each epoch, the parameter is only updated once using part of the dataset
        num_epoch : number of training epoch
        num_batch: this is actually number of samples we want the model to be trained on during each epoch
        """



        for e in range(num_epoch):
            self.model_optimiser.zero_grad()

            idx = np.random.choice(len(self.observations_list), num_batch)
            trun_obs = truncate_sequence([self.observations_list[i] for i in idx], batch_first = False)
            trun_actions = truncate_sequence([self.actions_list[j] for j in idx], batch_first = False)
            
            if self.model_name == 'DRNN':
                pred = self.model(trun_obs[0,:,:], trun_actions)
            elif self.model_name == 'SRNN':
                pred, _, _ = self.model(trun_obs[0, :, :], trun_actions)
            elif self.model_name == 'LLB':
                N = int(trun_actions.numel())
                b_FE, b_LL, b_KL = self.model(trun_obs, trun_actions, N)
                loss = -b_FE
            if self.model_name is not 'LLB':    
                loss = self.MSEloss(torch.cat(pred), trun_obs[1:,:,:].to('cuda'))
            loss.backward()

            

            #for i in idx:
            #    training_obs = self.observations_list[i].unsqueeze(1)       #[seq, 1, output]
            #    training_actions = self.actions_list[i]                      #[seq-1, 1]

            #    pred = self.model(training_obs[0,:,:], training_actions.unsqueeze(-1))
            #    loss = self.mseloss(torch.cat(pred).unsqueeze(1), training_obs[1:, :, :].to(self.device))
            #    temp_loss += loss
            #temp_loss.backward()

            self.model_optimiser.step()
            self.model_training_loss_list.append(loss.item())

            if e%1000 == 0:
                if self.model_name == 'LLB':
                    print('Epoch{}; FE = {}; LL = {}; KL = {}.'.format(e, b_FE.item(), b_LL.item(), b_KL.item()))
                else:
                    print('Epoch:{}; loss = {}.'.format(e, loss.item()))


    def cost(self, state):
        """
        cost = 5*angle^2 + position^2
        state : [seq, batch, output]
        return [seq, batch, 1]
        """
        return (5*state[:,:,2]**2 + state[:,:,0]**2).unsqueeze(-1)      #[seq, batch, 1]

    '''
    def cost(self, states, sigma=0.25):
        """
        states : [seq, batch, output]
        return : [seq, batch, 1]
        """
        l = 0.6
        seq_length = states.size(0)
        batch_size = states.size(1)
        feature_dim = states.size(-1)
        
        goal = Variable(torch.FloatTensor([0.0, l])).unsqueeze(0).unsqueeze(0).expand(seq_length,1, 2).to(self.device)     #[seq, 1,2]

        # Cart position
        cart_x = states[:,:, 0]         #[seq, batch]
        # Pole angle
        thetas = states[:,:,2]          #[seq, bnatch]
        # Pole position
        x = torch.sin(thetas)*l         #[seq, batch]
        y = torch.cos(thetas)*l
        positions = torch.stack([cart_x + x, y], -1)             #[seq, batch, 2]

        
        squared_distance = torch.sum((goal - positions)**2, -1).unsqueeze(-1)          #[]

        squared_sigma = sigma**2
        cost = 1 - torch.exp(-0.5*squared_distance/squared_sigma)
        
        return cost
    '''




    def policy_learning(self, imagine_num, num_particle,num_epoch, batch_size, horizon, plan, w_uncertainty, e_uncertainty, plot=False):
        """
        we utilise the current learned model to do policy learning on imagined data
        num_epoch : number of epochs we want to run our policy gradient for
        batch_size : number of samples we want to train the policy on/ number of initial states

        we creat batch_size number of initial state, the model then rollout for a fixed length(horizon), the sum of cost for each 
        imagined trajectory is computed, from which the gradient is taken w.r.t the policy parameters
        """
        #creat inital states 
        for i in range(imagine_num):
            #initial_state = []
            #for b in range(batch_size):
            #    init_x = self.env.reset()
            #    initial_state.append(torch.tensor(init_x).float())
            #initial_state = torch.stack(initial_state)          #[batch, output]
            initial_state = torch.tensor(self.env.reset()).unsqueeze(0).float()         #[1, output]
            if plot:
                initial_state = torch.zeros_like(initial_state)
            #initial_state = torch.tensor(np.array([ 0.04263216,  0.00452452, -0.03763419, -0.03992425])).float().unsqueeze(0)

            #learn the policy parameter using current model

            model_f = self.model.imagine

            if plan == 'pg':
                policy_train_loss = self.policy.pg_train(num_epoch, num_particle,initial_state, horizon , self.cost, model_f, 
                                                         w_uncertainty, e_uncertainty,gamma=1)
            elif plan == 'rp':
                policy_train_loss = self.policy.rp_train(num_epoch, num_particle,initial_state, horizon, self.cost, model_f, 
                                                         w_uncertainty, e_uncertainty,gamma = 1)
        self.policy_loss = policy_train_loss
        """
        total_reward = []
        for i in range(20):
            init_x = torch.tensor(self.env.reset()).unsqueeze(0).float() 

            imagine_reward = self.model.validate_by_imagination(init_x, self.policy.forward, 
                                                            plan, w_uncertainty, e_uncertainty)
            total_reward.append(imagine_reward)
            #print('temp reward', imagine_reward)
        mean_reward = np.mean(total_reward)
        std_reward = np.std(total_reward)
        print('Training reward: mean {}, std {}.'.format(mean_reward, std_reward))
        return mean_reward, std_reward

        """
        """
        total_cost10 = []
        total_cost100= []
        for i in range(20):
            initial_state = torch.tensor(self.env.reset()).unsqueeze(0).float()         #[1, output]
            mean_cost10 = self.policy.rp_validate(num_particle, initial_state, 10, self.cost, model_f, w_uncertainty, e_uncertainty)
            mean_cost100 = self.policy.rp_validate(num_particle, initial_state, 100, self.cost, model_f, w_uncertainty, e_uncertainty)

            total_cost10.append(mean_cost10)
            total_cost100.append(mean_cost100)

        mean_cost10 = np.mean(total_cost10)
        std_cost10 = np.std(total_cost10)

        mean_cost100 = np.mean(total_cost100)
        std_cost100 = np.std(total_cost100)


        return mean_cost10, std_cost10, mean_cost100, std_cost100

        """
        """
        total_cost10 = []
        for i in range(20):
            initial_state = torch.tensor(self.env.reset()).unsqueeze(0).float()         #[1, output]
            mean_cost = self.policy.rp_validate(num_particle, initial_state, 10, self.cost, model_f, w_uncertainty, e_uncertainty, gamma=1)
            total_cost10.append(mean_cost)

        mean_cost10 = np.mean(total_cost10)
        std_cost10 = np.std(total_cost10)

        total_cost100 = []
        for i in range(20):
            initial_state = torch.tensor(self.env.reset()).unsqueeze(0).float()         #[1, output]
            mean_cost = self.policy.rp_validate(num_particle, initial_state, 100, self.cost, model_f, w_uncertainty, e_uncertainty, gamma=1)
            total_cost100.append(mean_cost)

        mean_cost100 = np.mean(total_cost100)
        std_cost100 = np.std(total_cost100)


        return mean_cost10, std_cost10, mean_cost100, std_cost100
        """


#Learning-Planning iterations

In [8]:
time1 = time.time()
testing_reward_list = []
behaviour_uncertainty = True
deterministic = False
plan = 'rp'
num_data = 10

#mean_training_reward_list = []
#std_training_reward_list = []
torch.manual_seed(1)

agent = Agent(env_case = 1, deterministic = deterministic,device = 'cuda', model = 'LLB')
for i in range(num_data):
     _ = agent.env_rollout(True, behaviour_uncertainty = behaviour_uncertainty,plan = 'random')

agent.model_learning(num_epoch=1000, num_batch = 10)
#agent.model = BRNN_model

agent.policy_learning(imagine_num=50, num_particle = 1000, num_epoch = 1, 
                      batch_size = 10, horizon = 10, plan = plan, w_uncertainty = True, e_uncertainty = True)

#mean_training_reward_list.append(mean_training_reward)
#std_training_reward_list.append(std_training_reward)

print('\n ------------------TESTING-------------------')
#over 10 trails
avg_rewards = 0
for j in range(20):
    rewards = agent.env_rollout(if_remember=False, behaviour_uncertainty = behaviour_uncertainty,plan = plan)
    print(j, rewards)
    avg_rewards += rewards
avg_rewards = avg_rewards/20
testing_reward_list.append(avg_rewards)
print('Total trajs:', j, avg_rewards)
if avg_rewards >= 200:
    print('success')
        
#testing_reward_list = []
#mean_training_reward_list = []
#std_training_reward_list = []
#agent.policy = controller(1, 4, 'cuda').to('cuda')
avg_data_length_list = []

for i in range(50):
    print('Epoch = ',i+1)

    _ = agent.env_rollout(True, behaviour_uncertainty = behaviour_uncertainty,plan = plan)

    total = 0
    for i in range(len(agent.observations_list)):
        total += len(agent.observations_list[i])
    print('average training data length = ', total/len(agent.observations_list))
    avg_data_length_list.append(total/len(agent.observations_list))

    print('\n Begin model learning...')
    agent.model_learning(num_epoch =  1000, num_batch = 10)
    print('\n Finish model learning...')

    #agent.policy = controller(1, 4, 'cuda').to('cuda')
    agent.policy_learning(imagine_num=50, num_particle = 1000, num_epoch = 1, 
                          batch_size = 10, horizon = 10, plan = plan, w_uncertainty = True, e_uncertainty = True )
    #mean_training_reward_list.append(mean_training_reward)
    #std_training_reward_list.append(std_training_reward)
    print('\n Finish policy learning...')

    print('\n ------------------TESTING-------------------')
    #over 10 trails
    avg_rewards = 0
    for j in range(20):
        rewards = agent.env_rollout(if_remember=False, behaviour_uncertainty = behaviour_uncertainty,plan = plan)
        print(j, rewards)
        avg_rewards += rewards
    avg_rewards = avg_rewards/20
    testing_reward_list.append(avg_rewards)
    print('Total trajs:', j, avg_rewards)
    if avg_rewards > 200:
        print('success')
time2 = time.time()
print(time2 - time1)

CartPoleModEnv - Version 0.2.0, Noise case: 1




Epoch0; FE = -43.0851936340332; LL = -43.0851936340332; KL = 0.0.

 ------------------TESTING-------------------
0 25
1 22
2 18
3 21
4 25
5 16
6 19
7 22
8 23
9 18
10 21
11 17
12 24
13 20
14 27
15 20
16 17
17 17
18 21
19 23
Total trajs: 19 20.8
Epoch =  1
average training data length =  25.181818181818183

 Begin model learning...
Epoch0; FE = 25.531444549560547; LL = 27.623550415039062; KL = 2.092106342315674.

 Finish model learning...

 Finish policy learning...

 ------------------TESTING-------------------
0 33
1 31
2 37
3 38
4 49
5 29
6 48
7 53
8 34
9 45
10 38
11 69
12 30
13 42
14 41
15 59
16 53
17 59
18 49
19 41
Total trajs: 19 43.9
Epoch =  2
average training data length =  25.75

 Begin model learning...
Epoch0; FE = 34.811622619628906; LL = 38.47646713256836; KL = 3.6648433208465576.


KeyboardInterrupt: ignored