In [1]:
!pip install gym
!pip install stable_baselines3



In [2]:
from gym import Env
from gym.spaces import Discrete, Box
import random
import numpy as np
import torch
from stable_baselines3 import SAC, PPO
import gym

In [36]:
### Define a two line manifold environment ###
class simEnv(Env):
    def __init__(self):
        # Action agent can take
        self.action_space = Discrete(4)
        # Work space
        self.observation_space = Box(np.array([-10, 10]), np.array([-10, 10]))
        # Initial starting point
        self.state = np.array([0, 0])
        # Goal state
        self.goal = np.array([10, 10])
        # Intersection point
        self.intersection = np.array([10, 0])
        # Indicator whether next manifold is reached or not
        self.next_manifold = False
    
    def step(self, action):
        # action == 0 => moving up
        if action == 0:
            self.state[1] += 1
        # action == 1 => moving down 
        elif action == 1:
            self.state[1] -= 1
        # action == 2 => moving right
        elif action == 2:
            self.state[0] += 1
        # action == 3 => moving left
        elif action == 3:
            self.state[0] -= 1
        # If action is not valid, error out
        else:
            return
        if self.state[0] == 10:
            self.next_manifold = True
        done = False
        # Design rewards
        if self.state[1] != 0 and self.state[0] != 10: # Fall off the manifold
            reward = -100
        elif np.array_equal(self.state, self.intersection): # Get to the intersection point
            if self.next_manifold:
                reward = 1 / np.linalg.norm(self.state - self.goal)
            else: 
                reward = 10
        elif np.array_equal(self.state, self.goal): # Reach the goal 
            reward = 100
            done = True
        else: # Direct the agent to intersection point / goal
            if self.next_manifold:
                reward = 1 / np.linalg.norm(self.state - self.goal)
            else: 
                reward = 1 / np.linalg.norm(self.state - self.intersection)
        # Set placeholder for info 
        info = {}
        return self.state, reward, done, info
    
    def render(self):
        pass
    
    def reset(self):
        self.state = np.array([0,0])
        return self.state

In [37]:
env = simEnv()

In [38]:
type(env.action_space.sample())

int

In [39]:
# Test the environment
episodes = 10
for e in range(episodes):
    state = env.reset()
    done = False
    score = 0
    length = 0
    
    while not done and length < 10:
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score += reward
        length += 1
    print("Episode:{} Score:{}".format(e, score))

Episode:0 Score:-899.8888888888889
Episode:1 Score:-899.9
Episode:2 Score:-899.9090909090909
Episode:3 Score:-1000
Episode:4 Score:-799.8
Episode:5 Score:-1000
Episode:6 Score:-399.5031635031635
Episode:7 Score:-1000
Episode:8 Score:-899.9090909090909
Episode:9 Score:-1000


