In [1]:
%pip install torch torchvision stable-baselines3 ultralytics gym numpy python-dotenv


Collecting torch
  Downloading torch-2.9.1-cp310-cp310-win_amd64.whl.metadata (30 kB)
Collecting torchvision
  Downloading torchvision-0.24.1-cp310-cp310-win_amd64.whl.metadata (5.9 kB)
Collecting stable-baselines3
  Downloading stable_baselines3-2.7.0-py3-none-any.whl.metadata (4.8 kB)
Collecting ultralytics
  Downloading ultralytics-8.3.228-py3-none-any.whl.metadata (37 kB)
Collecting gym
  Downloading gym-0.26.2.tar.gz (721 kB)
     ---------------------------------------- 0.0/721.7 kB ? eta -:--:--
     ---------------------------------------- 0.0/721.7 kB ? eta -:--:--
     -------------- ------------------------- 262.1/721.7 kB ? eta -:--:--
     ---------------------------------------- 721.7/721.7 kB 1.3 MB/s  0:00:00
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  

In [None]:
import os

from dotenv import load_dotenv
from stable_baselines3 import PPO
from envionments.threshold_refinement import ThresholdRefinementEnv
from utility.dataset import load_pascal_voc2007, load_custom_dataset

load_dotenv()

VOC_ROOT = os.getenv("VOC_ROOT", "data/voc")
IMAGE_DIR = os.getenv("IMAGE_DIR")
LABEL_DIR = os.getenv("LABEL_DIR")
DATA_LIMIT = int(os.getenv("RL_DATA_LIMIT", "500"))
TOTAL_TIMESTEPS = int(os.getenv("RL_TOTAL_TIMESTEPS", "20000"))
LEARNING_RATE = float(os.getenv("RL_LEARNING_RATE", "3e-4"))

if IMAGE_DIR and LABEL_DIR:
    dataset = load_custom_dataset(IMAGE_DIR, LABEL_DIR, annotation_format="yolo_txt", limit=DATA_LIMIT)
else:
    dataset = load_pascal_voc2007(VOC_ROOT, image_set="trainval", limit=DATA_LIMIT, download=False)

env = ThresholdRefinementEnv(dataset)
model = PPO("MlpPolicy", env, verbose=1, learning_rate=LEARNING_RATE, n_steps=512)
model.learn(total_timesteps=TOTAL_TIMESTEPS)

model.save("rl_threshold_tuner")


In [None]:
import os

from dotenv import load_dotenv
from utility.evaluation import evaluate_policy, summarize_stats, plot_threshold_trajectories

load_dotenv()
EVAL_EPISODES = int(os.getenv("RL_EVAL_EPISODES", "5"))

stats = evaluate_policy(model, dataset, episodes=EVAL_EPISODES)
summary = summarize_stats(stats)
print(summary)

plot_threshold_trajectories(stats);
