In [1]:
import gym
import torch
import torch.nn as nn
import numpy as np
from scipy.special import softmax

In [2]:
class ValueNet(nn.Module):
    def __init__(self, input_dim, hidden, output_dim):
        super(ValueNet, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden),
            nn.Tanh(),
            nn.Linear(hidden, hidden),
            nn.Tanh(),
            nn.Linear(hidden, output_dim),
        )
    
    def forward(self, x):
        return self.model(x)

In [3]:
class POLO(object):
    def __init__(self, K, T, lambda_, noise_mu, noise_sigma, U, u_init, memory_size, 
                 observation_space, action_space, state_space, net_hidden_layers, num_nets, gamma, 
                 state_samples, gradient_steps, noise_gaussian=True):
        self.memory_size = memory_size
        self.obs_mem = np.zeros((self.memory_size, observation_space))
        self.state_mem = np.zeros((self.memory_size, state_space))
        
        self.K = K  # N_SAMPLES
        self.T = T  # TIMESTEPS
        self.lambda_ = lambda_
        self.noise_mu = noise_mu
        self.noise_sigma = noise_sigma
        self.U = U
        self.u_init = u_init
        self.reward_total = np.zeros(shape=(self.K))
        self.gamma = gamma
        self.state_samples = state_samples
        self.gradient_steps = gradient_steps


        if noise_gaussian:
            self.noise = np.random.normal(loc=self.noise_mu, scale=self.noise_sigma, size=(self.K, self.T))
        else:
            self.noise = np.full(shape=(self.K, self.T), fill_value=0.9)
        
        self.num_nets = num_nets
        
        self._build_value_nets(observation_space, net_hidden_layers, action_space)

    
    def _build_value_nets(self, input_dim, hidden, output_dim):
        self.value_nets = []
        self.loss_funcs = []
        self.optimizers = []
        
        for i in range(self.num_nets):
            self.value_nets.append(ValueNet(input_dim, hidden, output_dim))
            self.loss_funcs.append(nn.MSELoss())
            self.optimizers.append(torch.optim.Adam(self.value_nets[-1].parameters(), lr=0.01))
            
        
    def get_aggregated_value(self, s):
        values = []
        for net in self.value_nets:
            values.append(net(torch.tensor(s, dtype=torch.float)).item())
            
        values = np.array(values)
        weights = softmax(values)
        weighted_values = values * weights

        return sum(weighted_values)
        
    def learn(self, env):
        init_state = env.env.state
        for _ in range(self.gradient_steps):
            idx = np.random.choice(np.min([self.memory_counter, self.memory_size]), size=self.state_samples, replace=False)

            sampled_states = self.state_mem[idx,:]

            sampled_obs = self.obs_mem[idx,:]



            targets = x = [None for i in range(self.num_nets)]


            for s_state, o in zip(sampled_states, sampled_obs):
                discount = 1
                total_reward = 0

                max_rewards = [float('-inf') for _ in range(self.num_nets)]

                for k in range(self.K):
                    env.env.state = s_state
                    for t in range(self.T):
                        perturbed_action_t = self.U[t] + self.noise[k, t]

                        s, reward, _, _ = env.step([perturbed_action_t])

                        total_reward += discount * reward
                        discount *= self.gamma

                    for i in range(self.num_nets):
                        net = self.value_nets[i]
                        reward_for_net = torch.tensor(total_reward, dtype=torch.float) + net(torch.tensor(s, dtype=torch.float))
                        if reward_for_net > max_rewards[i]:
                            max_rewards[i] = reward_for_net



                for i in range(self.num_nets):
                    net = self.value_nets[i]
                    loss_func = self.loss_funcs[i]
                    optimizer = self.optimizers[i]

                    target = max_rewards[i]

                    if targets[i] is None:
                        targets[i] = torch.tensor([[target]], dtype=torch.float)
                    else:
                        targets[i] = torch.cat((targets[i], torch.tensor([[target]], dtype=torch.float)))

        
            for i in range(self.num_nets):
                net = self.value_nets[i]
                loss_func = self.loss_funcs[i]
                optimizer = self.optimizers[i]

                optimizer.zero_grad()

                preds = net(torch.tensor(sampled_obs, dtype=torch.float))

                loss = loss_func(preds, targets[i])

                loss.backward()
                optimizer.step()
                
        env.env.state = init_state

    
    def _compute_total_reward(self, k, env):
        discount = 1
        for t in range(self.T):
            perturbed_action_t = self.U[t] + self.noise[k, t]
            s, reward, _, _ = env.step([perturbed_action_t])
            self.reward_total[k] += discount * reward
            discount *= self.gamma
        self.reward_total[k] += discount * self.get_aggregated_value(s)
    
    def _ensure_non_zero(self, reward, beta, factor):
        return np.exp(-factor * (beta - reward))
    
    def choose_action(self, env):
        init = env.env.state
        for k in range(self.K):
            self._compute_total_reward(k, env)
            env.env.state = init
        
        beta = np.max(self.reward_total)  # maximize reward of all trajectories
        reward_total_non_zero = self._ensure_non_zero(reward=self.reward_total, beta=beta, factor=1/self.lambda_)

        eta = np.sum(reward_total_non_zero)
        omega = 1/eta * reward_total_non_zero

        self.U += [np.sum(omega * self.noise[:, t]) for t in range(self.T)]
        
        
        env.env.state = init
        action = self.U[0]
        
        self.U = np.roll(self.U, -1)  # shift all elements to the left
        self.U[-1] = self.u_init  #
        self.reward_total[:] = 0
        
        
        return action
    
    
    def store_state(self, obs, state):
        if not hasattr(self, 'memory_counter'):
            self.memory_counter = 0

        # replace the old memory with new memory
        index = self.memory_counter % self.memory_size
        self.obs_mem[index] = np.array(obs)
        self.state_mem[index] = np.array(state)

        self.memory_counter += 1
    

