In [4]:
!pip install stable-baselines3 gymnasium numpy matplotlib --quiet


In [9]:
import numpy as np
import matplotlib.pyplot as plt

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv

from stardew_mine_env import StardewMineEnv


In [None]:
def make_env(seed=0):
    def _init():
        return StardewMineEnv(seed=seed)
    return _init

env = DummyVecEnv([make_env(0)])

env


In [None]:
model = PPO(
    policy="MultiInputPolicy",
    env=env,
    learning_rate=3e-4,
    n_steps=2048,
    batch_size=64,
    n_epochs=10,
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.2,
    verbose=1,
    tensorboard_log="./ppo_logs/"
)
model


In [None]:
TIMESTEPS = 100_000  # You can increase to 200k or more

model.learn(total_timesteps=TIMESTEPS)
print("Training complete!")


In [None]:
model.save("ppo_stardew")
print("Model saved as ppo_stardew.zip")


In [None]:
model = PPO.load("ppo_stardew", env=env)
print("Model loaded!")


In [None]:
obs = env.reset()

total_reward = 0
terminated = False
truncated = False

for step in range(200):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated = env.step(action)
    total_reward += reward[0]    # reward is vectorized (1-env)
    
    # Render to notebook
    display(env.envs[0].render())
    
    if terminated or truncated:
        break

print("Episode finished with total reward:", total_reward)


In [None]:
def evaluate_agent(model, episodes=10):
    rewards = []

    for ep in range(episodes):
        obs = env.reset()
        total_reward = 0
        terminated = False
        truncated = False
        
        while not terminated and not truncated:
            action, _ = model.predict(obs, deterministic=True)
            obs, reward, terminated, truncated = env.step(action)
            total_reward += reward[0]

        rewards.append(total_reward)

    return rewards

scores = evaluate_agent(model, episodes=10)
print("Scores:", scores)
print("Mean score:", np.mean(scores))


In [None]:
plt.figure(figsize=(6,4))
plt.hist(scores, bins=10)
plt.title("PPO Agent Rewards Across 10 Episodes")
plt.xlabel("Reward")
plt.ylabel("Count")
plt.show()
