In [1]:
!pip install gym
!pip install stable_baselines3



In [2]:
from gym import Env
from gym.spaces import Discrete, Box
import random
import numpy as np
import torch
from stable_baselines3 import SAC, PPO
import gym

In [3]:
t = np.array([0.0, 0.0])
type(t[0])

numpy.float64

In [57]:
### Define a two line manifold environment ###
class simEnv(Env):
    def __init__(self):
        # Action agent can take
        self.action_space = Box(np.array([-1, -1]), np.array([1, 1]))
        # Work space
        self.observation_space = Box(np.array([-10, -10, 0]), np.array([10, 10, 1]))
        # Initial starting point
        self.state = np.array([0.0, 0.0, 0])
        # Goal state
        self.goal = np.array([10.0, 10.0, 1])
        # Intersection point
        self.intersection = np.array([10.0, 0.0, 0])
    
    def step(self, action):
        ### Projection ###
        if self.state[2] == 1:
            _action = np.array([0.0, action[1]])
        else:
            _action = np.array([action[0], 0.0])
        self.state[0] += _action[0]
        self.state[1] += _action[1]
        if self.state[0] == 10:
            self.state[2] = 1
        done = False
        # Design rewards
        if self.state[1] != 0 and self.state[0] != 10: # Fall off the manifold
            # negative of l2 distance to current manifold reward
            if self.state[2] == 0:
                reward = -np.abs(self.state[1])
            else:
                reward = -np.abs(self.state[0] - 10)
            done = True
        elif np.array_equal(self.state, self.intersection): # Get to the intersection point
            if self.state[2] == 1:
                reward = -np.linalg.norm(self.state - self.intersection) / 2
            else:
                reward = 10
        elif np.array_equal(self.state, self.goal): # Reach the goal 
            reward = 100
            done = True
        else: # Direct the agent to intersection point / goal
            if self.state[2] == 1:
                reward = -np.linalg.norm(self.state - self.goal) / 2
#                 reward = 1 / np.linalg.norm(self.state - self.goal)
            else: 
                reward = -np.linalg.norm(self.state - self.intersection)
#                 reward = 1 / np.linalg.norm(self.state - self.intersection)
        # Set placeholder for info 
        info = {}
        return self.state, reward, done, info
    
    def render(self):
        pass
    
    def reset(self):
        self.state = np.array([0.0,0.0, 0])
        return self.state

In [67]:
env = simEnv()