In [4]:
ENV_NAME = "Pendulum-v0"
TIMESTEPS = 14 # T
N_SAMPLES = 120  # K

env = gym.make(ENV_NAME)
ACTION_LOW = env.action_space.low[0]
ACTION_HIGH = env.action_space.high[0]

noise_mu = 0
noise_sigma = 0.3
lambda_ = 1

Z = 16

U = np.random.uniform(low=ACTION_LOW, high=ACTION_HIGH, size=TIMESTEPS)

s = env.reset()

polo = POLO(K=N_SAMPLES, T=TIMESTEPS, U=U, lambda_=lambda_, noise_mu=noise_mu, 
            noise_sigma=noise_sigma, u_init=0, memory_size=512, 
            observation_space=env.observation_space.shape[0], action_space=env.action_space.shape[0], 
            state_space=len(env.env.state),
            net_hidden_layers=16, num_nets=6, gamma=0.99, state_samples=8, gradient_steps=16, noise_gaussian=True)


env.render()
polo.store_state(s, env.env.state)

for t in range(10000):
    a = polo.choose_action(env)
    s, r, _, _ = env.step([a])
    print("action taken: {:.2f} reward received: {:.2f}".format(a, r))
    env.render()
    polo.store_state(s, env.env.state)
    
    if t != 0 and t % Z == 0:
        polo.learn(env)


action taken: 1.44 reward received: -2.05
action taken: -1.18 reward received: -2.33
action taken: 0.27 reward received: -2.97
action taken: -0.78 reward received: -3.80
action taken: 1.57 reward received: -5.05
action taken: -0.94 reward received: -6.26
action taken: -0.28 reward received: -8.10
action taken: 0.32 reward received: -10.08
action taken: -1.15 reward received: -12.07
action taken: -0.91 reward received: -13.02
action taken: 1.85 reward received: -11.30
action taken: 0.16 reward received: -9.15
action taken: -0.98 reward received: -7.41
action taken: -2.05 reward received: -5.99
action taken: -0.74 reward received: -4.84
action taken: -0.88 reward received: -3.75
action taken: -1.06 reward received: -2.89
action taken: -1.19 reward received: -2.24
action taken: -1.30 reward received: -1.77
action taken: -1.43 reward received: -1.44
action taken: -1.39 reward received: -1.23
action taken: -1.41 reward received: -1.11
action taken: -1.38 reward received: -1.08
action taken:

