### 1. Import Dependencies

In [1]:
# gym dependencies
import gym 
from gym import Env
from gym.spaces import Discrete, Box, Dict, Tuple, MultiBinary, MultiDiscrete 

# helpers
import numpy as np
import random
import os

# stable baselines dependencies
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

### 2. Types of Spaces 

In [3]:
Discrete(3).sample() # 3 actions that map to each of these 3 value

1

In [4]:
# for continuos variables

# 0 is the lower value
# 1 is the upper value
# (3,3) is the shape

Box(0,1, shape=(3,3))

# so it is an 3x3 array, a list of lists

Box([[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]], [[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]], (3, 3), float32)

In [5]:
Box(0,1, shape=(3,3)).sample()
# The output represents 3 arrays with 3 values each

array([[0.04529582, 0.13172327, 0.33639795],
       [0.66682327, 0.3226328 , 0.4317734 ],
       [0.6358906 , 0.4071296 , 0.9997765 ]], dtype=float32)

In [6]:
Box(0,1, shape=(3,)).sample()
# The output represents 1 array with 3 values

array([0.72536945, 0.35852256, 0.19191903], dtype=float32)

In [7]:
# tuple space allow to combine different spaces
Tuple((Discrete(3), Box(0, 1, shape=(3,))))

Tuple(Discrete(3), Box([0. 0. 0.], [1. 1. 1.], (3,), float32))

In [8]:
# returning the discrete and the box random values
Tuple((Discrete(3), Box(0, 1, shape=(3,)))).sample()

(0, array([0.05210012, 0.7755449 , 0.537649  ], dtype=float32))

In [9]:
# passing a dictionary that has 2 keyes
Dict({'height':Discrete(2), 'speed':Box(0,100,shape=(1,))})

Dict(height:Discrete(2), speed:Box([0.], [100.], (1,), float32))

In [10]:
Dict({'height':Discrete(2), 'speed':Box(0,100,shape=(1,))}).sample()

OrderedDict([('height', 0), ('speed', array([40.82437], dtype=float32))])

In [11]:
MultiBinary(4).sample()

array([1, 0, 1, 1], dtype=int8)

In [12]:
# returns discrete values that are smaller than the provided values
MultiDiscrete([5, 2, 2]).sample()

array([0, 1, 0], dtype=int64)

### 3. Building an Environment

- Build an agent to give us the best shower possible
- Randomly temperature 
- 37 and 39 degrees

In [None]:
0: decrease temp by 1 
1: no change
2: increase temp by one 

In [48]:
class ShowerEnv(Env):
    def __init__(self):
        # Actions we can take, down, stay, up
        self.action_space = Discrete(3)
        # Temperature array
        self.observation_space = Box(low=np.array([0]), high=np.array([100]))
        # Set start temp
        self.state = 38 + random.randint(-3,3)
        # Set shower length
        self.shower_length = 60
        
    def step(self, action):
        # Apply action
        # 0 -1 = -1 temperature
        # 1 -1 = 0 
        # 2 -1 = 1 temperature 
        self.state += action -1 
        # Reduce shower length by 1 second
        self.shower_length -= 1 
        
        # Calculate reward
        if self.state >=37 and self.state <=39: 
            reward =1 
        else: 
            reward = -1 
        
        # Check if shower is done
        if self.shower_length <= 0: 
            done = True
        else:
            done = False
        
        # Apply temperature noise
        #self.state += random.randint(-1,1)
        # Set placeholder for info
        info = {}
        
        # Return step information
        return self.state, reward, done, info

    def render(self):
        # Implement viz
        pass
    
    def reset(self):
        # Reset shower temperature
        self.state = np.array([38 + random.randint(-3,3)]).astype(np.float32)   
        # Reset shower time
        self.shower_length = 60 
        return self.state

In [49]:
env=ShowerEnv()

In [77]:
env.observation_space.sample()

array([28.749876], dtype=float32)

In [73]:
env.reset()

array([41.], dtype=float32)

