In [5]:
import time
import gymnasium
import numpy as np
from enum import Enum
from gymnasium import spaces
from stable_baselines3 import PPO, DQN
from stable_baselines3.common.env_util import make_vec_env
import os

In [23]:
class Actions(Enum):
    RIGHT = 0
    LEFT = 1


class GoLeftEnv(gymnasium.Env):
    metadata = {"render_modes": ["console"], "step_limit": 0, "current_step": 0, "current_reward": 0}

    def __init__(self, render_mode=None, road_length=16):
        self.render_mode = render_mode
        self.road_length = road_length
        self.agent_position = self.road_length - 1

        self.action_space = spaces.Discrete(2)
        self.observation_space = spaces.Box(low=0, high=self.road_length, shape=(1,), dtype=np.float32)

        self.metadata["step_limit"] = self.road_length + 10

    def reset(self, **kwargs):
        self.metadata["current_reward"] = 0
        self.metadata["current_step"] = 0
        self.agent_position = self.road_length - 1
        return np.array([self.agent_position]).astype(np.float32), {}

    def step(self, action):
        if action == Actions.LEFT.value:
            self.agent_position -= 1
        elif action == Actions.RIGHT.value:
            self.agent_position += 1
        else:
            raise ValueError(f"Received invalid action: {action}")

        self.agent_position = np.clip(self.agent_position, 0, self.road_length)

        terminated = self.metadata["current_step"] > self.metadata["step_limit"]

        done = self.agent_position == 0

        reward = self._calculate_reward(action)
        self.metadata["current_reward"] += reward

        info = {}

        self.metadata["current_step"] += 1
        return np.array([self.agent_position]), reward, done, terminated, info

    def _calculate_reward(self, action):
        if self.agent_position == 0:
            return 1e5
        elif action == Actions.LEFT.value:
            return self.metadata["current_step"] * self.road_length * (3e-2)
        else:
            return self.metadata["current_step"] * self.road_length * (-3e-5)

    def render(self):
        if self.render_mode != "console":
            raise NotImplementedError()
        # Agent is a + the rest is a dot
        print("." * self.agent_position, end="")
        print("+", end="")
        print("." * (self.road_length - self.agent_position))

    def close(self):
        pass


In [ ]:
model = PPO.load("./ppo_goleft.zip")

env = GoLeftEnv(render_mode="console")
env = make_vec_env(lambda: env, n_envs=1)

obs = env.reset()

n_steps = 10000
for step in range(n_steps):
    action = model.predict(obs)
    obs, reward, done, terminated = env.step(action[0])

    time.sleep(0.2)
    os.system("cls")
    env.render()
    print(f"current_reward: {env.metadata['current_reward']}")

env.close()


..............+..
current_reward: 0.0003
...............+.
current_reward: 0.00026999999999999995
..............+..
current_reward: 0.00057
.............+...
current_reward: 0.0008699999999999999
............+....
current_reward: 0.0011699999999999998
.............+...
current_reward: 0.0010199999999999999
............+....
current_reward: 0.0013199999999999998
...........+.....
current_reward: 0.0016199999999999997
..........+......
current_reward: 0.0019199999999999996
...........+.....
current_reward: 0.0016499999999999996
..........+......
current_reward: 0.0019499999999999995
...........+.....
current_reward: 0.0016199999999999995
..........+......
current_reward: 0.0019199999999999994
.........+.......
current_reward: 0.0022199999999999993
........+........
current_reward: 0.0025199999999999992
.......+.........
current_reward: 0.002819999999999999
......+..........
current_reward: 0.003119999999999999
.....+...........
current_reward: 0.003419999999999999
....+............
curre

In [31]:
model = DQN.load("./dqn_goleft.zip")

env = GoLeftEnv(render_mode="console", road_length=256)

obs, info = env.reset()

n_steps = 300
for step in range(n_steps):
    action = model.predict(obs)
    obs, reward, done, terminated, info = env.step(action[0])

    time.sleep(0.2)
    os.system("cls")
    env.render()
    print(f"current_reward: {env.metadata['current_reward']}")
    if done or terminated:
        env.reset()

env.close()


..............................................................................................................................................................................................................................................................+..
current_reward: 0.0
.............................................................................................................................................................................................................................................................+...
current_reward: 7.68
............................................................................................................................................................................................................................................................+....
current_reward: 23.04
...................................................................................................................................................................