In [None]:
#https://sourceforge.net/projects/swig/files/swigwin/swigwin-4.0.2/swigwin-4.0.2.zip/download?use_mirror=ixpeering

# 1. Import Dependencies

In [1]:
import gym 
from gym import Env
from gym.spaces import Discrete, Box, Dict, Tuple, MultiBinary, MultiDiscrete 
import numpy as np
import random
import os
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.evaluation import evaluate_policy

# 2. Types of Spaces

In [7]:
Discrete(3).sample()

0

In [4]:
Box(0,1,shape=(3,3)).sample()

array([[0.49781933, 0.3406956 , 0.536787  ],
       [0.9906441 , 0.11294025, 0.92704916],
       [0.5587072 , 0.47451192, 0.55221   ]], dtype=float32)

In [8]:
Box(0,255,shape=(3,3), dtype=int).sample()

array([[178,   5, 199],
       [ 30,  32,  68],
       [ 49, 111, 139]])

In [9]:
Tuple((Discrete(2), Box(0,100, shape=(1,)))).sample()

(1, array([80.25087], dtype=float32))

In [10]:
Dict({'height':Discrete(2), "speed":Box(0,100, shape=(1,))}).sample()

OrderedDict([('height', 0), ('speed', array([17.396807], dtype=float32))])

In [12]:
MultiBinary(4).sample()

array([1, 0, 1, 1], dtype=int8)

In [31]:
MultiDiscrete([5,2,2]).sample()

array([4, 0, 1], dtype=int64)

# 3. Building an Environment

In [47]:
class ShowerEnv(Env):
    def __init__(self):
        # Actions we can take, down, stay, up
        self.action_space = Discrete(3)
        # Temperature array
        self.observation_space = Box(low=np.array([0]), high=np.array([100]))
        # Set start temp
        self.state = 38 + random.randint(-3,3)
        # Set shower length
        self.shower_length = 60
        
    def step(self, action):
        # Apply action
        # 0 -1 = -1 temperature
        # 1 -1 = 0 
        # 2 -1 = 1 temperature 
        self.state += action -1 
        # Reduce shower length by 1 second
        self.shower_length -= 1 
        
        # Calculate reward
        if self.state >=37 and self.state <=39: 
            reward =1 
        else: 
            reward = -1 
        
        # Check if shower is done
        if self.shower_length <= 0: 
            done = True
        else:
            done = False
        
        # Apply temperature noise
        #self.state += random.randint(-1,1)
        # Set placeholder for info
        info = {}
        
        # Return step information
        return self.state, reward, done, info

    def render(self, mode=None):
        # Implement viz
        pass
    
    def reset(self):
        # Reset shower temperature
        self.state = np.array([38 + random.randint(-3,3)]).astype(float)
        # Reset shower time
        self.shower_length = 60 
        return self.state

In [48]:
env=ShowerEnv()

  logger.warn(


In [36]:
env.observation_space.sample()

array([0.38186643], dtype=float32)

In [37]:
env.reset()

array([40.])

In [38]:
from stable_baselines3.common.env_checker import check_env

In [39]:
check_env(env, warn=True)

AssertionError: The observation returned by the `reset()` method does not match the given observation space

# 4. Test Environment

In [41]:
episodes = 5
for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0 
    
    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score+=reward
    print('Episode:{} Score:{}'.format(episode, score))
env.close()

Episode:1 Score:-42
Episode:2 Score:-60
Episode:3 Score:-30
Episode:4 Score:16
Episode:5 Score:-26


In [None]:
env.close()

# 5. Train Model

In [42]:
log_path = os.path.join('Training', 'Logs')

In [43]:
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log=log_path)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [44]:
model.learn(total_timesteps=400000)

