In [56]:
from stable_baselines3 import DQN
from stable_baselines3.common.env_checker import check_env
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from Pyleste.PICO8 import PICO8
from Pyleste.Carts.Celeste import Celeste
import Pyleste.CelesteUtils as utils
import random
import time

In [None]:
# auxiliary functions rewritten by me
def compute_displacement(player):
    sign = lambda x: 1 if x > 0 else -1 if x < 0 else 0
    dx, dy = round(player.rem.x + player.spd.x), round(player.rem.y + player.spd.y)
    dx, dy = dx + sign(dx), dy + sign(dy)
    while player.is_solid(dx, 0): dx -= sign(player.spd.x)
    while player.is_solid(dx, dy): dy -= sign(player.spd.y)
    return dx, dy

def action_restrictions(game, player):
    dx, dy = compute_displacement(player)
    h_movement = abs(player.spd.x) <= 1
    can_jump = not player.p_jump and (player.grace - 1 > 0 or player.is_solid(-3 + dx, dy) or player.is_solid(3 + dx, dy) or player.is_solid(dx, 1 + dy))
    can_dash = player.djump > 0 or player.is_solid(dx, 1 + dy) or player.check(game.balloon, 0, 0) or player.check(game.fruit, 0, 0) or player.check(game.fly_fruit, 0, 0)
    return h_movement, can_jump, can_dash

def eval_actions(h_movement, can_jump, can_dash):
    ''' button states
        0b000000 -  0 - no input
        0b000001 -  1 - l
        0b000010 -  2 - r
        0b010000 - 16 - z
        0b010001 - 17 - l + z
        0b010010 - 18 - r + z
        0b100000 - 32 - x
        0b100001 - 33 - l + x
        0b100010 - 34 - r + x
        0b100100 - 36 - u + x
        0b100101 - 37 - u + l + x
        0b100110 - 38 - u + r + x
        0b101000 - 40 - d + x
        0b101001 - 41 - d + l + x
        0b101010 - 42 - d + r + x
    '''
    actions = [0b000000] if not h_movement else [0b000000, 0b000001, 0b000010]
    if can_jump:
        actions.extend([0b010000] if not h_movement else [0b010000, 0b010001, 0b010010])
    if can_dash:
        actions.extend([0b100000, 0b100001, 0b100010, 0b100100, 0b100101, 0b100110, 0b101000, 0b101001, 0b101010])
    return actions


def get_possible_actions(p8):
    p = p8.game.get_player()

    if not p: # dead RIP :(
        return []

    if p.dash_time != 0: return [0b000000]
    return eval_actions(*action_restrictions(p8.game, p))

LEVEL_ID = 0
LEVEL_GOAL = (108, 0)
x_goal, y_goal = LEVEL_GOAL
FPS = 30
MAX_FRAMES = 1 * 60 * FPS
TIME_PENALTY = 1

ALL_ACTIONS = [
    0, 1, 2, 16, 17, 18, 32, 33, 34, 36, 37, 38, 40, 41, 42
]

class CelesteEnv(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self, time_penalty=TIME_PENALTY):
        super().__init__()
        self.time_penalty = time_penalty
        self.p8 = PICO8(Celeste)
        utils.load_room(self.p8, LEVEL_ID)
        utils.skip_player_spawn(self.p8)

        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf, shape=(6,), dtype=np.float32
        )

        self.action_space = spaces.Discrete(len(ALL_ACTIONS))

        self.frame_count = 0
        self.best_reward = -np.inf
        self.best_episode = []

    def reset(self, seed=None, options=None):
        utils.load_room(self.p8, LEVEL_ID)
        utils.skip_player_spawn(self.p8)
        self.frame_count = 0
        self.current_episode_actions = []
        return self._get_obs(), {}

    def _get_obs(self):
        p = self.p8.game.get_player()
        if not p:
            return np.array([
                *self.p8.game.last_player_pos,
                *self.p8.game.last_player_rem,
                *self.p8.game.last_player_spd
            ], dtype=np.float32)
        return np.array([
            p.x, p.y, p.rem.x, p.rem.y, p.spd.x, p.spd.y
        ], dtype=np.float32)

    def step(self, action_idx):
        action_value = ALL_ACTIONS[action_idx]
        self.p8.set_btn_state(action_value)
        self.p8.step()

        obs = self._get_obs()
        reward = self._compute_reward()
        terminated = self._check_termination()
        truncated = self.frame_count >= MAX_FRAMES

        self.current_episode_actions.append(action_value)
        self.frame_count += 1

        # save best episode
        if reward > self.best_reward:
            self.best_reward = reward
            self.best_episode = self.current_episode_actions.copy()

        return obs, reward, terminated, truncated, {}

    def _compute_reward(self):
        p = self.p8.game.get_player()
        reward = 0
        if not p:
            # death penalty
            reward -= 50
            x = self.p8.game.last_player_pos[0]
            y = self.p8.game.last_player_pos[1]
            spdx = self.p8.game.last_player_spd[0]
            spdy = self.p8.game.last_player_spd[1]
        else:
            x, y = p.x, p.y
            spdx, spdy = p.spd.x, p.spd.y
            # check if got strawberry
            if p.check(self.p8.game.fruit, 0, 0):
                reward += 200
        
        reward += (x) * 0.5
        reward += (128 - y) * 1
        reward += 1 * spdx
        reward += (-1) * 2 * spdy

        # time penalty
        reward -= self.time_penalty

        # distance bonus
        dist = ((x - x_goal) ** 2 + (y - y_goal) ** 2) ** 0.5
        reward += (160 - dist) * 0.5
        
        # huge goal for reaching y = 26
        if y <= 26 and x >= 90:
            reward += 300

        # MUDEI ISSO
        if y <= 10 and x >= 100:
            reward += 500

        reward /= 256
        
        return reward

    def _check_termination(self):
        p = self.p8.game.get_player()
        if not p:
            return True
        # goal check with vertical speed <= -0.5
        # MUDEI ISSO
        if (p.x >= LEVEL_GOAL[0] - 5 and p.y <= 5 - LEVEL_GOAL[1] and p.spd.y <= -0.5):
            print("I'M ACTUALLY FINISHED")
            return True
        return False

    def render(self, mode="human"):
        print(self.p8.game)

    def save_best_episode(self):
        # get current time
        filename = time.strftime("best_episode_%Y%m%d_%H%M%S.txt")
        filename = f"./out/lvl{LEVEL_ID}/" + filename
        with open(filename, "w") as f:
            for a in self.best_episode:
                f.write(f"{a}\n")


In [58]:
env = CelesteEnv()
print(check_env(env))

None


In [59]:
model = DQN(
    "MlpPolicy",
    env,
    learning_starts=1000,
    buffer_size=10_000,
    learning_rate=1e-3,
    batch_size=32,
    gamma=0.99,
    exploration_fraction=0.9,
    exploration_final_eps=0.05,
    target_update_interval=500,
    verbose=1,
    seed=661
)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [60]:
model.learn(total_timesteps=500_000)
env.save_best_episode()

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 186      |
|    ep_rew_mean      | 58       |
|    exploration_rate | 0.998    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 25710    |
|    time_elapsed     | 0        |
|    total_timesteps  | 742      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 308      |
|    ep_rew_mean      | 86.5     |
|    exploration_rate | 0.995    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 5955     |
|    time_elapsed     | 0        |
|    total_timesteps  | 2464     |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.0134   |
|    n_updates        | 365      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean    

## TODO:
* deixar as recompensas gigantes um pouco mais autonomas (independente do level)
* penalizar mais o tempo
* portar pra Ray RLib pra ter action masking
* moranguinhos?