In [8]:
import random
import matplotlib.pyplot as plt
import numpy as np
import math
from collections import deque, defaultdict
from EDR_UTILITY_FUNCTIONS import *

def run_n_step_sarsa(initialEdges, goalEdges, totalSteps, nLookahead, epsilon, gamma, alpha, pGen, pSwap, maxAge, edr_window_size=100, convergence_epsilon=0.001, plot=True, bin_size=0.05):
    Q = qTable()
    q_value_diffs = []

    goal_success_queues = {goal: deque(maxlen=edr_window_size) for goal in goalEdges}

    raw_state = [(edge, -1) for edge in initialEdges]
    rolling_edrs = {goal: 0.0 for goal in goalEdges}
    current_state = get_augmented_state(raw_state, rolling_edrs, bin_size, goal_order=goalEdges)

    print("\nStarting SARSA run...")

    # Buffers for n-step
    state_buffer = deque(maxlen=nLookahead + 1)
    action_buffer = deque(maxlen=nLookahead + 1)
    reward_buffer = deque(maxlen=nLookahead)

    state_buffer.append(current_state)
    action_buffer.append(([], None))  # dummy action at t=0

    for t in range(totalSteps):
        prev_action = action_buffer[-1]
        current_state = performAction(prev_action, state_buffer[-1])
        current_state = ageEntanglements(current_state, maxAge)
        current_state = generateEntanglement(current_state, pGen)

        # Compute reward BEFORE next action
        reward = getReward(prev_action, goal_success_queues, t + 1, pSwap)
        reward_buffer.append(reward)
        state_buffer.append(current_state)

        # Update goal success queues
        for g in goal_success_queues:
            goal_success_queues[g].append(0)

        consumed_edges, goal = prev_action
        if goal is not None and consumed_edges:
            success = random.random() < (pSwap ** (len(consumed_edges) - 1))
            if success:
                goal_success_queues[goal][-1] = 1

        # Generate next action
        raw_ent_state, _ = current_state
        edr_snapshot = {
            goal: sum(goal_success_queues[goal]) / max(1, len(goal_success_queues[goal]))
            for goal in goalEdges
        }
        aug_state = get_augmented_state(raw_ent_state, edr_snapshot, bin_size, goal_order=goalEdges)
        next_action = getEpsilonGreedyAction(aug_state, Q, epsilon, goalEdges)
        action_buffer.append(next_action)

        # Perform n-step update if enough history
        if t >= nLookahead:
            G = 0.0
            for i in range(nLookahead):
                G += (gamma ** i) * reward_buffer[i]

            s_tau = state_buffer[0]
            a_tau = action_buffer[0]
            s_next = state_buffer[-1]
            a_next = action_buffer[-1]
            G += (gamma ** nLookahead) * Q.get_q_value(s_next, a_next)

            current_q = Q.get_q_value(s_tau, a_tau)
            new_q = current_q + alpha * (G - current_q)
            Q.set_q_value(s_tau, a_tau, new_q)
            q_value_diffs.append(abs(new_q - current_q))

        # Slide buffers
        # deque will automatically pop from the left once full, no manual trimming needed

    # Final updates for remaining state-action pairs in buffer
    T = totalSteps
    for t_rem in range(1, len(reward_buffer)):
        n = len(reward_buffer) - t_rem
        G = 0.0
        for i in range(n):
            G += (gamma ** i) * reward_buffer[t_rem + i]

        s_tau = state_buffer[t_rem]
        a_tau = action_buffer[t_rem]
        s_end = state_buffer[-1]
        a_end = action_buffer[-1]
        G += (gamma ** n) * Q.get_q_value(s_end, a_end)

        current_q = Q.get_q_value(s_tau, a_tau)
        new_q = current_q + alpha * (G - current_q)
        Q.set_q_value(s_tau, a_tau, new_q)
        q_value_diffs.append(abs(new_q - current_q))

    return current_state, Q, edr_snapshot


def train_sarsa_policy(
    edges, goal_edges, p_swap, p_gen, max_age,
    totalSteps, nLookahead, epsilon, gamma, alpha,
    edr_window_size, bin_size=0.05, seed=0):
    random.seed(seed)
    np.random.seed(seed)
    _, Q, _ = run_n_step_sarsa(
        initialEdges=edges,
        goalEdges=goal_edges,
        totalSteps=totalSteps,
        nLookahead=nLookahead,
        epsilon=epsilon,
        gamma=gamma,
        alpha=alpha,
        pGen=p_gen,
        pSwap=p_swap,
        maxAge=max_age,
        bin_size=0.05,
        edr_window_size=edr_window_size,
        plot=False
    )
    return Q


# === Training Configuration ===
edges = [(0, 1), (1, 2), (3, 2), (2, 4)]
goalEdges = [(3, 4), (0, 4)]
pSwap = 0.5
pGen = 0.5
maxAge = 2
totalSteps = 30000000
nLookahead = 1
epsilon = 0.2
gamma = 0.99
alpha = 0.2
bin_size = 0.01

_, Q, _ = run_n_step_sarsa(
    initialEdges=edges,
    goalEdges=goalEdges,
    totalSteps=totalSteps,
    nLookahead=nLookahead,
    epsilon=epsilon,
    gamma=gamma,
    alpha=alpha,
    pGen=pGen,
    pSwap=pSwap,
    maxAge=maxAge,
    edr_window_size=100,
    plot=False,
    bin_size =bin_size
)


#simulate_policy(Q_table =Q, edges=edges, goal_edges=goalEdges, p_swap=pSwap, p_gen=pGen, max_age=maxAge, num_steps=100000, edr_window_size=100, bin_size=bin_size, plot=True)






Starting SARSA run...


In [7]:
print(f"Total unique states visited: {len(Q.state_visits)}")
visits = list(Q.state_visits.values())
print(f"Average visits per state: {np.mean(visits):.2f}")
print(f"States visited only once: {sum(1 for v in visits if v == 1)}")
print(f"States visited less than 5 : {sum(1 for v in visits if v <= 5)}")


Total unique states visited: 50262
Average visits per state: 601.24
States visited only once: 5267
States visited less than 5 : 13808


In [9]:
print(f"Total unique states visited: {len(Q.state_visits)}")
visits = list(Q.state_visits.values())
print(f"Average visits per state: {np.mean(visits):.2f}")
print(f"States visited only once: {sum(1 for v in visits if v == 1)}")
print(f"States visited less than 5 : {sum(1 for v in visits if v <= 5)}")


Total unique states visited: 57759
Average visits per state: 1510.56
States visited only once: 5295
States visited less than 5 : 14028