In [57]:
from stable_baselines3.common.env_checker import check_env

In [58]:
check_env(env, warn=True)

### 4. Test Environment

In [59]:
episodes = 5
for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0 
    
    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score+=reward
    print('Episode:{} Score:{}'.format(episode, score))
env.close()

Episode:1 Score:-60
Episode:2 Score:-20
Episode:3 Score:-60
Episode:4 Score:-54
Episode:5 Score:-14


In [60]:
env.close()

### 5. Trian Model

In [62]:
log_path = os.path.join('Training', 'Logs')

In [63]:
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log=log_path)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [64]:
model.learn(total_timesteps=400000)

Logging to Training\Logs\PPO_9
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 60       |
|    ep_rew_mean     | -30.5    |
| time/              |          |
|    fps             | 1164     |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | -30.1       |
| time/                   |             |
|    fps                  | 645         |
|    iterations           | 2           |
|    time_elapsed         | 6           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.009792256 |
|    clip_fraction        | 0.00747     |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.09       |
|    explained_variance   | 0.000138    |

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | -30.1       |
| time/                   |             |
|    fps                  | 501         |
|    iterations           | 11          |
|    time_elapsed         | 44          |
|    total_timesteps      | 22528       |
| train/                  |             |
|    approx_kl            | 0.005578422 |
|    clip_fraction        | 0.025       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.09       |
|    explained_variance   | -0.00921    |
|    learning_rate        | 0.0003      |
|    loss                 | 40.6        |
|    n_updates            | 100         |
|    policy_gradient_loss | -0.00305    |
|    value_loss           | 85.8        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60    

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | -22.9       |
| time/                   |             |
|    fps                  | 506         |
|    iterations           | 21          |
|    time_elapsed         | 84          |
|    total_timesteps      | 43008       |
| train/                  |             |
|    approx_kl            | 0.006107846 |
|    clip_fraction        | 0.0304      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.09       |
|    explained_variance   | 5.54e-05    |
|    learning_rate        | 0.0003      |
|    loss                 | 28.1        |
|    n_updates            | 200         |
|    policy_gradient_loss | -0.00133    |
|    value_loss           | 60.5        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60    

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | -18         |
| time/                   |             |
|    fps                  | 502         |
|    iterations           | 31          |
|    time_elapsed         | 126         |
|    total_timesteps      | 63488       |
| train/                  |             |
|    approx_kl            | 0.009343628 |
|    clip_fraction        | 0.148       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.02       |
|    explained_variance   | 0.0211      |
|    learning_rate        | 0.0003      |
|    loss                 | 35.9        |
|    n_updates            | 300         |
|    policy_gradient_loss | -0.00992    |
|    value_loss           | 61.7        |
-----------------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 60      

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 24.1        |
| time/                   |             |
|    fps                  | 502         |
|    iterations           | 41          |
|    time_elapsed         | 167         |
|    total_timesteps      | 83968       |
| train/                  |             |
|    approx_kl            | 0.014245154 |
|    clip_fraction        | 0.243       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.01       |
|    explained_variance   | -0.0054     |
|    learning_rate        | 0.0003      |
|    loss                 | 20.7        |
|    n_updates            | 400         |
|    policy_gradient_loss | -0.0196     |
|    value_loss           | 38.9        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60    

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 34.6        |
| time/                   |             |
|    fps                  | 508         |
|    iterations           | 51          |
|    time_elapsed         | 205         |
|    total_timesteps      | 104448      |
| train/                  |             |
|    approx_kl            | 0.012759803 |
|    clip_fraction        | 0.175       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.897      |
|    explained_variance   | -0.000278   |
|    learning_rate        | 0.0003      |
|    loss                 | 34.7        |
|    n_updates            | 500         |
|    policy_gradient_loss | -0.00874    |
|    value_loss           | 63.8        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60    

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 51.1        |
| time/                   |             |
|    fps                  | 506         |
|    iterations           | 61          |
|    time_elapsed         | 246         |
|    total_timesteps      | 124928      |
| train/                  |             |
|    approx_kl            | 0.008449558 |
|    clip_fraction        | 0.146       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.822      |
|    explained_variance   | 2.26e-06    |
|    learning_rate        | 0.0003      |
|    loss                 | 43          |
|    n_updates            | 600         |
|    policy_gradient_loss | 0.00745     |
|    value_loss           | 94.5        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60    

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 50.2        |
| time/                   |             |
|    fps                  | 504         |
|    iterations           | 71          |
|    time_elapsed         | 288         |
|    total_timesteps      | 145408      |
| train/                  |             |
|    approx_kl            | 0.014290873 |
|    clip_fraction        | 0.168       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.762      |
|    explained_variance   | -5.6e-06    |
|    learning_rate        | 0.0003      |
|    loss                 | 40.7        |
|    n_updates            | 700         |
|    policy_gradient_loss | 0.00402     |
|    value_loss           | 93.9        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60    

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 50          |
| time/                   |             |
|    fps                  | 502         |
|    iterations           | 81          |
|    time_elapsed         | 330         |
|    total_timesteps      | 165888      |
| train/                  |             |
|    approx_kl            | 0.006971726 |
|    clip_fraction        | 0.136       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.709      |
|    explained_variance   | -5.84e-05   |
|    learning_rate        | 0.0003      |
|    loss                 | 48.5        |
|    n_updates            | 800         |
|    policy_gradient_loss | 0.00329     |
|    value_loss           | 96.8        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60    

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 55.6        |
| time/                   |             |
|    fps                  | 501         |
|    iterations           | 91          |
|    time_elapsed         | 371         |
|    total_timesteps      | 186368      |
| train/                  |             |
|    approx_kl            | 0.012720045 |
|    clip_fraction        | 0.147       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.671      |
|    explained_variance   | 7.63e-06    |
|    learning_rate        | 0.0003      |
|    loss                 | 68          |
|    n_updates            | 900         |
|    policy_gradient_loss | 0.0105      |
|    value_loss           | 102         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60    