In [68]:
device = torch.device(
        "cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model = PPO('MlpPolicy', env, verbose=1, device=device, ent_coef=.15)
model = model.learn(total_timesteps=500000, eval_freq=1000)

Using cuda:0 device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
-----------------------------
| time/              |      |
|    fps             | 1454 |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 1036        |
|    iterations           | 2           |
|    time_elapsed         | 3           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.008659707 |
|    clip_fraction        | 0.0512      |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.89       |
|    explained_variance   | 0.00197     |
|    learning_rate        | 0.0003      |
|    loss                 | 2.28e+04    |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.00405    |
|    std             

-----------------------------------------
| time/                   |             |
|    fps                  | 842         |
|    iterations           | 12          |
|    time_elapsed         | 29          |
|    total_timesteps      | 24576       |
| train/                  |             |
|    approx_kl            | 0.018478064 |
|    clip_fraction        | 0.134       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.93       |
|    explained_variance   | -1.8e-05    |
|    learning_rate        | 0.0003      |
|    loss                 | 5.11        |
|    n_updates            | 110         |
|    policy_gradient_loss | -0.0103     |
|    std                  | 1.77        |
|    value_loss           | 10.4        |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 840         |
|    iterations           | 13          |
|    time_elapsed         | 31    

-----------------------------------------
| time/                   |             |
|    fps                  | 827         |
|    iterations           | 23          |
|    time_elapsed         | 56          |
|    total_timesteps      | 47104       |
| train/                  |             |
|    approx_kl            | 0.011552835 |
|    clip_fraction        | 0.12        |
|    clip_range           | 0.2         |
|    entropy_loss         | -4.36       |
|    explained_variance   | 0.00017     |
|    learning_rate        | 0.0003      |
|    loss                 | -0.49       |
|    n_updates            | 220         |
|    policy_gradient_loss | 0.00235     |
|    std                  | 2.31        |
|    value_loss           | 0.556       |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 826         |
|    iterations           | 24          |
|    time_elapsed         | 59    

-----------------------------------------
| time/                   |             |
|    fps                  | 822         |
|    iterations           | 34          |
|    time_elapsed         | 84          |
|    total_timesteps      | 69632       |
| train/                  |             |
|    approx_kl            | 0.009591767 |
|    clip_fraction        | 0.126       |
|    clip_range           | 0.2         |
|    entropy_loss         | -4.62       |
|    explained_variance   | -0.0138     |
|    learning_rate        | 0.0003      |
|    loss                 | -0.524      |
|    n_updates            | 330         |
|    policy_gradient_loss | 0.00683     |
|    std                  | 2.85        |
|    value_loss           | 0.346       |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 822         |
|    iterations           | 35          |
|    time_elapsed         | 87    

-----------------------------------------
| time/                   |             |
|    fps                  | 820         |
|    iterations           | 45          |
|    time_elapsed         | 112         |
|    total_timesteps      | 92160       |
| train/                  |             |
|    approx_kl            | 0.020046826 |
|    clip_fraction        | 0.131       |
|    clip_range           | 0.2         |
|    entropy_loss         | -5.08       |
|    explained_variance   | -0.0148     |
|    learning_rate        | 0.0003      |
|    loss                 | -0.641      |
|    n_updates            | 440         |
|    policy_gradient_loss | 0.00793     |
|    std                  | 3.91        |
|    value_loss           | 0.298       |
-----------------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 819        |
|    iterations           | 46         |
|    time_elapsed         | 114       

------------------------------------------
| time/                   |              |
|    fps                  | 817          |
|    iterations           | 56           |
|    time_elapsed         | 140          |
|    total_timesteps      | 114688       |
| train/                  |              |
|    approx_kl            | 0.0071741873 |
|    clip_fraction        | 0.104        |
|    clip_range           | 0.2          |
|    entropy_loss         | -5.56        |
|    explained_variance   | -0.0043      |
|    learning_rate        | 0.0003       |
|    loss                 | -0.631       |
|    n_updates            | 550          |
|    policy_gradient_loss | 0.00424      |
|    std                  | 5.73         |
|    value_loss           | 0.29         |
------------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 817         |
|    iterations           | 57          |
|    time_elaps

-----------------------------------------
| time/                   |             |
|    fps                  | 818         |
|    iterations           | 67          |
|    time_elapsed         | 167         |
|    total_timesteps      | 137216      |
| train/                  |             |
|    approx_kl            | 0.009378334 |
|    clip_fraction        | 0.148       |
|    clip_range           | 0.2         |
|    entropy_loss         | -6.04       |
|    explained_variance   | -0.000461   |
|    learning_rate        | 0.0003      |
|    loss                 | -0.792      |
|    n_updates            | 660         |
|    policy_gradient_loss | 0.00683     |
|    std                  | 8.39        |
|    value_loss           | 0.297       |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 818         |
|    iterations           | 68          |
|    time_elapsed         | 170   

-----------------------------------------
| time/                   |             |
|    fps                  | 814         |
|    iterations           | 78          |
|    time_elapsed         | 196         |
|    total_timesteps      | 159744      |
| train/                  |             |
|    approx_kl            | 0.015984695 |
|    clip_fraction        | 0.158       |
|    clip_range           | 0.2         |
|    entropy_loss         | -6.56       |
|    explained_variance   | 0.00893     |
|    learning_rate        | 0.0003      |
|    loss                 | -0.739      |
|    n_updates            | 770         |
|    policy_gradient_loss | 0.00259     |
|    std                  | 12.4        |
|    value_loss           | 0.4         |
-----------------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 814        |
|    iterations           | 79         |
|    time_elapsed         | 198       

------------------------------------------
| time/                   |              |
|    fps                  | 795          |
|    iterations           | 89           |
|    time_elapsed         | 229          |
|    total_timesteps      | 182272       |
| train/                  |              |
|    approx_kl            | 0.0070422404 |
|    clip_fraction        | 0.162        |
|    clip_range           | 0.2          |
|    entropy_loss         | -7.05        |
|    explained_variance   | 0.00619      |
|    learning_rate        | 0.0003       |
|    loss                 | -0.901       |
|    n_updates            | 880          |
|    policy_gradient_loss | 0.00649      |
|    std                  | 17.6         |
|    value_loss           | 0.307        |
------------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 795         |
|    iterations           | 90          |
|    time_elaps

------------------------------------------
| time/                   |              |
|    fps                  | 795          |
|    iterations           | 100          |
|    time_elapsed         | 257          |
|    total_timesteps      | 204800       |
| train/                  |              |
|    approx_kl            | 0.0070938906 |
|    clip_fraction        | 0.144        |
|    clip_range           | 0.2          |
|    entropy_loss         | -7.52        |
|    explained_variance   | 2.66e-05     |
|    learning_rate        | 0.0003       |
|    loss                 | -1.01        |
|    n_updates            | 990          |
|    policy_gradient_loss | 0.0126       |
|    std                  | 24.9         |
|    value_loss           | 0.283        |
------------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 795         |
|    iterations           | 101         |
|    time_elaps

-----------------------------------------
| time/                   |             |
|    fps                  | 790         |
|    iterations           | 111         |
|    time_elapsed         | 287         |
|    total_timesteps      | 227328      |
| train/                  |             |
|    approx_kl            | 0.014872646 |
|    clip_fraction        | 0.133       |
|    clip_range           | 0.2         |
|    entropy_loss         | -7.98       |
|    explained_variance   | 0.000369    |
|    learning_rate        | 0.0003      |
|    loss                 | -0.955      |
|    n_updates            | 1100        |
|    policy_gradient_loss | 0.00797     |
|    std                  | 35.4        |
|    value_loss           | 0.25        |
-----------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 788          |
|    iterations           | 112          |
|    time_elapsed         | 29

-----------------------------------------
| time/                   |             |
|    fps                  | 788         |
|    iterations           | 122         |
|    time_elapsed         | 316         |
|    total_timesteps      | 249856      |
| train/                  |             |
|    approx_kl            | 0.027479012 |
|    clip_fraction        | 0.139       |
|    clip_range           | 0.2         |
|    entropy_loss         | -8.46       |
|    explained_variance   | 0.000249    |
|    learning_rate        | 0.0003      |
|    loss                 | -1.14       |
|    n_updates            | 1210        |
|    policy_gradient_loss | 0.00803     |
|    std                  | 50.3        |
|    value_loss           | 0.256       |
-----------------------------------------
---------------------------------------
| time/                   |           |
|    fps                  | 788       |
|    iterations           | 123       |
|    time_elapsed         | 319       |
| 

-----------------------------------------
| time/                   |             |
|    fps                  | 789         |
|    iterations           | 133         |
|    time_elapsed         | 345         |
|    total_timesteps      | 272384      |
| train/                  |             |
|    approx_kl            | 0.013089741 |
|    clip_fraction        | 0.17        |
|    clip_range           | 0.2         |
|    entropy_loss         | -8.89       |
|    explained_variance   | 0.00105     |
|    learning_rate        | 0.0003      |
|    loss                 | -1.22       |
|    n_updates            | 1320        |
|    policy_gradient_loss | 0.00863     |
|    std                  | 69          |
|    value_loss           | 0.221       |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 789         |
|    iterations           | 134         |
|    time_elapsed         | 347   

-----------------------------------------
| time/                   |             |
|    fps                  | 790         |
|    iterations           | 144         |
|    time_elapsed         | 373         |
|    total_timesteps      | 294912      |
| train/                  |             |
|    approx_kl            | 0.005444949 |
|    clip_fraction        | 0.135       |
|    clip_range           | 0.2         |
|    entropy_loss         | -9.37       |
|    explained_variance   | 0.00164     |
|    learning_rate        | 0.0003      |
|    loss                 | -1.35       |
|    n_updates            | 1430        |
|    policy_gradient_loss | 0.0112      |
|    std                  | 96.1        |
|    value_loss           | 0.179       |
-----------------------------------------
---------------------------------------
| time/                   |           |
|    fps                  | 790       |
|    iterations           | 145       |
|    time_elapsed         | 375       |
| 

------------------------------------------
| time/                   |              |
|    fps                  | 790          |
|    iterations           | 155          |
|    time_elapsed         | 401          |
|    total_timesteps      | 317440       |
| train/                  |              |
|    approx_kl            | 0.0060365126 |
|    clip_fraction        | 0.129        |
|    clip_range           | 0.2          |
|    entropy_loss         | -9.83        |
|    explained_variance   | 0.00196      |
|    learning_rate        | 0.0003       |
|    loss                 | -1.33        |
|    n_updates            | 1540         |
|    policy_gradient_loss | 0.00668      |
|    std                  | 134          |
|    value_loss           | 0.229        |
------------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 789         |
|    iterations           | 156         |
|    time_elaps

-----------------------------------------
| time/                   |             |
|    fps                  | 787         |
|    iterations           | 166         |
|    time_elapsed         | 431         |
|    total_timesteps      | 339968      |
| train/                  |             |
|    approx_kl            | 0.013338969 |
|    clip_fraction        | 0.152       |
|    clip_range           | 0.2         |
|    entropy_loss         | -10.3       |
|    explained_variance   | 0.0268      |
|    learning_rate        | 0.0003      |
|    loss                 | -1.47       |
|    n_updates            | 1650        |
|    policy_gradient_loss | 0.00751     |
|    std                  | 186         |
|    value_loss           | 0.163       |
-----------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 787          |
|    iterations           | 167          |
|    time_elapsed         | 43

-----------------------------------------
| time/                   |             |
|    fps                  | 785         |
|    iterations           | 177         |
|    time_elapsed         | 461         |
|    total_timesteps      | 362496      |
| train/                  |             |
|    approx_kl            | 0.003550647 |
|    clip_fraction        | 0.159       |
|    clip_range           | 0.2         |
|    entropy_loss         | -10.8       |
|    explained_variance   | 0.0252      |
|    learning_rate        | 0.0003      |
|    loss                 | -1.48       |
|    n_updates            | 1760        |
|    policy_gradient_loss | 0.0127      |
|    std                  | 252         |
|    value_loss           | 0.286       |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 785         |
|    iterations           | 178         |
|    time_elapsed         | 464   

------------------------------------------
| time/                   |              |
|    fps                  | 786          |
|    iterations           | 188          |
|    time_elapsed         | 489          |
|    total_timesteps      | 385024       |
| train/                  |              |
|    approx_kl            | 0.0047416436 |
|    clip_fraction        | 0.128        |
|    clip_range           | 0.2          |
|    entropy_loss         | -11.3        |
|    explained_variance   | 0.0042       |
|    learning_rate        | 0.0003       |
|    loss                 | -1.63        |
|    n_updates            | 1870         |
|    policy_gradient_loss | 0.00763      |
|    std                  | 363          |
|    value_loss           | 0.133        |
------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 786          |
|    iterations           | 189          |
|    time_e

-----------------------------------------
| time/                   |             |
|    fps                  | 784         |
|    iterations           | 199         |
|    time_elapsed         | 519         |
|    total_timesteps      | 407552      |
| train/                  |             |
|    approx_kl            | 0.012387662 |
|    clip_fraction        | 0.102       |
|    clip_range           | 0.2         |
|    entropy_loss         | -11.8       |
|    explained_variance   | 0.0206      |
|    learning_rate        | 0.0003      |
|    loss                 | -1.69       |
|    n_updates            | 1980        |
|    policy_gradient_loss | 0.00311     |
|    std                  | 534         |
|    value_loss           | 0.139       |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 784         |
|    iterations           | 200         |
|    time_elapsed         | 522   

-----------------------------------------
| time/                   |             |
|    fps                  | 784         |
|    iterations           | 210         |
|    time_elapsed         | 547         |
|    total_timesteps      | 430080      |
| train/                  |             |
|    approx_kl            | 0.011874098 |
|    clip_fraction        | 0.123       |
|    clip_range           | 0.2         |
|    entropy_loss         | -12.3       |
|    explained_variance   | 0.0244      |
|    learning_rate        | 0.0003      |
|    loss                 | -1.8        |
|    n_updates            | 2090        |
|    policy_gradient_loss | 0.00629     |
|    std                  | 767         |
|    value_loss           | 0.109       |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 785         |
|    iterations           | 211         |
|    time_elapsed         | 550   

-----------------------------------------
| time/                   |             |
|    fps                  | 784         |
|    iterations           | 221         |
|    time_elapsed         | 576         |
|    total_timesteps      | 452608      |
| train/                  |             |
|    approx_kl            | 0.015799213 |
|    clip_fraction        | 0.13        |
|    clip_range           | 0.2         |
|    entropy_loss         | -12.8       |
|    explained_variance   | 0.0277      |
|    learning_rate        | 0.0003      |
|    loss                 | -1.86       |
|    n_updates            | 2200        |
|    policy_gradient_loss | 0.00801     |
|    std                  | 1.14e+03    |
|    value_loss           | 0.0979      |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 784         |
|    iterations           | 222         |
|    time_elapsed         | 579   

-----------------------------------------
| time/                   |             |
|    fps                  | 785         |
|    iterations           | 232         |
|    time_elapsed         | 605         |
|    total_timesteps      | 475136      |
| train/                  |             |
|    approx_kl            | 0.007339552 |
|    clip_fraction        | 0.0928      |
|    clip_range           | 0.2         |
|    entropy_loss         | -13.4       |
|    explained_variance   | -0.00549    |
|    learning_rate        | 0.0003      |
|    loss                 | -1.9        |
|    n_updates            | 2310        |
|    policy_gradient_loss | 0.00295     |
|    std                  | 1.69e+03    |
|    value_loss           | 0.203       |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 785         |
|    iterations           | 233         |
|    time_elapsed         | 607   

-----------------------------------------
| time/                   |             |
|    fps                  | 785         |
|    iterations           | 243         |
|    time_elapsed         | 633         |
|    total_timesteps      | 497664      |
| train/                  |             |
|    approx_kl            | 0.010617437 |
|    clip_fraction        | 0.12        |
|    clip_range           | 0.2         |
|    entropy_loss         | -13.9       |
|    explained_variance   | 0.0309      |
|    learning_rate        | 0.0003      |
|    loss                 | -2.05       |
|    n_updates            | 2420        |
|    policy_gradient_loss | 0.00578     |
|    std                  | 2.41e+03    |
|    value_loss           | 0.0889      |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 785         |
|    iterations           | 244         |
|    time_elapsed         | 635   

In [71]:
obs = env.reset()

for _ in range(100):
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    print('=====')
    print(obs)
    print(rewards)

=====
[1. 0. 0.]
-9.0
=====
[2. 0. 0.]
-8.0
=====
[3. 0. 0.]
-7.0
=====
[4. 0. 0.]
-6.0
=====
[5. 0. 0.]
-5.0
=====
[6. 0. 0.]
-4.0
=====
[7. 0. 0.]
-3.0
=====
[8. 0. 0.]
-2.0
=====
[9. 0. 0.]
-1.0
=====
[10.  0.  1.]
-5.0
=====
[10.  1.  1.]
-4.5
=====
[10.  0.  1.]
-5.0
=====
[10. -1.  1.]
-5.5
=====
[10.  0.  1.]
-5.0
=====
[10. -1.  1.]
-5.5
=====
[10.  0.  1.]
-5.0
=====
[10. -1.  1.]
-5.5
=====
[10.  0.  1.]
-5.0
=====
[10. -1.  1.]
-5.5
=====
[10.  0.  1.]
-5.0
=====
[10. -1.  1.]
-5.5
=====
[10. -2.  1.]
-6.0
=====
[10. -1.  1.]
-5.5
=====
[10. -2.  1.]
-6.0
=====
[10. -3.  1.]
-6.5
=====
[10. -2.  1.]
-6.0
=====
[10. -1.  1.]
-5.5
=====
[10. -2.  1.]
-6.0
=====
[10. -1.  1.]
-5.5
=====
[10.  0.  1.]
-5.0
=====
[10. -1.  1.]
-5.5
=====
[10. -2.  1.]
-6.0
=====
[10. -3.  1.]
-6.5
=====
[10. -2.  1.]
-6.0
=====
[10. -3.  1.]
-6.5
=====
[10. -2.  1.]
-6.0
=====
[10. -1.  1.]
-5.5
=====
[10.  0.  1.]
-5.0
=====
[10.  1.  1.]
-4.5
=====
[10.  2.  1.]
-4.0
=====
[10.  1.  1.]
-4.5
==

In [None]:
# env = simEnv()

# device = torch.device(
#         "cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# model = SAC('MlpPolicy', env, verbose=1, device=device)
# model = model.learn(total_timesteps=500000, eval_freq=1000)

Using cuda:0 device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [None]:
# obs = env.reset()

# for _ in range(100):
#     action, _states = model.predict(obs)
#     obs, rewards, dones, info = env.step(action)
#     print(obs)
#     print(rewards)