In [29]:
import sys
import os
import numpy as np
import pandas as pd

sys.path.append("..")


from env.stardew_mine_env import StardewMineEnv

from evaluation.visualization_tools import (
    plot_reward_curve,
    plot_training_curves,
    plot_heatmap,
    plot_bar
)

from stable_baselines3 import PPO


In [30]:
from stable_baselines3 import PPO
model = PPO.load("../models/ppo_miningbot.zip")
print("Model loaded successfully!")


Model loaded successfully!


In [31]:
def evaluate_agent(agent, env, episodes=20, is_ppo=True):
    rewards = []
    ores = []
    floors = []
    energies = []
    heatmaps = []

    # infer grid size robustly (from env.grid, observation_space, or SIZE)
    if hasattr(env, "grid"):
        h, w = env.grid.shape
    else:
        os = getattr(env, 'observation_space', None)
        if os is not None and hasattr(os, 'spaces') and 'local_view' in os.spaces:
            h, w = os.spaces['local_view'].shape[:2]
        else:
            h = w = getattr(env, 'SIZE', getattr(env, 'size', 10))

    def _normalize_obs(obs, obs_space):
        # Ensure each dict observation field is a numpy array with the
        # shape expected by the observation space (no 0-d scalars).
        if not isinstance(obs, dict) or obs_space is None:
            return obs

        norm = {}
        spaces = getattr(obs_space, 'spaces', {})
        for k, v in obs.items():
            space = spaces.get(k)
            arr = np.asarray(v)
            # cast dtype when possible
            try:
                if space is not None and hasattr(space, 'dtype'):
                    arr = arr.astype(space.dtype)
            except Exception:
                pass

            # expand 0-d scalars into shape expected by the Box
            try:
                if arr.shape == () and space is not None:
                    arr = arr.reshape(space.shape)
                elif space is not None and arr.shape != space.shape and arr.size == np.prod(space.shape):
                    arr = arr.reshape(space.shape)
            except Exception:
                pass

            norm[k] = arr

        return norm

    for ep in range(episodes):
        # reset may return (obs, info) (Gymnasium) or obs (older Gym)
        reset_result = env.reset()
        if isinstance(reset_result, tuple):
            # common Gymnasium signature: (obs, info)
            obs = reset_result[0]
            info = reset_result[1] if len(reset_result) > 1 else {}
        else:
            obs = reset_result
            info = {}

        total_reward = 0
        visit_count = np.zeros((h, w), dtype=int)

        # track initial ore count so we can report collected ore
        try:
            initial_ore = int(np.sum(env.grid == getattr(env, 'ORE', 4)))
        except Exception:
            initial_ore = 0

        last_pos = info.get('agent_pos', None)

        done = False
        while not done:

            # Determine current player location (prefer info, then obs)
            if last_pos is not None:
                x, y = last_pos
            else:
                x = y = None
                if isinstance(obs, dict):
                    al = obs.get('agent_location', None)
                    if al is not None:
                        try:
                            x = int(round(float(al[0])))
                            y = int(round(float(al[1])))
                        except Exception:
                            x = y = None
                elif isinstance(obs, np.ndarray):
                    try:
                        y, x = np.unravel_index(obs.argmax(), obs.shape)
                        x = int(x); y = int(y)
                    except Exception:
                        x = y = None

            if x is not None and 0 <= x < w and 0 <= y < h:
                visit_count[y][x] += 1

            # Choose action. Normalize observations to match SB3 expectations
            obs_for_agent = _normalize_obs(obs, getattr(env, 'observation_space', None))
            if is_ppo:
                action, _ = agent.predict(obs_for_agent)
            else:
                # support legacy agents that expect the raw obs
                action = agent.act(obs_for_agent) if hasattr(agent, 'act') else agent.predict(obs_for_agent)[0]

            # Step env (support Gymnasium and Gym signatures)
            step_result = env.step(action)
            if isinstance(step_result, tuple):
                if len(step_result) == 5:
                    obs, reward, terminated, truncated, info = step_result
                    done = bool(terminated or truncated)
                elif len(step_result) == 4:
                    obs, reward, done, info = step_result
                else:
                    # unexpected tuple length; best-effort unpack
                    obs = step_result[0]
                    reward = float(step_result[1]) if len(step_result) > 1 else 0.0
                    done = False
                    info = step_result[-1] if len(step_result) > 2 else {}
            else:
                obs = step_result
                reward = 0.0
                done = False
                info = {}

            total_reward += float(reward)

            # Update position (preferred source)
            last_pos = info.get('agent_pos', None)

        # Episode finished: record stats
        rewards.append(total_reward)
        try:
            remaining_ore = int(np.sum(env.grid == getattr(env, 'ORE', 4)))
            ores.append(initial_ore - remaining_ore)
        except Exception:
            ores.append(info.get('ore_collected', 0))
        floors.append(getattr(env, 'floor', info.get('floor', 0)))
        energies.append(getattr(env, 'energy', info.get('energy', 0)))
        heatmaps.append(visit_count)

    return rewards, ores, floors, energies, heatmaps


In [32]:
env = StardewMineEnv(size=10)

ppo_rewards, ppo_ores, ppo_floors, ppo_energy, ppo_heatmaps = evaluate_agent(
    model, env, episodes=20, is_ppo=True
)


ValueError: Error: Unexpected observation shape () for Box environment, please use (1,) or (n_env, 1) for the observation shape.

In [None]:
results = pd.DataFrame({
    "agent": ["PPO"]*20,
    "reward": ppo_rewards,
    "ore": ppo_ores,
    "floor": ppo_floors,
    "energy": ppo_energy
})

results.to_csv("../results/stats.csv", index=False)
results.head()


In [None]:
plot_bar(
    x=["PPO"],
    heights=[np.mean(ppo_rewards)],
    xlabel="Algorithm",
    ylabel="Average Reward",
    title="PPO vs DQN Average Episode Reward"
)


In [None]:
plot_heatmap(ppo_heatmaps[0], "PPO Heatmap (Episode 1)")
