In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

def generate_task(n_states, b, seed=None):
    """Generate a random episodic task."""
    if seed is not None:
        np.random.seed(seed)
    r_term = np.random.normal(0, 1, (n_states, 2))
    next_states = np.zeros((n_states, 2, b), dtype=np.int32)
    rewards = np.zeros((n_states, 2, b))
    for s in range(n_states):
        for a in range(2):
            next_states[s, a] = np.random.choice(n_states, size=b, replace=False)
            rewards[s, a] = np.random.normal(0, 1, b)
    return r_term, next_states, rewards

def evaluate_policy(policy, r_term, next_states, rewards, b, n_states, max_iter=1000, tol=1e-4):
    """Evaluate the policy using iterative policy evaluation."""
    V_old = np.zeros(n_states)
    for it in range(max_iter):
        V_new = np.zeros(n_states)
        for s in range(n_states):
            a = policy[s]
            term_val = 0.1 * r_term[s, a]
            ns = next_states[s, a]
            r_non = rewards[s, a]
            non_term_val = 0.9 / b * (np.sum(r_non) + np.sum(V_old[ns]))
            V_new[s] = term_val + non_term_val
        max_diff = np.max(np.abs(V_new - V_old))
        if max_diff < tol:
            break
        V_old = V_new
    return V_new

def run_uniform(r_term, next_states, rewards, b, n_states, max_updates, eval_interval):
    """Perform uniform updates over state-action pairs."""
    n_actions = 2
    Q = np.zeros((n_states, n_actions))
    sa_pairs = [(s, a) for s in range(n_states) for a in range(n_actions)]
    total_sa = len(sa_pairs)
    values = []
    for update_step in range(max_updates):
        s, a = sa_pairs[update_step % total_sa]
        term_r = r_term[s, a]
        ns = next_states[s, a]
        r_non = rewards[s, a]
        max_Qs = np.max(Q[ns], axis=1)
        non_term_val = 0.9 / b * (np.sum(r_non) + np.sum(max_Qs))
        Q[s, a] = 0.1 * term_r + non_term_val
        if (update_step + 1) % eval_interval == 0:
            policy = np.argmax(Q, axis=1)
            V = evaluate_policy(policy, r_term, next_states, rewards, b, n_states)
            values.append(V[0])
    return values

def run_onpolicy(r_term, next_states, rewards, b, n_states, max_updates, eval_interval, start_state=0):
    """Perform on-policy updates during simulated episodes."""
    n_actions = 2
    Q = np.zeros((n_states, n_actions))
    values = []
    updates = 0
    while updates < max_updates:
        state = start_state
        terminated = False
        while not terminated and updates < max_updates:
            if np.random.rand() < 0.1:
                action = np.random.randint(0, n_actions)
            else:
                action = np.argmax(Q[state])
            term_r = r_term[state, action]
            ns = next_states[state, action]
            r_non = rewards[state, action]
            max_Qs = np.max(Q[ns], axis=1)
            non_term_val = 0.9 / b * (np.sum(r_non) + np.sum(max_Qs))
            Q[state, action] = 0.1 * term_r + non_term_val
            if np.random.rand() < 0.1:
                terminated = True
            else:
                state = np.random.choice(ns)
            updates += 1
            if updates % eval_interval == 0:
                policy = np.argmax(Q, axis=1)
                V = evaluate_policy(policy, r_term, next_states, rewards, b, n_states)
                values.append(V[0])
    return values

# Parameters
n_states = 10000
max_updates = 1000
eval_interval = 50
n_tasks = 30
branching_factors = [1, 3]
methods = ['uniform', 'onpolicy']
results = {b: {method: np.zeros(max_updates // eval_interval) for method in methods} for b in branching_factors}

# Main experiment loop
for b in branching_factors:
    print(f"Branching factor b = {b}")
    avg_uniform = np.zeros(max_updates // eval_interval)
    avg_onpolicy = np.zeros(max_updates // eval_interval)
    
    for task_idx in tqdm(range(n_tasks)):
        r_term, next_states, rewards = generate_task(n_states, b, seed=task_idx)
        
        # Uniform updates
        uniform_vals = run_uniform(r_term, next_states, rewards, b, n_states, max_updates, eval_interval)
        avg_uniform += np.array(uniform_vals)
        
        # On-policy updates
        onpolicy_vals = run_onpolicy(r_term, next_states, rewards, b, n_states, max_updates, eval_interval)
        avg_onpolicy += np.array(onpolicy_vals)
    
    avg_uniform /= n_tasks
    avg_onpolicy /= n_tasks
    results[b]['uniform'] = avg_uniform
    results[b]['onpolicy'] = avg_onpolicy

# Plot results
x = np.arange(eval_interval, max_updates + 1, eval_interval)
plt.figure(figsize=(12, 8))

plt.subplot(2, 1, 1)
plt.plot(x, results[1]['uniform'], label='Uniform', color='blue')
plt.plot(x, results[1]['onpolicy'], label='On-policy', color='red')
plt.title('Branching Factor $b=1$ (10,000 states)')
plt.ylabel('Value of Start State')
plt.legend()
plt.grid(True)

plt.subplot(2, 1, 2)
plt.plot(x, results[3]['uniform'], label='Uniform', color='blue')
plt.plot(x, results[3]['onpolicy'], label='On-policy', color='red')
plt.title('Branching Factor $b=3$ (10,000 states)')
plt.xlabel('Number of Expected Updates')
plt.ylabel('Value of Start State')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('figure_8_8_replication.png')
plt.show()