Logging to Training\Logs\PPO_5
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 60       |
|    ep_rew_mean     | -31.1    |
| time/              |          |
|    fps             | 509      |
|    iterations      | 1        |
|    time_elapsed    | 4        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | -28.3       |
| time/                   |             |
|    fps                  | 760         |
|    iterations           | 2           |
|    time_elapsed         | 5           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.008984437 |
|    clip_fraction        | 0.0479      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.09       |
|    explained_variance   | 0.000158    |

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | -22.6       |
| time/                   |             |
|    fps                  | 1267        |
|    iterations           | 11          |
|    time_elapsed         | 17          |
|    total_timesteps      | 22528       |
| train/                  |             |
|    approx_kl            | 0.006362684 |
|    clip_fraction        | 0.0266      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1          |
|    explained_variance   | -0.000551   |
|    learning_rate        | 0.0003      |
|    loss                 | 36.3        |
|    n_updates            | 100         |
|    policy_gradient_loss | -0.00225    |
|    value_loss           | 76.8        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60    

------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 60           |
|    ep_rew_mean          | -14.1        |
| time/                   |              |
|    fps                  | 1358         |
|    iterations           | 21           |
|    time_elapsed         | 31           |
|    total_timesteps      | 43008        |
| train/                  |              |
|    approx_kl            | 0.0032677683 |
|    clip_fraction        | 0.021        |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.856       |
|    explained_variance   | 0.0052       |
|    learning_rate        | 0.0003       |
|    loss                 | 35.8         |
|    n_updates            | 200          |
|    policy_gradient_loss | -0.00201     |
|    value_loss           | 76           |
------------------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len

------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 60           |
|    ep_rew_mean          | 2.92         |
| time/                   |              |
|    fps                  | 1390         |
|    iterations           | 31           |
|    time_elapsed         | 45           |
|    total_timesteps      | 63488        |
| train/                  |              |
|    approx_kl            | 0.0021995537 |
|    clip_fraction        | 0.0451       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.62        |
|    explained_variance   | 0.000272     |
|    learning_rate        | 0.0003       |
|    loss                 | 43.1         |
|    n_updates            | 300          |
|    policy_gradient_loss | -0.00121     |
|    value_loss           | 93.1         |
------------------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 26.3        |
| time/                   |             |
|    fps                  | 1410        |
|    iterations           | 41          |
|    time_elapsed         | 59          |
|    total_timesteps      | 83968       |
| train/                  |             |
|    approx_kl            | 0.009285607 |
|    clip_fraction        | 0.0778      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.619      |
|    explained_variance   | 0.000218    |
|    learning_rate        | 0.0003      |
|    loss                 | 37.2        |
|    n_updates            | 400         |
|    policy_gradient_loss | 0.000403    |
|    value_loss           | 68.3        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60    

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 45.6        |
| time/                   |             |
|    fps                  | 1425        |
|    iterations           | 51          |
|    time_elapsed         | 73          |
|    total_timesteps      | 104448      |
| train/                  |             |
|    approx_kl            | 0.018775795 |
|    clip_fraction        | 0.0976      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.557      |
|    explained_variance   | -4.97e-05   |
|    learning_rate        | 0.0003      |
|    loss                 | 45.7        |
|    n_updates            | 500         |
|    policy_gradient_loss | 0.00705     |
|    value_loss           | 85.4        |
-----------------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 60  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 53.8        |
| time/                   |             |
|    fps                  | 1442        |
|    iterations           | 61          |
|    time_elapsed         | 86          |
|    total_timesteps      | 124928      |
| train/                  |             |
|    approx_kl            | 0.005297954 |
|    clip_fraction        | 0.103       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.594      |
|    explained_variance   | -3.22e-06   |
|    learning_rate        | 0.0003      |
|    loss                 | 48.4        |
|    n_updates            | 600         |
|    policy_gradient_loss | 0.00703     |
|    value_loss           | 110         |
-----------------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 60  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 54.2        |
| time/                   |             |
|    fps                  | 1448        |
|    iterations           | 71          |
|    time_elapsed         | 100         |
|    total_timesteps      | 145408      |
| train/                  |             |
|    approx_kl            | 0.005629017 |
|    clip_fraction        | 0.0937      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.534      |
|    explained_variance   | -0.000735   |
|    learning_rate        | 0.0003      |
|    loss                 | 82.3        |
|    n_updates            | 700         |
|    policy_gradient_loss | 0.00908     |
|    value_loss           | 103         |
-----------------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 60  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 46.5        |
| time/                   |             |
|    fps                  | 1454        |
|    iterations           | 81          |
|    time_elapsed         | 114         |
|    total_timesteps      | 165888      |
| train/                  |             |
|    approx_kl            | 0.104668505 |
|    clip_fraction        | 0.159       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.476      |
|    explained_variance   | -0.012      |
|    learning_rate        | 0.0003      |
|    loss                 | 31.1        |
|    n_updates            | 800         |
|    policy_gradient_loss | 0.0264      |
|    value_loss           | 56.6        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60    

