In [None]:
import sys
import os


sys.path.append(os.path.abspath(".."))

In [None]:
# Standard imports
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Stable Baselines 3 PPO
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

# Your custom environment
from env.stardew_mine_env import StardewMineEnv

In [None]:
def train_ppo(env, learning_rate=3e-4, n_steps=2048, batch_size=64, gamma=0.99, total_timesteps=10000, seed=None):
    env = DummyVecEnv([lambda: env])
    
    model = PPO(
        "MultiInputPolicy",
        env,
        learning_rate=learning_rate,
        n_steps=n_steps,
        batch_size=batch_size,
        gamma=gamma,
        verbose=0,
        seed=seed
    )
    
    model.learn(total_timesteps=total_timesteps)
    
    mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10, deterministic=True)
    
    return mean_reward, std_reward


In [None]:
hyperparams = {
    "learning_rate": [1e-4, 3e-4, 1e-3],
    "n_steps": [512, 2048, 4096],
    "batch_size": [32, 64, 128],
    "gamma": [0.95, 0.99, 0.999]
}

In [None]:
results = []

for lr in hyperparams['learning_rate']:
    for n in hyperparams['n_steps']:
        for batch in hyperparams['batch_size']:
            for g in hyperparams['gamma']:
                env = StardewMineEnv()
                mean_reward, std_reward = train_ppo(env, learning_rate=lr, n_steps=n, batch_size=batch, gamma=g)
                
                results.append({
                    "learning_rate": lr,
                    "n_steps": n,
                    "batch_size": batch,
                    "gamma": g,
                    "mean_reward": mean_reward,
                    "std_reward": std_reward
                })
                print(f"Done: lr={lr}, n_steps={n}, batch={batch}, gamma={g}, reward={mean_reward:.2f}")


In [None]:
df = pd.DataFrame(results)
df.to_csv("ppo_hyperparam_results.csv", index=False)

# Example plots
import seaborn as sns

sns.lineplot(data=df, x="learning_rate", y="mean_reward", marker="o")
plt.title("Mean Reward vs Learning Rate")
plt.show()