In [None]:
# ---------------- custom env ----------------
from gymnasium import Env
from gymnasium.spaces import Discrete, Box
import numpy as np
import random


class CustomEnv(Env):
  def __init__(self):
    self.action_space = Discrete(3)
    self.observation_space = Box(low=np.array([0]), high=np.array([100]))
    self.state = 38+random.randint(-3, 3)
    self.shower_length=60


  def step(self, action):
    self.state += action - 1
    self.shower_length -= 1
    done = False

    if self.state >= 37 and self.state <= 39:
      reward  = 1
    else:
      reward = -1

    if self.shower_length <= 0:
      done = True
    self.state += random.randint(-1, 1)
    info = {}

    return self.state, reward,done, done, info
    
  def render(self):
    pass
  def reset(self, *, seed=None, options=None):
    super().reset(seed=seed)
    obs = np.zeros(self.observation_space.shape, dtype=np.float32)
    info = {}
    self.state = 38 + random.randint(-3, 3)
    self.shower_length = 60
    info = {} 
    return obs, info



In [None]:
env = CustomEnv()
print(env.action_space.sample())
print(env.observation_space.sample())

In [None]:
episodes = 10
for ep in range(1, episodes+1):
  stats = env.reset()
  done = False
  score = 0

  while not done:
    env.render()
    action = env.action_space.sample()
    n_state, reward, done, info = env.step(action)
    score += reward

  print(ep, reward)

In [None]:
!pip install stable_baselines3

In [None]:
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv

env_maker = lambda: CustomEnv()
env = DummyVecEnv([env_maker])

model = DQN("MlpPolicy", env, verbose=1, device="cpu")
model.learn(total_timesteps=10000)

In [None]:
import gym
test_env = CustomEnv()
obs, info = test_env.reset()

print("Action space:", test_env.action_space)

while True:
    obs = np.expand_dims(obs, axis=0)
    action, _states = model.predict(obs)
    print("Predicted action:", action, "Type:", type(action), "Shape:", action.shape)

    if isinstance(test_env.action_space, gym.spaces.Discrete):
        action = int(action.item())
    elif isinstance(test_env.action_space, gym.spaces.Box):
        action = action.flatten()

    obs, reward, terminated, truncated, info = test_env.step(action)
    done = terminated or truncated
    if done:
        print("Final info:", info)
        break