In [1]:
from pathlib import Path
from RMZ3Env import make_RMZ3Env
from gymnasium.wrappers import NormalizeObservation
from env_setup import *

PATH = Path("test_result")
PATH.mkdir(parents=True, exist_ok=True)

env_config = {
        "gba_rom": GBA_ROM,
        "gba_sav": GBA_SAV,
        "max_run_time": MAX_TIME,
        "include_lives_count": INCLUDE_LIVES,
        "render_mode": "rgb_array",
        "frameskip": FRAMESKIP,
        "mgba_silence": SILENCE,
        "to_resize": RESIZE,
        "scrn_w": SCRN_W,
        "scrn_h": SCRN_H,
        "to_grayscale": GARYSCALE,
        "record": True,
        "record_path": PATH.joinpath("videos/")}

env = make_RMZ3Env(**env_config)
env = NormalizeObservation(env)

  logger.warn(


In [2]:
import lzma, pickle
from CNNLSTM import CNNLSTM
from evotorch.neuroevolution.net.vecrl import Policy

model = CNNLSTM(env.observation_space.shape, env.action_space.n, HIDDEN_SIZE, NUM_LAYERS)

with lzma.open("EA/PGPE_CNNLSTM/searcher_status_2025-03-27 16_06_02.186381.xz", "rb") as f:
    status = pickle.load(f)

In [3]:
import numpy as np
import torch
import pandas as pd
test_solutions = ["center", "pop_best", "best"]
policy = Policy(model)
results = {}
for solution in test_solutions:
    print("Solution from PGPE model: ", solution)
    if "PGPE_" + solution not in results:
        results["PGPE_" + solution] = pd.DataFrame(columns=["Rewards", "Best stages - checkpoint", "Total play time"])
    policy.set_parameters(status[solution])
    for i in range(20):
        env.set_wrapper_attr("name_prefix", f"PGPE_{solution}_{i}")
        obs, info = env.reset()
        while True:
            action = np.argmax(policy(torch.as_tensor(obs, dtype=torch.float32, device="cpu")))
            obs, reward, terminated, truncated, info = env.step(action)
            env.render()
            if terminated or truncated:
                rewards = info["total_rewards"]
                curr_stage = info["current_stage"]
                curr_checkpoint = info["current_checkpoint"]
                total_play_time = info["total_play_time"]
                results["PGPE_" + solution].loc[len(results["PGPE_" + solution])] = [rewards, (curr_stage, curr_checkpoint), total_play_time]
                env.close()
                policy.reset()
                break

Solution from PGPE model:  center
Solution from PGPE model:  pop_best
Solution from PGPE model:  best


In [4]:
from sb3_contrib import RecurrentPPO
rppo = RecurrentPPO.load("RL/RecurrentPPO/RPPO_model_2025-03-26 13_11_00.498926")
results["RPPO"] = pd.DataFrame(columns=["Rewards", "Best stages - checkpoint", "Total play time"])
for i in range(20):
    env.set_wrapper_attr("name_prefix", f"RPPO_{solution}_{i}")
    obs, info = env.reset()
    lstm_states = None
    # Episode start signals are used to reset the lstm states
    episode_start = None
    while True:
        action, lstm_states = rppo.predict(obs, state=lstm_states, episode_start=episode_start, deterministic = True)
        obs, reward, terminated, truncated, info = env.step(action)
        episode_starts = truncated | terminated
        env.render()
        if terminated or truncated:
            rewards = info["total_rewards"]
            curr_stage = info["current_stage"]
            curr_checkpoint = info["current_checkpoint"]
            total_play_time = info["total_play_time"]
            results["RPPO"].loc[len(results["RPPO"])] = [rewards, (curr_stage, curr_checkpoint), total_play_time]
            env.close()
            policy.reset()
            break


  return F.linear(input, self.weight, self.bias)


In [5]:
results

{'PGPE_center':        Rewards Best stages - checkpoint  Total play time
 0   661.040165                   (1, 3)        50.000000
 1   660.058805                   (1, 3)        50.066667
 2   665.002116                   (1, 3)        50.050000
 3   663.037248                   (1, 3)        50.066667
 4   662.022812                   (1, 3)        50.000000
 5   658.035317                   (1, 3)        50.050000
 6   657.106763                   (1, 3)        50.000000
 7   660.038262                   (1, 3)        50.050000
 8   661.042533                   (1, 3)        50.016667
 9   659.029757                   (1, 3)        50.050000
 10  660.000868                   (1, 3)        50.016667
 11  660.022491                   (1, 3)        50.066667
 12  665.971625                   (1, 3)        50.033333
 13  656.064090                   (1, 3)        50.066667
 14  657.065935                   (1, 3)        50.033333
 15  661.046248                   (1, 3)        50.033333

In [6]:
for sol, v in results.items():
    print(sol, ":")
    print("Mean reward: ", v["Rewards"].mean())
    print("Best stage-checkpoint: ", v["Best stages - checkpoint"].max())
    print("Mean play time: ", v["Total play time"].mean())




PGPE_center :
Mean reward:  660.337377406842
Best stage-checkpoint:  (1, 3)
Mean play time:  50.035833333333876
PGPE_pop_best :
Mean reward:  658.7081504273775
Best stage-checkpoint:  (1, 3)
Mean play time:  50.02333333333387
PGPE_best :
Mean reward:  661.1148312422805
Best stage-checkpoint:  (1, 3)
Mean play time:  50.02750000000054
RPPO :
Mean reward:  643.876301545062
Best stage-checkpoint:  (1, 3)
Mean play time:  50.02916666666721