In [40]:
device = torch.device(
        "cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model = PPO('MlpPolicy', env, verbose=1, device=device)
model = model.learn(total_timesteps=500000, eval_freq=1000)

Using cuda:0 device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
-----------------------------
| time/              |      |
|    fps             | 1577 |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 2048 |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 1179         |
|    iterations           | 2            |
|    time_elapsed         | 3            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0055554053 |
|    clip_fraction        | 0.00146      |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.38        |
|    explained_variance   | -0.000883    |
|    learning_rate        | 0.0003       |
|    loss                 | 1.3e+06      |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.00302     |
|    

-----------------------------------------
| time/                   |             |
|    fps                  | 972         |
|    iterations           | 13          |
|    time_elapsed         | 27          |
|    total_timesteps      | 26624       |
| train/                  |             |
|    approx_kl            | 0.008081218 |
|    clip_fraction        | 0.057       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.37       |
|    explained_variance   | 5.96e-08    |
|    learning_rate        | 0.0003      |
|    loss                 | 1.33e+06    |
|    n_updates            | 120         |
|    policy_gradient_loss | -0.0068     |
|    value_loss           | 2.59e+06    |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 969         |
|    iterations           | 14          |
|    time_elapsed         | 29          |
|    total_timesteps      | 28672 

------------------------------------------
| time/                   |              |
|    fps                  | 952          |
|    iterations           | 24           |
|    time_elapsed         | 51           |
|    total_timesteps      | 49152        |
| train/                  |              |
|    approx_kl            | 0.0017207661 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.37        |
|    explained_variance   | 2.98e-07     |
|    learning_rate        | 0.0003       |
|    loss                 | 1.36e+06     |
|    n_updates            | 230          |
|    policy_gradient_loss | -0.000681    |
|    value_loss           | 2.7e+06      |
------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 951          |
|    iterations           | 25           |
|    time_elapsed         | 53           |
|    total_

------------------------------------------
| time/                   |              |
|    fps                  | 909          |
|    iterations           | 35           |
|    time_elapsed         | 78           |
|    total_timesteps      | 71680        |
| train/                  |              |
|    approx_kl            | 0.0021956656 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.31        |
|    explained_variance   | -2.38e-07    |
|    learning_rate        | 0.0003       |
|    loss                 | 1.31e+06     |
|    n_updates            | 340          |
|    policy_gradient_loss | -0.000491    |
|    value_loss           | 2.66e+06     |
------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 909          |
|    iterations           | 36           |
|    time_elapsed         | 81           |
|    total_

------------------------------------------
| time/                   |              |
|    fps                  | 875          |
|    iterations           | 46           |
|    time_elapsed         | 107          |
|    total_timesteps      | 94208        |
| train/                  |              |
|    approx_kl            | 0.0052040117 |
|    clip_fraction        | 0.0126       |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.25        |
|    explained_variance   | -1.19e-07    |
|    learning_rate        | 0.0003       |
|    loss                 | 1.31e+06     |
|    n_updates            | 450          |
|    policy_gradient_loss | -0.00225     |
|    value_loss           | 2.63e+06     |
------------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 872         |
|    iterations           | 47          |
|    time_elapsed         | 110         |
|    total_times

------------------------------------------
| time/                   |              |
|    fps                  | 841          |
|    iterations           | 57           |
|    time_elapsed         | 138          |
|    total_timesteps      | 116736       |
| train/                  |              |
|    approx_kl            | 0.0034705722 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.14        |
|    explained_variance   | -1.19e-07    |
|    learning_rate        | 0.0003       |
|    loss                 | 1.28e+06     |
|    n_updates            | 560          |
|    policy_gradient_loss | -0.0018      |
|    value_loss           | 2.59e+06     |
------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 842           |
|    iterations           | 58            |
|    time_elapsed         | 141           |
|    t

------------------------------------------
| time/                   |              |
|    fps                  | 827          |
|    iterations           | 68           |
|    time_elapsed         | 168          |
|    total_timesteps      | 139264       |
| train/                  |              |
|    approx_kl            | 0.0007959268 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.06        |
|    explained_variance   | 1.19e-07     |
|    learning_rate        | 0.0003       |
|    loss                 | 1.27e+06     |
|    n_updates            | 670          |
|    policy_gradient_loss | 0.00398      |
|    value_loss           | 2.55e+06     |
------------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 827         |
|    iterations           | 69          |
|    time_elapsed         | 170         |
|    total_times

-----------------------------------------
| time/                   |             |
|    fps                  | 836         |
|    iterations           | 79          |
|    time_elapsed         | 193         |
|    total_timesteps      | 161792      |
| train/                  |             |
|    approx_kl            | 0.006797335 |
|    clip_fraction        | 0           |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.2        |
|    explained_variance   | 0           |
|    learning_rate        | 0.0003      |
|    loss                 | 1.25e+06    |
|    n_updates            | 780         |
|    policy_gradient_loss | -0.00282    |
|    value_loss           | 2.52e+06    |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 837         |
|    iterations           | 80          |
|    time_elapsed         | 195         |
|    total_timesteps      | 163840

-------------------------------------------
| time/                   |               |
|    fps                  | 838           |
|    iterations           | 90            |
|    time_elapsed         | 219           |
|    total_timesteps      | 184320        |
| train/                  |               |
|    approx_kl            | 0.00040937398 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.18         |
|    explained_variance   | 1.19e-07      |
|    learning_rate        | 0.0003        |
|    loss                 | 1.23e+06      |
|    n_updates            | 890           |
|    policy_gradient_loss | -1.84e-05     |
|    value_loss           | 2.48e+06      |
-------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 839           |
|    iterations           | 91            |
|    time_elapsed         | 222 

------------------------------------------
| time/                   |              |
|    fps                  | 841          |
|    iterations           | 101          |
|    time_elapsed         | 245          |
|    total_timesteps      | 206848       |
| train/                  |              |
|    approx_kl            | 0.0016040107 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.11        |
|    explained_variance   | -1.19e-07    |
|    learning_rate        | 0.0003       |
|    loss                 | 1.23e+06     |
|    n_updates            | 1000         |
|    policy_gradient_loss | -0.000544    |
|    value_loss           | 2.44e+06     |
------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 839          |
|    iterations           | 102          |
|    time_elapsed         | 248          |
|    total_

------------------------------------------
| time/                   |              |
|    fps                  | 841          |
|    iterations           | 112          |
|    time_elapsed         | 272          |
|    total_timesteps      | 229376       |
| train/                  |              |
|    approx_kl            | 0.0046574865 |
|    clip_fraction        | 0.0354       |
|    clip_range           | 0.2          |
|    entropy_loss         | -1           |
|    explained_variance   | 0            |
|    learning_rate        | 0.0003       |
|    loss                 | 1.21e+06     |
|    n_updates            | 1110         |
|    policy_gradient_loss | -0.00289     |
|    value_loss           | 2.41e+06     |
------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 841          |
|    iterations           | 113          |
|    time_elapsed         | 274          |
|    total_

-------------------------------------------
| time/                   |               |
|    fps                  | 845           |
|    iterations           | 123           |
|    time_elapsed         | 297           |
|    total_timesteps      | 251904        |
| train/                  |               |
|    approx_kl            | 0.00064792274 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.05         |
|    explained_variance   | 2.38e-07      |
|    learning_rate        | 0.0003        |
|    loss                 | 1.19e+06      |
|    n_updates            | 1220          |
|    policy_gradient_loss | -0.000604     |
|    value_loss           | 2.37e+06      |
-------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 844          |
|    iterations           | 124          |
|    time_elapsed         | 300     

------------------------------------------
| time/                   |              |
|    fps                  | 848          |
|    iterations           | 134          |
|    time_elapsed         | 323          |
|    total_timesteps      | 274432       |
| train/                  |              |
|    approx_kl            | 0.0042110044 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.1         |
|    explained_variance   | 2.38e-07     |
|    learning_rate        | 0.0003       |
|    loss                 | 1.14e+06     |
|    n_updates            | 1330         |
|    policy_gradient_loss | -0.00173     |
|    value_loss           | 2.34e+06     |
------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 849          |
|    iterations           | 135          |
|    time_elapsed         | 325          |
|    total_

-------------------------------------------
| time/                   |               |
|    fps                  | 851           |
|    iterations           | 145           |
|    time_elapsed         | 348           |
|    total_timesteps      | 296960        |
| train/                  |               |
|    approx_kl            | 0.00080123683 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.14         |
|    explained_variance   | -1.19e-07     |
|    learning_rate        | 0.0003        |
|    loss                 | 1.14e+06      |
|    n_updates            | 1440          |
|    policy_gradient_loss | -0.000508     |
|    value_loss           | 2.3e+06       |
-------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 850          |
|    iterations           | 146          |
|    time_elapsed         | 351     

-----------------------------------------
| time/                   |             |
|    fps                  | 854         |
|    iterations           | 156         |
|    time_elapsed         | 374         |
|    total_timesteps      | 319488      |
| train/                  |             |
|    approx_kl            | 0.004656518 |
|    clip_fraction        | 0           |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.21       |
|    explained_variance   | 0           |
|    learning_rate        | 0.0003      |
|    loss                 | 1.14e+06    |
|    n_updates            | 1550        |
|    policy_gradient_loss | -0.00203    |
|    value_loss           | 2.27e+06    |
-----------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 854          |
|    iterations           | 157          |
|    time_elapsed         | 376          |
|    total_timesteps      | 3

------------------------------------------
| time/                   |              |
|    fps                  | 856          |
|    iterations           | 167          |
|    time_elapsed         | 399          |
|    total_timesteps      | 342016       |
| train/                  |              |
|    approx_kl            | 0.0011011169 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.03        |
|    explained_variance   | -1.19e-07    |
|    learning_rate        | 0.0003       |
|    loss                 | 1.11e+06     |
|    n_updates            | 1660         |
|    policy_gradient_loss | -0.000275    |
|    value_loss           | 2.23e+06     |
------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 855           |
|    iterations           | 168           |
|    time_elapsed         | 401           |
|    t

------------------------------------------
| time/                   |              |
|    fps                  | 857          |
|    iterations           | 178          |
|    time_elapsed         | 425          |
|    total_timesteps      | 364544       |
| train/                  |              |
|    approx_kl            | 0.0005997446 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.985       |
|    explained_variance   | 0            |
|    learning_rate        | 0.0003       |
|    loss                 | 1.1e+06      |
|    n_updates            | 1770         |
|    policy_gradient_loss | -0.00026     |
|    value_loss           | 2.2e+06      |
------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 858          |
|    iterations           | 179          |
|    time_elapsed         | 427          |
|    total_

------------------------------------------
| time/                   |              |
|    fps                  | 862          |
|    iterations           | 189          |
|    time_elapsed         | 448          |
|    total_timesteps      | 387072       |
| train/                  |              |
|    approx_kl            | 0.0023265795 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.09        |
|    explained_variance   | 0            |
|    learning_rate        | 0.0003       |
|    loss                 | 1.05e+06     |
|    n_updates            | 1880         |
|    policy_gradient_loss | -0.00062     |
|    value_loss           | 2.17e+06     |
------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 862           |
|    iterations           | 190           |
|    time_elapsed         | 451           |
|    t

-------------------------------------------
| time/                   |               |
|    fps                  | 865           |
|    iterations           | 200           |
|    time_elapsed         | 473           |
|    total_timesteps      | 409600        |
| train/                  |               |
|    approx_kl            | 0.00065068854 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.06         |
|    explained_variance   | 1.19e-07      |
|    learning_rate        | 0.0003        |
|    loss                 | 1.04e+06      |
|    n_updates            | 1990          |
|    policy_gradient_loss | 4.62e-05      |
|    value_loss           | 2.13e+06      |
-------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 865          |
|    iterations           | 201          |
|    time_elapsed         | 475     

------------------------------------------
| time/                   |              |
|    fps                  | 861          |
|    iterations           | 211          |
|    time_elapsed         | 501          |
|    total_timesteps      | 432128       |
| train/                  |              |
|    approx_kl            | 0.0015683668 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.873       |
|    explained_variance   | 1.19e-07     |
|    learning_rate        | 0.0003       |
|    loss                 | 1.06e+06     |
|    n_updates            | 2100         |
|    policy_gradient_loss | -0.00116     |
|    value_loss           | 2.1e+06      |
------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 860          |
|    iterations           | 212          |
|    time_elapsed         | 504          |
|    total_

-------------------------------------------
| time/                   |               |
|    fps                  | 859           |
|    iterations           | 222           |
|    time_elapsed         | 529           |
|    total_timesteps      | 454656        |
| train/                  |               |
|    approx_kl            | 0.00057588343 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -0.778        |
|    explained_variance   | -1.19e-07     |
|    learning_rate        | 0.0003        |
|    loss                 | 1.03e+06      |
|    n_updates            | 2210          |
|    policy_gradient_loss | -0.000279     |
|    value_loss           | 2.07e+06      |
-------------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 859         |
|    iterations           | 223         |
|    time_elapsed         | 531         

------------------------------------------
| time/                   |              |
|    fps                  | 860          |
|    iterations           | 233          |
|    time_elapsed         | 554          |
|    total_timesteps      | 477184       |
| train/                  |              |
|    approx_kl            | 0.0001744245 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.76        |
|    explained_variance   | 0            |
|    learning_rate        | 0.0003       |
|    loss                 | 1e+06        |
|    n_updates            | 2320         |
|    policy_gradient_loss | -0.000483    |
|    value_loss           | 2.03e+06     |
------------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 860         |
|    iterations           | 234         |
|    time_elapsed         | 556         |
|    total_times

------------------------------------------
| time/                   |              |
|    fps                  | 861          |
|    iterations           | 244          |
|    time_elapsed         | 579          |
|    total_timesteps      | 499712       |
| train/                  |              |
|    approx_kl            | 0.0003172434 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.671       |
|    explained_variance   | 0            |
|    learning_rate        | 0.0003       |
|    loss                 | 9.94e+05     |
|    n_updates            | 2430         |
|    policy_gradient_loss | 0.00277      |
|    value_loss           | 2e+06        |
------------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 861         |
|    iterations           | 245         |
|    time_elapsed         | 582         |
|    total_times

In [41]:
obs = env.reset()
for _ in range(100):
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    print(obs)
    print(rewards)

[0 1]
-100
[1 1]
-100
[2 1]
-100
[2 2]
-100
[1 2]
-100
[1 1]
-100
[0 1]
-100
[1 1]
-100
[1 2]
-100
[1 1]
-100
[1 2]
-100
[0 2]
-100
[1 2]
-100
[0 2]
-100
[0 3]
-100
[-1  3]
-100
[0 3]
-100
[0 4]
-100
[1 4]
-100
[0 4]
-100
[1 4]
-100
[0 4]
-100
[-1  4]
-100
[-2  4]
-100
[-1  4]
-100
[-2  4]
-100
[-3  4]
-100
[-4  4]
-100
[-5  4]
-100
[-6  4]
-100
[-7  4]
-100
[-8  4]
-100
[-9  4]
-100
[-10   4]
-100
[-11   4]
-100
[-12   4]
-100
[-13   4]
-100
[-14   4]
-100
[-15   4]
-100
[-16   4]
-100
[-17   4]
-100
[-18   4]
-100
[-19   4]
-100
[-20   4]
-100
[-21   4]
-100
[-22   4]
-100
[-23   4]
-100
[-24   4]
-100
[-25   4]
-100
[-26   4]
-100
[-27   4]
-100
[-28   4]
-100
[-28   5]
-100
[-27   5]
-100
[-28   5]
-100
[-29   5]
-100
[-30   5]
-100
[-31   5]
-100
[-32   5]
-100
[-32   6]
-100
[-33   6]
-100
[-34   6]
-100
[-35   6]
-100
[-36   6]
-100
[-37   6]
-100
[-38   6]
-100
[-39   6]
-100
[-40   6]
-100
[-41   6]
-100
[-42   6]
-100
[-41   6]
-100
[-42   6]
-100
[-43   6]
-100
[-44   6]
-10