In [None]:
# Paste your entire code here
# Then press SHIFT+ENTER
# cat_env.py
import math
import json
import random
import numpy as np
import pandas as pd
from gymnasium import Env, spaces
from scipy.stats import norm

# ----------------------------
# Some IRT / EAP Helper Functions
# ----------------------------

def irt_probability(theta, a, b, c):
    """
    3PL item response function:
      P(correct) = c + (1 - c) / [1 + exp(-1.7*a*(theta - b))]
    """
    return c + (1 - c) / (1 + math.exp(-1.7 * a * (theta - b)))

def eap_estimate(responses, a_vals, b_vals, c_vals, nqt=31, prior_mean=0, prior_sd=1):
    """
    Grid-based EAP (Expected A Posteriori) for theta estimation.
    """
    thetas = np.linspace(-4, 4, nqt)
    prior = norm.pdf(thetas, loc=prior_mean, scale=prior_sd)

    likelihood = np.ones(nqt)
    for (r, a, b, c) in zip(responses, a_vals, b_vals, c_vals):
        # Probability of correct at each grid point
        p_grid = c + (1 - c) / (1 + np.exp(-1.7 * a * (thetas - b)))
        likelihood *= p_grid**r * (1 - p_grid)**(1 - r)

    posterior = likelihood * prior
    posterior /= (posterior.sum() + 1e-12)  # avoid zero division

    return np.sum(thetas * posterior)

def fisher_information(a, b, c, theta):
    """
    Fisher Information for 3PL at a given theta.
    """
    # D constant ~ 1.7^2 = 2.89
    D2 = 2.89
    num = D2 * (a**2) * (1 - c)
    denom1 = c + np.exp(1.7 * a * (theta - b))
    denom2 = (1 + np.exp(-1.7 * a * (theta - b)))**2
    return num / (denom1 * denom2)

def standard_error(a_vals, b_vals, c_vals, theta):
    """
    Standard error of the theta estimate using sum of Fisher Info.
    """
    total_info = 0.0
    for (a, b, c) in zip(a_vals, b_vals, c_vals):
        total_info += fisher_information(a, b, c, theta)
    return 1.0 / np.sqrt(total_info + 1e-12)


# ----------------------------
# Generate an Item Bank (similar to your code)
# ----------------------------
def generate_item_bank(num_items=100, seed=42):
    random.seed(seed)
    np.random.seed(seed)

    difficulty_values = np.clip(np.random.normal(0, 1, num_items), -3, 3)
    discrimination_values = np.clip(np.random.lognormal(0, 0.5, num_items), 0, 3)
    guessing_values = np.random.uniform(0.05, 0.3, num_items)
    slip_values = np.random.uniform(0.01, 0.3, num_items)

    lo_keys = ["lo1", "lo2", "lo3", "lo4"]
    item_bank = []
    for i in range(num_items):
        # Random learning objectives
        num_lo = random.randint(1, 3)
        chosen_los = random.sample(lo_keys, num_lo)
        lo_dict = {lo: random.choice([0,1]) for lo in chosen_los}

        item_bank.append({
            "item_id": i+1,
            "difficulty": round(float(difficulty_values[i]), 3),
            "discrimination": round(float(discrimination_values[i]), 3),
            "guessing": round(float(guessing_values[i]), 3),
            "slip": round(float(slip_values[i]), 3),
            "learning_objectives": json.dumps(lo_dict)
        })

    df = pd.DataFrame(item_bank)
    return df


# ----------------------------
# CATEnv: Our custom RL environment
# ----------------------------
class CATEnv(Env):
    """
    A Gymnasium-style environment for a simple CAT.
    Action space: Discrete( num_items ) -> select an item_id from [0..num_items-1].
    Observation: [current_theta_estimate] (1D float).
    The environment ends when min_items <= #administered and (SE_theta < 0.3 or #administered >= max_items).
    Or if we run out of items.
    Reward is based on how the updated theta changes, for instance.
    """

    def __init__(self,
                 item_bank_df: pd.DataFrame,
                 agent_true_ability: float = 0.0,
                 min_items: int = 10,
                 max_items: int = 30):
        super().__init__()

        self.item_bank_df = item_bank_df.copy().reset_index(drop=True)
        self.num_items = len(self.item_bank_df)

        # The agent's *true* ability (used to simulate correct/incorrect).
        self.agent_true_ability = agent_true_ability

        # Minimum and maximum items
        self.min_items = min_items
        self.max_items = max_items

        # Define action/obs spaces
        self.action_space = spaces.Discrete(self.num_items)  # choose item index
        self.observation_space = spaces.Box(
            low=-4, high=4, shape=(1,), dtype=np.float32
        )

        self.reset()

    def reset(self, seed=None, options=None):
        if seed is not None:
            self.seed(seed)

        self.administered_indices = []
        self.responses = []
        self.a_vals = []
        self.b_vals = []
        self.c_vals = []

        self.theta_est = 0.0
        self.steps = 0

        # Return initial observation
        return np.array([self.theta_est], dtype=np.float32), {}

    def seed(self, seed):
        random.seed(seed)
        np.random.seed(seed)

    def step(self, action):
        """
        action: integer in [0..num_items-1], the index in item_bank_df
        we'll administer that item if not used, else random fallback
        """

        # If item was already used, pick a random new one
        # (or you can define a penalty for re-using an item).
        if action in self.administered_indices:
            # Possibly penalize or pick a random new item
            valid_actions = [i for i in range(self.num_items) if i not in self.administered_indices]
            if len(valid_actions) == 0:
                # No more items
                done = True
                return np.array([self.theta_est], dtype=np.float32), 0.0, done, {}
            action = random.choice(valid_actions)

        row = self.item_bank_df.iloc[action]
        a = float(row["discrimination"])
        b = float(row["difficulty"])
        c = float(row["guessing"])

        # Simulate correct/incorrect with the agent's *true* ability:
        p_correct = irt_probability(self.agent_true_ability, a, b, c)
        r = 1 if np.random.rand() < p_correct else 0  # 0 or 1

        self.administered_indices.append(action)
        self.responses.append(r)
        self.a_vals.append(a)
        self.b_vals.append(b)
        self.c_vals.append(c)

        # EAP update
        old_theta = self.theta_est
        self.theta_est = eap_estimate(self.responses, self.a_vals, self.b_vals, self.c_vals)
        self.steps += 1

        # Stopping rule check
        se = standard_error(self.a_vals, self.b_vals, self.c_vals, self.theta_est)
        done = False
        if self.steps >= self.min_items and (se <= 0.3 or self.steps >= self.max_items):
            done = True

        # Reward: let's do: difference in theta + (0.5 if correct, else 0)
        # This is just an example.
        reward = (self.theta_est - old_theta) + (0.5 if r == 1 else 0)

        obs = np.array([self.theta_est], dtype=np.float32)
        info = {"p_correct": p_correct, "se": se, "score": r}
        return obs, float(reward), done, info