In [5]:
!pip install gym
!pip install stable_baselines3



In [6]:
from gym import Env
from gym.spaces import Discrete, Box
import random
import numpy as np
import torch
from stable_baselines3 import SAC, PPO
import gym

In [7]:
t = np.array([0.0, 0.0])
type(t[0])

numpy.float64

In [15]:
### Define a two line manifold environment ###
class simEnv(Env):
    def __init__(self):
        # Action agent can take
        self.action_space = Box(np.array([-1, -1]), np.array([1, 1]))
        # Work space
        self.observation_space = Box(np.array([-10, -10]), np.array([10, 10]))
        # Initial starting point
        self.state = np.array([0.0, 0.0])
        # Goal state
        self.goal = np.array([10.0, 10.0])
        # Intersection point
        self.intersection = np.array([10.0, 0.0])
        # Indicator whether next manifold is reached or not
        self.next_manifold = False
    
    def step(self, action):
        ### Projection ###
        if self.next_manifold:
            _action = np.array([0.0, action[1]])
        else:
            _action = np.array([action[0], 0.0])
        self.state += _action
        if self.state[0] == 10:
            self.next_manifold = True
        done = False
        # Design rewards
        if self.state[1] != 0 and self.state[0] != 10: # Fall off the manifold
            # negative of l2 distance to current manifold reward
            if not self.next_manifold:
                reward = -np.abs(self.state[1])
            else:
                reward = -np.abs(self.state[0] - 10)
            reward = -100
            done = True
        elif np.array_equal(self.state, self.intersection): # Get to the intersection point
            if self.next_manifold:
                reward = 1 / np.linalg.norm(self.state - self.goal)
            else:
                reward = 10
        elif np.array_equal(self.state, self.goal): # Reach the goal 
            reward = 100
            done = True
        else: # Direct the agent to intersection point / goal
            if self.next_manifold:
                reward = -np.linalg.norm(self.state - self.goal)
#                 reward = 1 / np.linalg.norm(self.state - self.goal)
            else: 
                reward = -np.linalg.norm(self.state - self.intersection)
#                 reward = 1 / np.linalg.norm(self.state - self.intersection)
        # Set placeholder for info 
        info = {}
        return self.state, reward, done, info
    
    def render(self):
        pass
    
    def reset(self):
        self.state = np.array([0.0,0.0])
        self.next_manifold = False
        return self.state

In [16]:
env = simEnv()