action taken: -1.35 reward received: -1.50
action taken: -1.42 reward received: -1.24
action taken: -1.46 reward received: -1.08
action taken: -1.44 reward received: -1.00
action taken: -1.32 reward received: -1.00
action taken: -1.16 reward received: -1.07
action taken: -0.97 reward received: -1.25
action taken: -0.79 reward received: -1.54
action taken: -0.60 reward received: -1.98
action taken: -0.45 reward received: -2.63
action taken: -0.32 reward received: -3.52
action taken: -0.19 reward received: -4.71
action taken: -0.10 reward received: -6.23
action taken: -0.04 reward received: -8.06
action taken: 0.00 reward received: -10.17
action taken: 0.02 reward received: -12.42
action taken: 0.01 reward received: -13.79
action taken: 0.06 reward received: -11.82
action taken: 0.16 reward received: -9.83
action taken: 0.30 reward received: -7.94
action taken: 0.45 reward received: -6.25
action taken: 0.62 reward received: -4.81
action taken: 0.80 reward received: -3.65
action taken: 0.

action taken: 0.99 reward received: -2.74
action taken: 1.16 reward received: -2.07
action taken: 1.31 reward received: -1.58
action taken: 1.44 reward received: -1.22
action taken: 1.56 reward received: -0.98
action taken: 1.67 reward received: -0.81
action taken: 1.79 reward received: -0.71
action taken: 1.91 reward received: -0.65
action taken: 2.02 reward received: -0.62
action taken: 2.06 reward received: -0.62
action taken: 2.07 reward received: -0.66
action taken: 2.03 reward received: -0.72
action taken: 1.89 reward received: -0.83
action taken: 1.73 reward received: -0.98
action taken: 1.57 reward received: -1.20
action taken: 1.46 reward received: -1.52
action taken: 1.32 reward received: -1.96
action taken: 1.17 reward received: -2.56
action taken: 1.17 reward received: -3.37
action taken: 1.13 reward received: -4.41
action taken: 0.86 reward received: -5.71
action taken: 0.62 reward received: -7.29
action taken: 0.40 reward received: -9.13
action taken: 0.24 reward received

action taken: 0.32 reward received: -10.58
action taken: 0.12 reward received: -12.69
action taken: 0.05 reward received: -12.96
action taken: -0.07 reward received: -11.05
action taken: -0.22 reward received: -9.17
action taken: -0.39 reward received: -7.41
action taken: -0.57 reward received: -5.86
action taken: -0.75 reward received: -4.55
action taken: -0.92 reward received: -3.49
action taken: -1.07 reward received: -2.68
action taken: -1.20 reward received: -2.07
action taken: -1.31 reward received: -1.63
action taken: -1.40 reward received: -1.33
action taken: -1.45 reward received: -1.13
action taken: -1.44 reward received: -1.02
action taken: -1.33 reward received: -0.99
action taken: -1.21 reward received: -1.03
action taken: -1.08 reward received: -1.16
action taken: -0.92 reward received: -1.40
action taken: -0.74 reward received: -1.78
action taken: -0.61 reward received: -2.33
action taken: -0.48 reward received: -3.10
action taken: -0.34 reward received: -4.14
action tak

action taken: -0.10 reward received: -6.26
action taken: -0.04 reward received: -8.10
action taken: -0.00 reward received: -10.22
action taken: 0.01 reward received: -12.48
action taken: 0.00 reward received: -13.79
action taken: 0.06 reward received: -11.81
action taken: 0.16 reward received: -9.82
action taken: 0.29 reward received: -7.92
action taken: 0.45 reward received: -6.22
action taken: 0.62 reward received: -4.79
action taken: 0.79 reward received: -3.62
action taken: 0.97 reward received: -2.73
action taken: 1.14 reward received: -2.05
action taken: 1.29 reward received: -1.56
action taken: 1.41 reward received: -1.21
action taken: 1.54 reward received: -0.96
action taken: 1.65 reward received: -0.79
action taken: 1.76 reward received: -0.68
action taken: 1.88 reward received: -0.62
action taken: 2.00 reward received: -0.59
action taken: 2.06 reward received: -0.59
action taken: 2.07 reward received: -0.61
action taken: 2.05 reward received: -0.66
action taken: 1.99 reward r

