In [3]:
# train_agent.py
import numpy as np
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from cat_env import CATEnv, generate_item_bank

def make_env():
    # Generate item bank
    item_bank_df = generate_item_bank(num_items=100, seed=123)
    # Suppose we fix the agent's true ability at +1.5 for demonstration
    env = CATEnv(item_bank_df=item_bank_df, agent_true_ability=1.5, min_items=10, max_items=30)
    return env

def main():
    env = make_env()
    # Wrap environment in a VecEnv for Stable-Baselines
    vec_env = DummyVecEnv([make_env])

    # Create model
    model = PPO(
        policy="MlpPolicy",
        env=vec_env,
        verbose=1,
        n_steps=64,
        batch_size=32,
        learning_rate=1e-3,
        gamma=0.99,
        ent_coef=0.0,
    )

    # Train
    model.learn(total_timesteps=5000)

    # Test the learned policy
    obs, _ = env.reset()
    total_reward = 0
    done = False
    while not done:
        action, _states = model.predict(obs, deterministic=True)
        obs, reward, done, info = env.step(action)
        total_reward += reward
    print(f"Test completed with total reward={total_reward:.3f} and final theta={obs[0]:.3f}")



ModuleNotFoundError: No module named 'stable_baselines3'