In [17]:
device = torch.device(
        "cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model = PPO('MlpPolicy', env, verbose=1, device=device)
model = model.learn(total_timesteps=500000, eval_freq=1000)

Using cuda:0 device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
-----------------------------
| time/              |      |
|    fps             | 1531 |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 2048 |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 1130         |
|    iterations           | 2            |
|    time_elapsed         | 3            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0068806177 |
|    clip_fraction        | 0.0671       |
|    clip_range           | 0.2          |
|    entropy_loss         | -2.85        |
|    explained_variance   | -0.00666     |
|    learning_rate        | 0.0003       |
|    loss                 | 8.83e+04     |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.00439     |
|    

-----------------------------------------
| time/                   |             |
|    fps                  | 833         |
|    iterations           | 12          |
|    time_elapsed         | 29          |
|    total_timesteps      | 24576       |
| train/                  |             |
|    approx_kl            | 0.020257652 |
|    clip_fraction        | 0.178       |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.54       |
|    explained_variance   | 2.93e-05    |
|    learning_rate        | 0.0003      |
|    loss                 | 1.01        |
|    n_updates            | 110         |
|    policy_gradient_loss | -0.0217     |
|    std                  | 0.862       |
|    value_loss           | 2.39        |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 823         |
|    iterations           | 13          |
|    time_elapsed         | 32    

-----------------------------------------
| time/                   |             |
|    fps                  | 803         |
|    iterations           | 23          |
|    time_elapsed         | 58          |
|    total_timesteps      | 47104       |
| train/                  |             |
|    approx_kl            | 0.014747739 |
|    clip_fraction        | 0.165       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.87       |
|    explained_variance   | -0.00922    |
|    learning_rate        | 0.0003      |
|    loss                 | 0.152       |
|    n_updates            | 220         |
|    policy_gradient_loss | 0.00424     |
|    std                  | 0.674       |
|    value_loss           | 0.381       |
-----------------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 802        |
|    iterations           | 24         |
|    time_elapsed         | 61        

----------------------------------------
| time/                   |            |
|    fps                  | 808        |
|    iterations           | 34         |
|    time_elapsed         | 86         |
|    total_timesteps      | 69632      |
| train/                  |            |
|    approx_kl            | 0.01502686 |
|    clip_fraction        | 0.23       |
|    clip_range           | 0.2        |
|    entropy_loss         | -1.33      |
|    explained_variance   | -0.00252   |
|    learning_rate        | 0.0003     |
|    loss                 | 0.114      |
|    n_updates            | 330        |
|    policy_gradient_loss | -0.00534   |
|    std                  | 0.589      |
|    value_loss           | 0.273      |
----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 809         |
|    iterations           | 35          |
|    time_elapsed         | 88          |
|    total_

------------------------------------------
| time/                   |              |
|    fps                  | 813          |
|    iterations           | 45           |
|    time_elapsed         | 113          |
|    total_timesteps      | 92160        |
| train/                  |              |
|    approx_kl            | 0.0146023035 |
|    clip_fraction        | 0.27         |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.799       |
|    explained_variance   | -0.00409     |
|    learning_rate        | 0.0003       |
|    loss                 | 0.0251       |
|    n_updates            | 440          |
|    policy_gradient_loss | 0.0182       |
|    std                  | 0.53         |
|    value_loss           | 0.0843       |
------------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 813         |
|    iterations           | 46          |
|    time_elaps

-----------------------------------------
| time/                   |             |
|    fps                  | 816         |
|    iterations           | 56          |
|    time_elapsed         | 140         |
|    total_timesteps      | 114688      |
| train/                  |             |
|    approx_kl            | 0.056721803 |
|    clip_fraction        | 0.334       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.426      |
|    explained_variance   | -0.00878    |
|    learning_rate        | 0.0003      |
|    loss                 | -0.00972    |
|    n_updates            | 550         |
|    policy_gradient_loss | 0.0371      |
|    std                  | 0.53        |
|    value_loss           | 0.0369      |
-----------------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 816        |
|    iterations           | 57         |
|    time_elapsed         | 142       

-----------------------------------------
| time/                   |             |
|    fps                  | 817         |
|    iterations           | 67          |
|    time_elapsed         | 167         |
|    total_timesteps      | 137216      |
| train/                  |             |
|    approx_kl            | 0.047867052 |
|    clip_fraction        | 0.375       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.223      |
|    explained_variance   | 0.000244    |
|    learning_rate        | 0.0003      |
|    loss                 | -0.00524    |
|    n_updates            | 660         |
|    policy_gradient_loss | 0.00689     |
|    std                  | 0.533       |
|    value_loss           | 0.0439      |
-----------------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 818        |
|    iterations           | 68         |
|    time_elapsed         | 170       

----------------------------------------
| time/                   |            |
|    fps                  | 819        |
|    iterations           | 78         |
|    time_elapsed         | 194        |
|    total_timesteps      | 159744     |
| train/                  |            |
|    approx_kl            | 0.08102324 |
|    clip_fraction        | 0.439      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.0196    |
|    explained_variance   | 0.00123    |
|    learning_rate        | 0.0003     |
|    loss                 | 0.0996     |
|    n_updates            | 770        |
|    policy_gradient_loss | 0.0147     |
|    std                  | 0.539      |
|    value_loss           | 0.0306     |
----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 819         |
|    iterations           | 79          |
|    time_elapsed         | 197         |
|    total_

----------------------------------------
| time/                   |            |
|    fps                  | 820        |
|    iterations           | 89         |
|    time_elapsed         | 222        |
|    total_timesteps      | 182272     |
| train/                  |            |
|    approx_kl            | 0.04076627 |
|    clip_fraction        | 0.371      |
|    clip_range           | 0.2        |
|    entropy_loss         | 0.216      |
|    explained_variance   | 0.00106    |
|    learning_rate        | 0.0003     |
|    loss                 | 0.0541     |
|    n_updates            | 880        |
|    policy_gradient_loss | 0.00657    |
|    std                  | 0.539      |
|    value_loss           | 0.0323     |
----------------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 819        |
|    iterations           | 90         |
|    time_elapsed         | 224        |
|    total_times

----------------------------------------
| time/                   |            |
|    fps                  | 819        |
|    iterations           | 100        |
|    time_elapsed         | 250        |
|    total_timesteps      | 204800     |
| train/                  |            |
|    approx_kl            | 0.13126053 |
|    clip_fraction        | 0.478      |
|    clip_range           | 0.2        |
|    entropy_loss         | 0.307      |
|    explained_variance   | 3.11e-05   |
|    learning_rate        | 0.0003     |
|    loss                 | 0.384      |
|    n_updates            | 990        |
|    policy_gradient_loss | 0.0624     |
|    std                  | 0.537      |
|    value_loss           | 0.019      |
----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 819         |
|    iterations           | 101         |
|    time_elapsed         | 252         |
|    total_

----------------------------------------
| time/                   |            |
|    fps                  | 823        |
|    iterations           | 111        |
|    time_elapsed         | 275        |
|    total_timesteps      | 227328     |
| train/                  |            |
|    approx_kl            | 0.13510087 |
|    clip_fraction        | 0.52       |
|    clip_range           | 0.2        |
|    entropy_loss         | 0.416      |
|    explained_variance   | -0.000431  |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0104    |
|    n_updates            | 1100       |
|    policy_gradient_loss | 0.0965     |
|    std                  | 0.535      |
|    value_loss           | 0.0158     |
----------------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 824        |
|    iterations           | 112        |
|    time_elapsed         | 278        |
|    total_times

----------------------------------------
| time/                   |            |
|    fps                  | 829        |
|    iterations           | 122        |
|    time_elapsed         | 301        |
|    total_timesteps      | 249856     |
| train/                  |            |
|    approx_kl            | 0.17246573 |
|    clip_fraction        | 0.498      |
|    clip_range           | 0.2        |
|    entropy_loss         | 0.514      |
|    explained_variance   | -0.000654  |
|    learning_rate        | 0.0003     |
|    loss                 | 0.0326     |
|    n_updates            | 1210       |
|    policy_gradient_loss | 0.0794     |
|    std                  | 0.54       |
|    value_loss           | 0.012      |
----------------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 830        |
|    iterations           | 123        |
|    time_elapsed         | 303        |
|    total_times

---------------------------------------
| time/                   |           |
|    fps                  | 834       |
|    iterations           | 133       |
|    time_elapsed         | 326       |
|    total_timesteps      | 272384    |
| train/                  |           |
|    approx_kl            | 0.1552765 |
|    clip_fraction        | 0.484     |
|    clip_range           | 0.2       |
|    entropy_loss         | 0.734     |
|    explained_variance   | 0.000544  |
|    learning_rate        | 0.0003    |
|    loss                 | 0.007     |
|    n_updates            | 1320      |
|    policy_gradient_loss | 0.0607    |
|    std                  | 0.528     |
|    value_loss           | 0.00776   |
---------------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 835        |
|    iterations           | 134        |
|    time_elapsed         | 328        |
|    total_timesteps      | 274432 

----------------------------------------
| time/                   |            |
|    fps                  | 838        |
|    iterations           | 144        |
|    time_elapsed         | 351        |
|    total_timesteps      | 294912     |
| train/                  |            |
|    approx_kl            | 0.16441903 |
|    clip_fraction        | 0.627      |
|    clip_range           | 0.2        |
|    entropy_loss         | 0.767      |
|    explained_variance   | -0.000319  |
|    learning_rate        | 0.0003     |
|    loss                 | 0.0994     |
|    n_updates            | 1430       |
|    policy_gradient_loss | 0.103      |
|    std                  | 0.526      |
|    value_loss           | 0.00883    |
----------------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 839        |
|    iterations           | 145        |
|    time_elapsed         | 353        |
|    total_times

----------------------------------------
| time/                   |            |
|    fps                  | 842        |
|    iterations           | 155        |
|    time_elapsed         | 376        |
|    total_timesteps      | 317440     |
| train/                  |            |
|    approx_kl            | 0.12044103 |
|    clip_fraction        | 0.59       |
|    clip_range           | 0.2        |
|    entropy_loss         | 0.837      |
|    explained_variance   | 0.000546   |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0056    |
|    n_updates            | 1540       |
|    policy_gradient_loss | 0.0701     |
|    std                  | 0.516      |
|    value_loss           | 0.0114     |
----------------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 842        |
|    iterations           | 156        |
|    time_elapsed         | 379        |
|    total_times

----------------------------------------
| time/                   |            |
|    fps                  | 845        |
|    iterations           | 166        |
|    time_elapsed         | 402        |
|    total_timesteps      | 339968     |
| train/                  |            |
|    approx_kl            | 0.21594197 |
|    clip_fraction        | 0.554      |
|    clip_range           | 0.2        |
|    entropy_loss         | 0.943      |
|    explained_variance   | 4.49e-05   |
|    learning_rate        | 0.0003     |
|    loss                 | 0.0668     |
|    n_updates            | 1650       |
|    policy_gradient_loss | 0.135      |
|    std                  | 0.507      |
|    value_loss           | 0.00473    |
----------------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 845        |
|    iterations           | 167        |
|    time_elapsed         | 404        |
|    total_times

--------------------------------------
| time/                   |          |
|    fps                  | 847      |
|    iterations           | 177      |
|    time_elapsed         | 427      |
|    total_timesteps      | 362496   |
| train/                  |          |
|    approx_kl            | 38.88595 |
|    clip_fraction        | 0.759    |
|    clip_range           | 0.2      |
|    entropy_loss         | 1.06     |
|    explained_variance   | 0.00276  |
|    learning_rate        | 0.0003   |
|    loss                 | 0.295    |
|    n_updates            | 1760     |
|    policy_gradient_loss | 0.163    |
|    std                  | 0.502    |
|    value_loss           | 0.258    |
--------------------------------------
---------------------------------------
| time/                   |           |
|    fps                  | 848       |
|    iterations           | 178       |
|    time_elapsed         | 429       |
|    total_timesteps      | 364544    |
| train/           

----------------------------------------
| time/                   |            |
|    fps                  | 850        |
|    iterations           | 188        |
|    time_elapsed         | 452        |
|    total_timesteps      | 385024     |
| train/                  |            |
|    approx_kl            | 0.10969092 |
|    clip_fraction        | 0.431      |
|    clip_range           | 0.2        |
|    entropy_loss         | 0.966      |
|    explained_variance   | 9e-06      |
|    learning_rate        | 0.0003     |
|    loss                 | 1.08       |
|    n_updates            | 1870       |
|    policy_gradient_loss | 0.0244     |
|    std                  | 0.508      |
|    value_loss           | 10.3       |
----------------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 850        |
|    iterations           | 189        |
|    time_elapsed         | 455        |
|    total_times

----------------------------------------
| time/                   |            |
|    fps                  | 852        |
|    iterations           | 199        |
|    time_elapsed         | 478        |
|    total_timesteps      | 407552     |
| train/                  |            |
|    approx_kl            | 0.07308593 |
|    clip_fraction        | 0.412      |
|    clip_range           | 0.2        |
|    entropy_loss         | 0.878      |
|    explained_variance   | 1.29e-05   |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0212    |
|    n_updates            | 1980       |
|    policy_gradient_loss | 0.0128     |
|    std                  | 0.524      |
|    value_loss           | 0.243      |
----------------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 852        |
|    iterations           | 200        |
|    time_elapsed         | 480        |
|    total_times

---------------------------------------
| time/                   |           |
|    fps                  | 854       |
|    iterations           | 210       |
|    time_elapsed         | 503       |
|    total_timesteps      | 430080    |
| train/                  |           |
|    approx_kl            | 0.0759561 |
|    clip_fraction        | 0.361     |
|    clip_range           | 0.2       |
|    entropy_loss         | 0.834     |
|    explained_variance   | 0.0136    |
|    learning_rate        | 0.0003    |
|    loss                 | 0.00987   |
|    n_updates            | 2090      |
|    policy_gradient_loss | 0.000393  |
|    std                  | 0.528     |
|    value_loss           | 0.123     |
---------------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 854        |
|    iterations           | 211        |
|    time_elapsed         | 505        |
|    total_timesteps      | 432128 

-----------------------------------------
| time/                   |             |
|    fps                  | 856         |
|    iterations           | 221         |
|    time_elapsed         | 528         |
|    total_timesteps      | 452608      |
| train/                  |             |
|    approx_kl            | 0.061620496 |
|    clip_fraction        | 0.367       |
|    clip_range           | 0.2         |
|    entropy_loss         | 0.817       |
|    explained_variance   | 0.0701      |
|    learning_rate        | 0.0003      |
|    loss                 | 0.112       |
|    n_updates            | 2200        |
|    policy_gradient_loss | -0.00735    |
|    std                  | 0.536       |
|    value_loss           | 0.183       |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 856         |
|    iterations           | 222         |
|    time_elapsed         | 531   

-----------------------------------------
| time/                   |             |
|    fps                  | 857         |
|    iterations           | 232         |
|    time_elapsed         | 554         |
|    total_timesteps      | 475136      |
| train/                  |             |
|    approx_kl            | 0.042024303 |
|    clip_fraction        | 0.363       |
|    clip_range           | 0.2         |
|    entropy_loss         | 0.713       |
|    explained_variance   | 0.0626      |
|    learning_rate        | 0.0003      |
|    loss                 | 0.0503      |
|    n_updates            | 2310        |
|    policy_gradient_loss | 0.00465     |
|    std                  | 0.581       |
|    value_loss           | 0.146       |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 857         |
|    iterations           | 233         |
|    time_elapsed         | 556   

--------------------------------------
| time/                   |          |
|    fps                  | 858      |
|    iterations           | 243      |
|    time_elapsed         | 579      |
|    total_timesteps      | 497664   |
| train/                  |          |
|    approx_kl            | 0.034869 |
|    clip_fraction        | 0.346    |
|    clip_range           | 0.2      |
|    entropy_loss         | 0.821    |
|    explained_variance   | 0.0582   |
|    learning_rate        | 0.0003   |
|    loss                 | 0.0112   |
|    n_updates            | 2420     |
|    policy_gradient_loss | -0.0153  |
|    std                  | 0.546    |
|    value_loss           | 0.11     |
--------------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 859        |
|    iterations           | 244        |
|    time_elapsed         | 581        |
|    total_timesteps      | 499712     |
| train/     

In [18]:
obs = env.reset()

for _ in range(100):
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    print('=====')
    print(obs)
    print(rewards)

=====
[1. 0.]
-9.0
=====
[2. 0.]
-8.0
=====
[3. 0.]
-7.0
=====
[4. 0.]
-6.0
=====
[5. 0.]
-5.0
=====
[6. 0.]
-4.0
=====
[7. 0.]
-3.0
=====
[7.78907472 0.        ]
-2.210925281047821
=====
[8.37890512 0.        ]
-1.6210948824882507
=====
[8.78755832 0.        ]
-1.2124416828155518
=====
[9.10008737 0.        ]
-0.8999126255512238
=====
[9.32068688 0.        ]
-0.6793131232261658
=====
[9.55418406 0.        ]
-0.44581593573093414
=====
[9.68067454 0.        ]
-0.3193254619836807
=====
[9.75936401 0.        ]
-0.24063599109649658
=====
[9.87143537 0.        ]
-0.12856463342905045
=====
[9.94522381 0.        ]
-0.05477619171142578
=====
[10.02678749  0.        ]
-0.026787489652633667
=====
[10.11129017  0.        ]
-0.11129017174243927
=====
[10.12255942  0.        ]
-0.12255942448973656
=====
[10.1243435  0.       ]
-0.12434350326657295
=====
[10.12768865  0.        ]
-0.12768864631652832
=====
[10.17239237  0.        ]
-0.1723923720419407
=====
[10.1929385  0.       ]
-0.192938502877950

In [None]:
# env = simEnv()

# device = torch.device(
#         "cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# model = SAC('MlpPolicy', env, verbose=1, device=device)
# model = model.learn(total_timesteps=500000, eval_freq=1000)

Using cuda:0 device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [None]:
# obs = env.reset()

# for _ in range(100):
#     action, _states = model.predict(obs)
#     obs, rewards, dones, info = env.step(action)
#     print(obs)
#     print(rewards)