In [None]:
import sys
import os


sys.path.append(os.path.abspath(".."))


In [None]:
import gymnasium as gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor

import matplotlib.pyplot as plt
import os
from datetime import datetime

# Import your environment
from env.stardew_mine_env import StardewMineEnv

In [None]:
def make_env():
    env = StardewMineEnv(size=10, max_floor=10, max_energy=100, local_view_size=5)
    env = Monitor(env)  # Logs episode rewards
    return env

env = DummyVecEnv([make_env])


In [None]:
import os
from datetime import datetime


base_dir = os.getcwd()
parent_dir = os.path.dirname(base_dir)

models_dir = os.path.join(parent_dir, "models")
os.makedirs(models_dir, exist_ok=True) 

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_filename = f"ppo_miningbot_{timestamp}.zip"
save_path = os.path.join(models_dir, model_filename)

print("Full save path for model:", save_path)



In [None]:
model = PPO(
    policy="MultiInputPolicy",
    env=env,
    verbose=1,
    learning_rate=3e-4,
    batch_size=64,
    n_steps=2048,
    gamma=0.99,
    tensorboard_log=f"logs/{run_name}"
)



In [None]:
TIMESTEPS = 200_000  # adjust for your experiment size

model.learn(total_timesteps=TIMESTEPS)
model.save(f"{save_path}/ppo_miningbot")
print("Training complete!")


In [None]:
model = PPO.load(f"{save_path}/ppo_miningbot")

test_env = StardewMineEnv()
obs, _ = test_env.reset()

total_reward = 0
done = False

while not done:
    action, _ = model.predict(obs)
    obs, reward, done, truncated, info = test_env.step(action)
    total_reward += reward

print("Test episode reward:", total_reward)

In [None]:
import pandas as pd

monitor_file = [f for f in os.listdir(".") if f.endswith(".monitor.csv")]
monitor_file

In [None]:
df = pd.read_csv(monitor_file[0], skiprows=1)

plt.plot(df["r"])  # r = reward
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("PPO Training Reward Curve")
plt.show()