------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 60           |
|    ep_rew_mean          | 54.6         |
| time/                   |              |
|    fps                  | 498          |
|    iterations           | 101          |
|    time_elapsed         | 415          |
|    total_timesteps      | 206848       |
| train/                  |              |
|    approx_kl            | 0.0076028947 |
|    clip_fraction        | 0.188        |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.728       |
|    explained_variance   | 1.45e-05     |
|    learning_rate        | 0.0003       |
|    loss                 | 68           |
|    n_updates            | 1000         |
|    policy_gradient_loss | 0.0143       |
|    value_loss           | 111          |
------------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_m

----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 60         |
|    ep_rew_mean          | 56.9       |
| time/                   |            |
|    fps                  | 495        |
|    iterations           | 111        |
|    time_elapsed         | 458        |
|    total_timesteps      | 227328     |
| train/                  |            |
|    approx_kl            | 0.00933174 |
|    clip_fraction        | 0.153      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.636     |
|    explained_variance   | 8.34e-07   |
|    learning_rate        | 0.0003     |
|    loss                 | 46.5       |
|    n_updates            | 1100       |
|    policy_gradient_loss | 0.0128     |
|    value_loss           | 113        |
----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_m

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 49.2        |
| time/                   |             |
|    fps                  | 493         |
|    iterations           | 121         |
|    time_elapsed         | 501         |
|    total_timesteps      | 247808      |
| train/                  |             |
|    approx_kl            | 0.008093576 |
|    clip_fraction        | 0.141       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.618      |
|    explained_variance   | 6.32e-06    |
|    learning_rate        | 0.0003      |
|    loss                 | 46.8        |
|    n_updates            | 1200        |
|    policy_gradient_loss | 0.00337     |
|    value_loss           | 109         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60    

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 57.8        |
| time/                   |             |
|    fps                  | 491         |
|    iterations           | 131         |
|    time_elapsed         | 545         |
|    total_timesteps      | 268288      |
| train/                  |             |
|    approx_kl            | 0.017380279 |
|    clip_fraction        | 0.139       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.589      |
|    explained_variance   | -2.38e-05   |
|    learning_rate        | 0.0003      |
|    loss                 | 54.9        |
|    n_updates            | 1300        |
|    policy_gradient_loss | 0.0123      |
|    value_loss           | 114         |
-----------------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 60      