----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 60         |
|    ep_rew_mean          | 45.6       |
| time/                   |            |
|    fps                  | 1460       |
|    iterations           | 91         |
|    time_elapsed         | 127        |
|    total_timesteps      | 186368     |
| train/                  |            |
|    approx_kl            | 0.03741492 |
|    clip_fraction        | 0.108      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.337     |
|    explained_variance   | -0.0142    |
|    learning_rate        | 0.0003     |
|    loss                 | 25.5       |
|    n_updates            | 900        |
|    policy_gradient_loss | 0.0077     |
|    value_loss           | 60.8       |
----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_m

----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 60         |
|    ep_rew_mean          | 40.5       |
| time/                   |            |
|    fps                  | 1464       |
|    iterations           | 101        |
|    time_elapsed         | 141        |
|    total_timesteps      | 206848     |
| train/                  |            |
|    approx_kl            | 0.03566791 |
|    clip_fraction        | 0.0911     |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.335     |
|    explained_variance   | 0.0115     |
|    learning_rate        | 0.0003     |
|    loss                 | 57.8       |
|    n_updates            | 1000       |
|    policy_gradient_loss | 0.014      |
|    value_loss           | 109        |
----------------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 60         |
|    ep_rew_mean

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 53.6        |
| time/                   |             |
|    fps                  | 1467        |
|    iterations           | 111         |
|    time_elapsed         | 154         |
|    total_timesteps      | 227328      |
| train/                  |             |
|    approx_kl            | 0.003404431 |
|    clip_fraction        | 0.077       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.395      |
|    explained_variance   | 0.00083     |
|    learning_rate        | 0.0003      |
|    loss                 | 52.8        |
|    n_updates            | 1100        |
|    policy_gradient_loss | 0.00806     |
|    value_loss           | 101         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60    

----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 60         |
|    ep_rew_mean          | 52.7       |
| time/                   |            |
|    fps                  | 1471       |
|    iterations           | 121        |
|    time_elapsed         | 168        |
|    total_timesteps      | 247808     |
| train/                  |            |
|    approx_kl            | 0.03187746 |
|    clip_fraction        | 0.0715     |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.293     |
|    explained_variance   | -0.000485  |
|    learning_rate        | 0.0003     |
|    loss                 | 55.6       |
|    n_updates            | 1200       |
|    policy_gradient_loss | 0.00588    |
|    value_loss           | 119        |
----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_m

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 51.9        |
| time/                   |             |
|    fps                  | 1471        |
|    iterations           | 131         |
|    time_elapsed         | 182         |
|    total_timesteps      | 268288      |
| train/                  |             |
|    approx_kl            | 0.012721978 |
|    clip_fraction        | 0.0443      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.2        |
|    explained_variance   | -0.00649    |
|    learning_rate        | 0.0003      |
|    loss                 | 72.6        |
|    n_updates            | 1300        |
|    policy_gradient_loss | 0.00676     |
|    value_loss           | 117         |
-----------------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 60  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 49.3        |
| time/                   |             |
|    fps                  | 1470        |
|    iterations           | 141         |
|    time_elapsed         | 196         |
|    total_timesteps      | 288768      |
| train/                  |             |
|    approx_kl            | 0.014467287 |
|    clip_fraction        | 0.0603      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.208      |
|    explained_variance   | 0.0159      |
|    learning_rate        | 0.0003      |
|    loss                 | 56.2        |
|    n_updates            | 1400        |
|    policy_gradient_loss | 0.00776     |
|    value_loss           | 102         |
-----------------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 60  

