In [None]:
from typing import Any, Dict, Optional, Union
from collections import OrderedDict
import gym
from gym import Env
from gym.spaces import Discrete, Box
from gym.envs.registration import EnvSpec
from stable_baselines3.common.type_aliases import GymStepReturn
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
import numpy as np
import random

In [None]:
class ShowerEnv(Env):
    def __init__(self):
        # Actions we can take, down, stay, up
        self.action_space = Discrete(3)
        # Temperature array
        self.observation_space = Box(low=np.array([0]), high=np.array([100]))
        # Set start temp
        self.state = random.randint(5,70)
        # Set shower length
        self.shower_length = 200
        
    def step(self, action:int):
        # Apply action
        # 0 -1 = -1 temperature
        # 1 -1 = 0 
        # 2 -1 = 1 temperature 
        self.state += action -1 
        # Reduce shower length by 1 second
        self.shower_length -= 1 
        
        # Calculate reward
        if self.state >=37 and self.state <=39: 
            if self.state == 38:
                reward = 100
            else:
                reward = 10
        else: 
            reward = -1
        # reward = -abs(38 - self.state)

        # Check if shower is done
        if self.shower_length <= 0: 
            # reward = 100
            done = True
        else:
            done = False
        
        # Apply temperature noise
        #self.state += random.randint(-1,1)
        # Set placeholder for info
        info = {}
        
        # Return step information
        return self.state, reward, done, info

    def render(self, mode):
        # Implement viz
        print(self.state)
    
    def reset(self):
        # Reset shower temperature
        self.state = random.randint(5,70)
        # Reset shower time
        self.shower_length = 200 
        return self.state

In [None]:
env = make_vec_env(ShowerEnv, n_envs=1)
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log='tblog')
model.learn(total_timesteps=1e7)

In [None]:
obs = env.reset()
for _ in range(200):
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

env.close()