action taken: 2.07 reward received: -0.66
action taken: 2.03 reward received: -0.73
action taken: 1.89 reward received: -0.83
action taken: 1.72 reward received: -0.98
action taken: 1.56 reward received: -1.21
action taken: 1.45 reward received: -1.53
action taken: 1.32 reward received: -1.98
action taken: 1.16 reward received: -2.59
action taken: 1.15 reward received: -3.41
action taken: 1.12 reward received: -4.45
action taken: 0.87 reward received: -5.76
action taken: 0.62 reward received: -7.34
action taken: 0.39 reward received: -9.20
action taken: 0.23 reward received: -11.24
action taken: 0.10 reward received: -13.36
action taken: 0.02 reward received: -12.30
action taken: -0.11 reward received: -10.40
action taken: -0.27 reward received: -8.56
action taken: -0.45 reward received: -6.87
action taken: -0.64 reward received: -5.40
action taken: -0.82 reward received: -4.18
action taken: -0.98 reward received: -3.21
action taken: -1.13 reward received: -2.47
action taken: -1.25 rew

action taken: -1.22 reward received: -2.03
action taken: -1.33 reward received: -1.60
action taken: -1.41 reward received: -1.31
action taken: -1.45 reward received: -1.13
action taken: -1.44 reward received: -1.03
action taken: -1.32 reward received: -1.01
action taken: -1.16 reward received: -1.07
action taken: -0.98 reward received: -1.23
action taken: -0.81 reward received: -1.50
action taken: -0.62 reward received: -1.92
action taken: -0.46 reward received: -2.54
action taken: -0.34 reward received: -3.40
action taken: -0.20 reward received: -4.54
action taken: -0.10 reward received: -6.01
action taken: -0.04 reward received: -7.81
action taken: -0.00 reward received: -9.88
action taken: 0.02 reward received: -12.11
action taken: 0.01 reward received: -14.02
action taken: 0.06 reward received: -12.06
action taken: 0.15 reward received: -10.08
action taken: 0.28 reward received: -8.17
action taken: 0.44 reward received: -6.45
action taken: 0.61 reward received: -4.98
action taken: 

action taken: 0.90 reward received: -3.26
action taken: 1.08 reward received: -2.46
action taken: 1.24 reward received: -1.86
action taken: 1.38 reward received: -1.43
action taken: 1.50 reward received: -1.13
action taken: 1.62 reward received: -0.92
action taken: 1.74 reward received: -0.78
action taken: 1.86 reward received: -0.69
action taken: 1.97 reward received: -0.65
action taken: 2.06 reward received: -0.64
action taken: 2.06 reward received: -0.66
action taken: 2.03 reward received: -0.72
action taken: 1.89 reward received: -0.80
action taken: 1.83 reward received: -0.94
action taken: 1.66 reward received: -1.14
action taken: 1.49 reward received: -1.42
action taken: 1.28 reward received: -1.83
action taken: 1.23 reward received: -2.39
action taken: 1.21 reward received: -3.13
action taken: 1.10 reward received: -4.10
action taken: 0.80 reward received: -5.32
action taken: 0.66 reward received: -6.85
action taken: 0.46 reward received: -8.63
action taken: 0.30 reward received

action taken: 0.14 reward received: -11.91
action taken: 0.06 reward received: -13.71
action taken: -0.02 reward received: -11.80
action taken: -0.15 reward received: -9.89
action taken: -0.31 reward received: -8.07
action taken: -0.49 reward received: -6.42
action taken: -0.67 reward received: -5.00
action taken: -0.85 reward received: -3.85
action taken: -1.01 reward received: -2.95
action taken: -1.15 reward received: -2.26
action taken: -1.27 reward received: -1.76
action taken: -1.36 reward received: -1.41
action taken: -1.43 reward received: -1.17
action taken: -1.46 reward received: -1.02
action taken: -1.44 reward received: -0.96
action taken: -1.33 reward received: -0.96
action taken: -1.17 reward received: -1.05
action taken: -0.97 reward received: -1.22


KeyboardInterrupt: 