----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 60         |
|    ep_rew_mean          | 47.8       |
| time/                   |            |
|    fps                  | 1471       |
|    iterations           | 151        |
|    time_elapsed         | 210        |
|    total_timesteps      | 309248     |
| train/                  |            |
|    approx_kl            | 0.02806857 |
|    clip_fraction        | 0.0492     |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.209     |
|    explained_variance   | 0.00458    |
|    learning_rate        | 0.0003     |
|    loss                 | 74.5       |
|    n_updates            | 1500       |
|    policy_gradient_loss | 0.00351    |
|    value_loss           | 113        |
----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_m

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 10.4        |
| time/                   |             |
|    fps                  | 1471        |
|    iterations           | 161         |
|    time_elapsed         | 224         |
|    total_timesteps      | 329728      |
| train/                  |             |
|    approx_kl            | 0.013764681 |
|    clip_fraction        | 0.029       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.113      |
|    explained_variance   | 0.03        |
|    learning_rate        | 0.0003      |
|    loss                 | 66.5        |
|    n_updates            | 1600        |
|    policy_gradient_loss | -0.00204    |
|    value_loss           | 140         |
-----------------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 60      

------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 60           |
|    ep_rew_mean          | 20           |
| time/                   |              |
|    fps                  | 1471         |
|    iterations           | 171          |
|    time_elapsed         | 237          |
|    total_timesteps      | 350208       |
| train/                  |              |
|    approx_kl            | 0.0003555752 |
|    clip_fraction        | 0.00898      |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.087       |
|    explained_variance   | 0.0793       |
|    learning_rate        | 0.0003       |
|    loss                 | 55.8         |
|    n_updates            | 1700         |
|    policy_gradient_loss | -0.00128     |
|    value_loss           | 109          |
------------------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len

------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 60           |
|    ep_rew_mean          | 32.3         |
| time/                   |              |
|    fps                  | 1473         |
|    iterations           | 181          |
|    time_elapsed         | 251          |
|    total_timesteps      | 370688       |
| train/                  |              |
|    approx_kl            | 0.0011380955 |
|    clip_fraction        | 0.0238       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.113       |
|    explained_variance   | 0.0262       |
|    learning_rate        | 0.0003       |
|    loss                 | 63.3         |
|    n_updates            | 1800         |
|    policy_gradient_loss | 0.00124      |
|    value_loss           | 136          |
------------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_m

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 60          |
|    ep_rew_mean          | 13.6        |
| time/                   |             |
|    fps                  | 1473        |
|    iterations           | 191         |
|    time_elapsed         | 265         |
|    total_timesteps      | 391168      |
| train/                  |             |
|    approx_kl            | 0.012800582 |
|    clip_fraction        | 0.122       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.15       |
|    explained_variance   | -0.627      |
|    learning_rate        | 0.0003      |
|    loss                 | 36.9        |
|    n_updates            | 1900        |
|    policy_gradient_loss | -0.000721   |
|    value_loss           | 91.6        |
-----------------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 60      

<stable_baselines3.ppo.ppo.PPO at 0x2150d4cafa0>

# 6. Save Model

In [45]:
model.save('PPO')

In [54]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)

(24.0, 54.99090833947008)