------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 60           |
|    ep_rew_mean          | 56.6         |
| time/                   |              |
|    fps                  | 489          |
|    iterations           | 141          |
|    time_elapsed         | 589          |
|    total_timesteps      | 288768       |
| train/                  |              |
|    approx_kl            | 0.0128499875 |
|    clip_fraction        | 0.109        |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.582       |
|    explained_variance   | -3.7e-06     |
|    learning_rate        | 0.0003       |
|    loss                 | 45.9         |
|    n_updates            | 1400         |
|    policy_gradient_loss | 0.00758      |
|    value_loss           | 117          |
------------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_m

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 56.2        |
| time/                   |             |
|    fps                  | 490         |
|    iterations           | 151         |
|    time_elapsed         | 630         |
|    total_timesteps      | 309248      |
| train/                  |             |
|    approx_kl            | 0.022323143 |
|    clip_fraction        | 0.133       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.548      |
|    explained_variance   | -1.55e-05   |
|    learning_rate        | 0.0003      |
|    loss                 | 55.2        |
|    n_updates            | 1500        |
|    policy_gradient_loss | 0.00913     |
|    value_loss           | 114         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60    

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 57.1        |
| time/                   |             |
|    fps                  | 490         |
|    iterations           | 161         |
|    time_elapsed         | 671         |
|    total_timesteps      | 329728      |
| train/                  |             |
|    approx_kl            | 0.011408877 |
|    clip_fraction        | 0.13        |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.582      |
|    explained_variance   | -2.38e-07   |
|    learning_rate        | 0.0003      |
|    loss                 | 47.5        |
|    n_updates            | 1600        |
|    policy_gradient_loss | 0.0111      |
|    value_loss           | 118         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60    

----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 60         |
|    ep_rew_mean          | 55.7       |
| time/                   |            |
|    fps                  | 491        |
|    iterations           | 171        |
|    time_elapsed         | 712        |
|    total_timesteps      | 350208     |
| train/                  |            |
|    approx_kl            | 0.07253586 |
|    clip_fraction        | 0.144      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.598     |
|    explained_variance   | 0          |
|    learning_rate        | 0.0003     |
|    loss                 | 55.6       |
|    n_updates            | 1700       |
|    policy_gradient_loss | 0.0147     |
|    value_loss           | 119        |
----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_m

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 55.1        |
| time/                   |             |
|    fps                  | 491         |
|    iterations           | 181         |
|    time_elapsed         | 753         |
|    total_timesteps      | 370688      |
| train/                  |             |
|    approx_kl            | 0.007174743 |
|    clip_fraction        | 0.114       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.539      |
|    explained_variance   | 0.000157    |
|    learning_rate        | 0.0003      |
|    loss                 | 70.5        |
|    n_updates            | 1800        |
|    policy_gradient_loss | 0.00778     |
|    value_loss           | 114         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60    

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 36.4        |
| time/                   |             |
|    fps                  | 491         |
|    iterations           | 191         |
|    time_elapsed         | 795         |
|    total_timesteps      | 391168      |
| train/                  |             |
|    approx_kl            | 0.011625508 |
|    clip_fraction        | 0.104       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.478      |
|    explained_variance   | 0.0123      |
|    learning_rate        | 0.0003      |
|    loss                 | 49.3        |
|    n_updates            | 1900        |
|    policy_gradient_loss | 0.00121     |
|    value_loss           | 76.6        |
-----------------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 60      

<stable_baselines3.ppo.ppo.PPO at 0x270dad596d0>

### 6. Save Model

In [65]:
model.save('PPO')

In [70]:
evaluate_policy(model, env, n_eval_episodes=10)

(24.0, 54.99090833947008)