# **N-STEP SARSA CODE**

In [1]:
import random
import matplotlib.pyplot as plt
import numpy as np
import math
from collections import deque, defaultdict, Counter
#from linearApprox import *
import itertools


#  **UTILITY CODE**

In [2]:
def get_augmented_state(state, edrs, goal_order=None):
    if goal_order is None:
        goal_order = sorted(edrs.keys())
    edr_vector = tuple(edrs[goal] for goal in goal_order)
    sorted_state = sorted(state, key=lambda x: (x[0][0], x[0][1]))
    return (tuple(sorted_state), edr_vector)

def ageEntanglements(augmented_state, maxAge):
    ent_state, edr_bins = augmented_state
    new_state = []
    for edge, age in ent_state:
        if age >= 0:
            new_age = age + 1
            if new_age > maxAge:
                new_state.append((edge, -1))
            else:
                new_state.append((edge, new_age))
        else:
            new_state.append((edge, age))
    return (tuple(new_state), edr_bins)

def generateEntanglement(augmented_state, pGen, initial_edges):
    ent_state, edr_bins = augmented_state

    # Set of initial edges (normalized to (min(u,v), max(u,v)) form)
    initial_edges_set = { (min(u,v), max(u,v)) for (u,v) in initial_edges }

    new_state = []
    seen_edges = set()

    for edge, age in ent_state:
        normalized_edge = (min(edge[0], edge[1]), max(edge[0], edge[1]))
        
        if normalized_edge in initial_edges_set:
            # This is an initial edge, so it can regrow
            if age < 0:  # currently dead
                if random.random() < pGen:
                    new_state.append((normalized_edge, 1))
                else:
                    new_state.append((normalized_edge, -1))
            else:
                new_state.append((normalized_edge, age))
        else:
            # Non-initial edge: only keep if still alive
            if age >= 0:
                new_state.append((normalized_edge, age))
            # Otherwise, if dead and non-initial, don't add back at all (completely delete)
        
        seen_edges.add(normalized_edge)

    # Safety check: Add missing initial edges if completely missing
    for edge in initial_edges_set:
        if edge not in seen_edges:
            if random.random() < pGen:
                new_state.append((edge, 1))
            else:
                new_state.append((edge, -1))

    return (tuple(new_state), edr_bins)






def jains_index(edrs):
    """Compute Jain's Fairness Index."""
    if all(edr == 0 for edr in edrs.values()):
        return 0.0
    numerator = sum(edrs.values())**2
    denominator = len(edrs) * sum(v**2 for v in edrs.values())
    return numerator / denominator if denominator > 0 else 0.0

def featurize_state(state, goal_order, master_edge_list):
    ent_state, edrs = state
    edge_age_map = {edge: age for edge, age in ent_state}

    edge_features = []
    for edge in master_edge_list:
        age = edge_age_map.get(edge, -1)
        edge_features.append(age / 10.0 if age >= 0 else -1.0)

    edr_features = list(edrs)
    return np.array(edge_features + edr_features, dtype=np.float32)


class LinearQApproximator:
    def __init__(self, feature_size):
        self.weights = {}  # Dict[action_key] = weight_vector
        self.feature_size = feature_size

    def _action_key(self, action):
        consumed_paths, goal_list = action

        # Handle the no-op case: goal_list is None for no-op action
        if goal_list is None:
            return ((), None)  # Ensure this is hashable for no-op action

        # Sort paths and goals for consistency in hashing
        sorted_paths = tuple(sorted(tuple(sorted(path)) for path in consumed_paths))
        sorted_goals = tuple(sorted(goal_list))
        return (sorted_paths, sorted_goals)

    def _init_weights(self, action_key):
        if action_key not in self.weights:
            self.weights[action_key] = np.zeros(self.feature_size)

    def get_q_value(self, features, action):
        key = self._action_key(action)
        self._init_weights(key)
        return float(np.dot(self.weights[key], features))

    def update(self, features, action, target, alpha):
        key = self._action_key(action)
        self._init_weights(key)
        prediction = np.dot(self.weights[key], features)
        error = target - prediction
        self.weights[key] += alpha * error * features


def get_possible_multi_actionsold(ent_state, goalEdges, nestedSwaps=False, max_path_length=None):
    import itertools

    actions = []
    existing_edges = {tuple(sorted(edge)) for edge, age in ent_state if age >= 0}

    def find_paths(start, visited=None, path=None, depth=0):
        if visited is None:
            visited = set()
        if path is None:
            path = []

        paths = []
        for edge in existing_edges:
            if edge in visited:
                continue
            u, v = edge
            if u == start or v == start:
                next_node = v if u == start else u
                new_path = path + [edge]
                paths.append(new_path)
                if max_path_length is None or depth < max_path_length:
                    paths.extend(find_paths(next_node, visited | {edge}, new_path, depth + 1))
        return paths

    single_goal_actions = []
    for goal in goalEdges:
        start, end = goal
        paths = find_paths(start)
        for path in paths:
            if not path or len(path)<2:
                continue
            nodes = [n for edge in path for n in edge]
            counts = {node: nodes.count(node) for node in nodes}
            endpoints = [node for node, count in counts.items() if count == 1]
            if len(endpoints) != 2:
                continue
            if not nestedSwaps:
                if set(endpoints) != set(goal):
                    continue
            normalized_path = [tuple(sorted(e)) for e in path]
            single_goal_actions.append((normalized_path, goal))
            actions.append(([normalized_path], [goal]))

    for k in range(2, len(single_goal_actions) + 1):
        for combo in itertools.combinations(single_goal_actions, k):
            paths, goals = zip(*combo)
            flat_edges = [tuple(sorted(e)) for path in paths for e in path]
            if len(flat_edges) == len(set(flat_edges)):
                actions.append((list(paths), list(goals)))

    actions.append(([], None))
    return actions




def compute_reward(action, goal_success_queues, pSwap, mode="basic", alpha=1.0, noop_penalty=0.0):
    epsilon = 0.0001
    consumed_edges, goals = action
    if not goals or not consumed_edges:
        return -noop_penalty, False

    total_reward = 0.0
    any_success = False
    used_edges = set()

    for goal, path in zip(goals, consumed_edges):
        path_edges = set(path)
        if not path_edges.isdisjoint(used_edges):
            continue
        used_edges.update(path_edges)

        success_prob = pSwap ** (len(path) - 1)
        edr = sum(goal_success_queues[goal]) / len(goal_success_queues[goal]) + epsilon
        x = success_prob / edr
        success = (random.random() < success_prob)
        any_success = any_success or success

        if mode == "partial":
            base = math.log(1 + x)
            total_reward += base if success else 0.5 * base
        else:
            total_reward += math.log(1 + x) if success else 0.0

    return total_reward, any_success

def performAction(action, augmented_state, pSwap, nestedSwaps=False, system_goals=None):
    consumed_paths, _ = action
    ent_state, edr_bins = augmented_state
    new_state = list(ent_state)

    normalized_goal_edges = set((min(u, v), max(u, v)) for u, v in (system_goals or []))
    busy_nodes = set(u for (u, v), age in ent_state if age >= 0 for u, v in [(u, v)])

    used_edges = set()  # Tracks entanglements used this timestep (normalized)

    for path in consumed_paths:
        if not path:
            continue

        normalized_path = [tuple(sorted(e)) for e in path]

        # Prevent reusing entanglements already used this timestep
        if any(edge in used_edges for edge in normalized_path):
            continue

        consumed_ages = []
        for edge_to_consume in normalized_path:
            for i, (edge, age) in enumerate(new_state):
                if tuple(sorted(edge)) == edge_to_consume:
                    consumed_ages.append(age)
                    busy_nodes.discard(edge[0])
                    busy_nodes.discard(edge[1])
                    new_state[i] = (edge, -1)
                    break

        # Mark all edges in this path as used
        used_edges.update(normalized_path)

        # Attempt swap
        success_prob = pSwap ** (len(path) - 1)
        swap_success = random.random() < success_prob

        if not swap_success:
            continue

        # Determine new edge from endpoints
        nodes = [n for edge in normalized_path for n in edge]
        node_counts = {node: nodes.count(node) for node in nodes}
        endpoints = [node for node, count in node_counts.items() if count == 1]

        if len(endpoints) != 2:
            continue

        start, end = endpoints
        new_edge = (min(start, end), max(start, end))

        if new_edge in normalized_path:
            continue

        if start in busy_nodes or end in busy_nodes:
            continue

        new_age = max(consumed_ages) if consumed_ages else 0
        alive_edges = {edge for edge, age in new_state if age >= 0}

        if new_edge not in alive_edges:
            if nestedSwaps or new_edge in normalized_goal_edges:
                new_state.append((new_edge, new_age))
                busy_nodes.update([start, end])

    # Remove dead edges
    new_state = [pair for pair in new_state if pair[1] != -1]
    return (tuple(new_state), edr_bins)

def get_possible_multi_actions(ent_state, goalEdges, nestedSwaps=False, max_path_length=None):
    import itertools

    actions = []
    existing_edges = {tuple(sorted(edge)) for edge, age in ent_state if age >= 0}

    def find_paths(start, visited=None, path=None, depth=0):
        if visited is None:
            visited = set()
        if path is None:
            path = []

        paths = []
        for edge in existing_edges:
            if edge in visited:
                continue
            u, v = edge
            if u == start or v == start:
                next_node = v if u == start else u
                new_path = path + [edge]
                paths.append(new_path)
                if max_path_length is None or depth < max_path_length:
                    paths.extend(find_paths(next_node, visited | {edge}, new_path, depth + 1))
        return paths

    # Single-goal actions
    single_goal_actions = []
    for goal in goalEdges:
        start, end = goal
        paths = find_paths(start)
        for path in paths:
            if not path or len(path) < 2:  # 🛡️ Skip trivial 1-hop paths
                continue

            nodes = [n for edge in path for n in edge]
            counts = {node: nodes.count(node) for node in nodes}
            endpoints = [node for node, count in counts.items() if count == 1]
            if len(endpoints) != 2:
                continue

            if set(endpoints) != set(goal):
                continue

            normalized_path = [tuple(sorted(e)) for e in path]
            single_goal_actions.append((normalized_path, goal))
            actions.append(([normalized_path], [goal]))

    # Multi-goal disjoint actions
    for k in range(2, len(single_goal_actions) + 1):
        for combo in itertools.combinations(single_goal_actions, k):
            paths, goals = zip(*combo)

            if len(set(goals)) != len(goals):  # 🛡️ Prevent duplicate goals
                continue

            flat_edges = [tuple(sorted(e)) for path in paths for e in path]
            if len(flat_edges) == len(set(flat_edges)):  # 🛡️ Ensure disjoint paths
                actions.append((list(paths), list(goals)))

    # Always allow no-op action
    actions.append(([], None))

    return actions


# **SIMULATION CODE**

In [3]:
def simulate_policy(
    Q_table,
    edges,
    goal_edges,
    pSwap,
    pGen,
    max_age,
    num_steps,
    edr_window_size=100,
    burn_in=None,
    plot=True,
    nestedSwaps=False
):    
    
    
    nodes = set()
    for u, v in edges:
        nodes.add(u)
        nodes.add(v)
    nodes = sorted(list(nodes))

    master_edge_list = []
    for i in range(len(nodes)):
        for j in range(i+1, len(nodes)):
            master_edge_list.append((nodes[i], nodes[j]))

    if burn_in is None:
        burn_in = num_steps // 2

    raw = [(e, -1) for e in edges]
    current = get_augmented_state(raw, {g:0.0 for g in goal_edges}, goal_order=goal_edges)

    recent = {g: [] for g in goal_edges}
    edr_hist, jain_hist, tp_hist = {g:[] for g in goal_edges}, [], []
    valids, acts, qvals = [], [], []

    for t in range(num_steps):
        ent_state, _ = current
        acts_all = get_possible_multi_actions(ent_state, goal_edges, nestedSwaps=nestedSwaps)

        if ([], None) not in acts_all:
            acts_all.append(([], None))
        real = [a for a in acts_all if a != ([], None)]
        avail = len(real) > 0
        valids.append(1 if avail else 0)

        feats = featurize_state(current, goal_edges, master_edge_list)
        best_a, best_q = max(((a, Q_table.get_q_value(feats, a)) for a in acts_all), key=lambda x: x[1])
        qvals.append(best_q)
        acts.append(1.0 if (avail and best_a in real) else 0.0)

        nxt = performAction(best_a, current, pSwap=pSwap, nestedSwaps=nestedSwaps)
        nxt = ageEntanglements(nxt, max_age)
        nxt = generateEntanglement(nxt, pGen, edges)

        consumed_paths, goals = best_a
        if goals is not None:
            for g, path in zip(goals, consumed_paths):
                if path:
                    succ = random.random() < (pSwap ** (len(path) - 1))
                    recent[g].append(1 if succ else 0)
                else:
                    recent[g].append(0)
            for g in goal_edges:
                if g not in goals:
                    recent[g].append(0)
        else:
            for g in goal_edges:
                recent[g].append(0)

        if len(recent[g]) > edr_window_size:
            recent[g].pop(0)

        edrs = {g: sum(recent[g]) / len(recent[g]) for g in goal_edges}
        for g in goal_edges:
            edr_hist[g].append(edrs[g])

        total = sum(edrs.values())
        tp_hist.append(total)
        jain_hist.append(jains_index(edrs))

        current = get_augmented_state(nxt[0], edrs, goal_order=goal_edges)

    if plot:
        fig, axs = plt.subplots(1, 3, figsize=(18, 5))
        fig.suptitle(
            f"Policy Sim — pSwap={pSwap}, pGen={pGen}, maxAge={max_age}, steps={num_steps}, window={edr_window_size}",
            fontsize=14
        )

        # (0) EDR + Jain
        ax0 = axs[0]
        for g in goal_edges:
            ax0.plot(edr_hist[g], label=f"EDR {g}")
        ax0.plot(jain_hist, '--', label="Jain's", linewidth=2)
        ax0.set_title("EDR (solid) & Jain (dashed)")
        ax0.set_xlabel("Timestep")
        ax0.set_ylabel("Value")
        ax0.set_ylim(0, 1.05)
        ax0.legend()

        # (1) single Pareto point after burn-in
        ax1 = axs[1]
        avg_tp = np.mean(tp_hist[burn_in:])
        avg_jain = np.mean(jain_hist[burn_in:])
        ax1.scatter([avg_tp], [avg_jain], s=100, c='crimson')
        ax1.set_title("Final Pareto Point")
        ax1.set_xlabel("Avg Throughput")
        ax1.set_ylabel("Avg Jain")
        ax1.set_xlim(0, max(tp_hist) * 1.1)
        ax1.set_ylim(0, 1.05)
        ax1.text(avg_tp, avg_jain, f"  ({avg_tp:.3f}, {avg_jain:.3f})")

        # (2) best Q-value
        ax2 = axs[2]
        ax2.plot(qvals, color='slateblue')
        ax2.set_title("Best Q-Value Over Time")
        ax2.set_xlabel("Timestep")
        ax2.set_ylabel("Q-Value")

        plt.tight_layout(rect=[0, 0, 1, 0.95])
        plt.show()

    burn_in_idx = 5000
    final_edrs = {g: np.mean(edr_hist[g][burn_in_idx:]) for g in goal_edges}
    final_tp = sum(final_edrs.values())
    final_jain = jains_index(final_edrs)

    print("\nMetrics After Burn-in (first 5000 steps ignored):")
    print("Mean EDRs:", {g: f"{v:.4f}" for g, v in final_edrs.items()})
    print(f"Total Throughput (sum of EDRs): {final_tp:.4f}")
    print(f"Jain's Fairness Index: {final_jain:.4f}")

    return {
        "edr_history": edr_hist,
        "jain_history": jain_hist,
        "throughput_history": tp_hist,
        "q_values": qvals
    }


In [4]:
def multi_simulate_policy(
    Q_table,
    edges,
    goal_edges,
    pSwap,
    pGen,
    max_age,
    num_steps,
    edr_window_size=100,
    burn_in=None,
    seeds=[10, 20, 30, 40, 50],
    plot=True,
    nestedSwaps=False
):
    if burn_in is None:
        burn_in = num_steps // 2

    all_edrs = {g: [] for g in goal_edges}
    all_jains = []
    all_tp = []

    edr_time_series = {g: [] for g in goal_edges}
    jain_time_series = []

    goal_colors = {
        goal_edges[0]: "tab:blue",
        goal_edges[1]: "tab:green"
    }

    for seed in seeds:
        random.seed(seed)
        np.random.seed(seed)

        result = simulate_policy(
            Q_table=Q_table,
            edges=edges,
            goal_edges=goal_edges,
            pSwap=pSwap,
            pGen=pGen,
            max_age=max_age,
            num_steps=num_steps,
            edr_window_size=edr_window_size,
            burn_in=burn_in,
            plot=False,
            nestedSwaps=nestedSwaps
        )

        # Final metrics
        edr_hist = result["edr_history"]
        jain_hist = result["jain_history"]
        tp_hist = result["throughput_history"]

        for g in goal_edges:
            edr_time_series[g].append(edr_hist[g])
        jain_time_series.append(jain_hist)

        burn_in_idx = 5000
        edrs_final = {g: np.mean(edr_hist[g][burn_in_idx:]) for g in goal_edges}
        tp_final = sum(edrs_final.values())
        jain_final = jains_index(edrs_final)

        for g in goal_edges:
            all_edrs[g].append(edrs_final[g])
        all_tp.append(tp_final)
        all_jains.append(jain_final)

    if plot:
        fig, axs = plt.subplots(1, 3, figsize=(18, 5))
        fig.suptitle(
            f"Multi-Sim — {len(seeds)} Seeds | pSwap={pSwap}, pGen={pGen}, maxAge={max_age}, steps={num_steps}",
            fontsize=14
        )

        # --- EDR time series ---
        ax0 = axs[0]
        for g in goal_edges:
            for run_edr in edr_time_series[g]:
                ax0.plot(run_edr, color=goal_colors[g], alpha=0.3)
            ax0.plot(np.mean(edr_time_series[g], axis=0), color=goal_colors[g], linewidth=2, label=f"EDR {g}")
        ax0.set_title("EDRs Over Time")
        ax0.set_xlabel("Timestep")
        ax0.set_ylabel("EDR")
        ax0.set_ylim(0, 1.05)
        ax0.legend()

        # --- Jain's Index time series ---
        ax1 = axs[1]
        for run_jain in jain_time_series:
            ax1.plot(run_jain, color="gray", alpha=0.3)
        ax1.plot(np.mean(jain_time_series, axis=0), color="black", linewidth=2, label="Jain’s Index")
        ax1.set_title("Jain's Index Over Time")
        ax1.set_xlabel("Timestep")
        ax1.set_ylabel("Jain's Fairness")
        ax1.set_ylim(0, 1.05)
        ax1.legend()

        # --- Pareto scatter plot ---
        ax2 = axs[2]
        ax2.scatter(all_tp, all_jains, c='crimson', s=60, label="Runs")
        ax2.scatter(np.mean(all_tp), np.mean(all_jains), c='black', s=100, marker='x', label="Mean")
        ax2.set_title("Pareto Points Across Seeds")
        ax2.set_xlabel("Avg Throughput")
        ax2.set_ylabel("Avg Jain")
        ax2.set_xlim(0, max(all_tp) * 1.1)
        ax2.set_ylim(0, 1.05)
        for tp, jn in zip(all_tp, all_jains):
            ax2.text(tp + 0.002, jn, f"({tp:.2f}, {jn:.2f})", fontsize=8)
        ax2.legend()

        plt.tight_layout(rect=[0, 0, 1, 0.95])
        plt.show()

    return {
        "avg_edrs": {g: np.mean(all_edrs[g]) for g in goal_edges},
        "avg_tp": np.mean(all_tp),
        "avg_jain": np.mean(all_jains),
        "edrs_by_seed": all_edrs,
        "tp_by_seed": all_tp,
        "jain_by_seed": all_jains
    }


In [5]:
def compareOverParam(
    param_name, param_values,
    edges, goal_edges,
    pSwap, pGen, max_age,
    totalSteps, nLookahead,
    alpha, gamma,
    edr_window_size, reward_mode,
    initial_temperature, temperature_decay,
    seed,
    nestedSwaps,
    training_function,
    simulate_steps=50_000,
    burn_in_ratio=0.5,
    **extra_kwargs
):
    edr_results = {g: [] for g in goal_edges}
    jain_results = []
    pareto_points = []
    full_raw_data = {}

    for value in param_values:
        print(f"\n=== Training with {param_name} = {value} ===")

        if param_name == 'pGen':
            pGen = value
        elif param_name == 'pSwap':
            pSwap = value
        else:
            raise ValueError("param_name must be 'pGen' or 'pSwap'.")

        # --- Build kwargs dynamically ---
        train_kwargs = {
            'edges': edges,
            'goal_edges': goal_edges,
            'pSwap': pSwap,
            'pGen': pGen,
            'max_age': max_age,
            'seed': seed,
            'totalSteps': totalSteps,
            'alpha': alpha,
            'gamma': gamma,
            'edr_window_size': edr_window_size,
            'reward_mode': reward_mode,
            'nestedSwaps': nestedSwaps,
        }

        # --- Add nLookahead if needed ---
        if training_function.__name__ == 'train_sarsa_linear_policy':
            train_kwargs['nLookahead'] = nLookahead
            train_kwargs['initial_temperature'] = initial_temperature
            train_kwargs['temperature_decay'] = temperature_decay
            train_kwargs['log_interval'] = edr_window_size
        elif training_function.__name__ == 'train_q_learning_linear_policy':
            # rename properly for q_learning
            train_kwargs['temperature'] = initial_temperature  # <=== fix here
            train_kwargs['temperature_decay'] = temperature_decay

        # --- Add extra kwargs ---
        train_kwargs.update(extra_kwargs)

        # --- Now call training function cleanly ---
        Q = training_function(**train_kwargs)

        # Simulate
        sim_result = simulate_policy(
            Q_table=Q,
            edges=edges,
            goal_edges=goal_edges,
            pSwap=pSwap,
            pGen=pGen,
            max_age=max_age,
            num_steps=simulate_steps,
            edr_window_size=edr_window_size,
            plot=False,
            nestedSwaps=nestedSwaps
        )

        # Post-processing
        burn_in_steps = int(simulate_steps * burn_in_ratio)
        final_edrs = {g: np.mean(sim_result["edr_history"][g][burn_in_steps:]) for g in goal_edges}
        final_tp = sum(final_edrs.values())
        final_jain = jains_index(final_edrs)

        for g in goal_edges:
            edr_results[g].append(final_edrs[g])
        jain_results.append(final_jain)
        pareto_points.append((final_tp, final_jain))
        full_raw_data[value] = sim_result

    # --- Plotting (same as before) ---
    plt.figure(figsize=(8, 5))
    for g in goal_edges:
        plt.plot(param_values, edr_results[g], marker='o', label=f"EDR {g}")
    plt.xlabel(param_name)
    plt.ylabel("Mean EDR")
    plt.title(f"EDR vs {param_name}")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(8, 5))
    plt.plot(param_values, jain_results, marker='s', label="Jain's Fairness")
    plt.xlabel(param_name)
    plt.ylabel("Jain's Index")
    plt.title(f"Jain's Fairness vs {param_name}")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

    pareto_tp, pareto_jain = zip(*pareto_points)
    plt.figure(figsize=(8, 6))
    plt.scatter(pareto_tp, pareto_jain, c='crimson', s=80)
    for (tp, jain), val in zip(pareto_points, param_values):
        plt.text(tp, jain, f"{val:.2f}", fontsize=9)
    plt.xlabel("Throughput (sum EDRs)")
    plt.ylabel("Jain's Fairness Index")
    plt.title(f"Pareto Plot for different {param_name} values")
    plt.grid(True)
    plt.xlim(0, max(pareto_tp) * 1.1)
    plt.ylim(0, 1.05)
    plt.tight_layout()
    plt.show()

    return {
        "param_name": param_name,
        "param_values": param_values,
        "edr_results": edr_results,
        "jain_results": jain_results,
        "pareto_points": pareto_points,
        "full_raw_data": full_raw_data
    }


In [6]:
def compareOverParamRobust(
    param_name, param_values,
    edges, goal_edges,
    pSwap, pGen, max_age,
    totalSteps, nLookahead,
    alpha, gamma,
    edr_window_size, reward_mode,
    initial_temperature, temperature_decay,
    seed,
    nestedSwaps,
    training_function,
    simulate_steps=50_000,
    burn_in_ratio=0.5,
    trainCount=3,
    simulateCount=3,
    **extra_kwargs
):
    edr_results = {g: [] for g in goal_edges}
    jain_results = []
    pareto_points = []
    full_raw_data = {}

    for idx_value, value in enumerate(param_values):
        print(f"\n=== Training with {param_name} = {value} ===")

        # --- Dynamic parameter assignment ---
        if param_name == 'pGen':
            pGen = value
        elif param_name == 'pSwap':
            pSwap = value
        else:
            raise ValueError("param_name must be 'pGen' or 'pSwap'.")

        edrs_all_runs = {g: [] for g in goal_edges}
        jains_all_runs = []
        tp_all_runs = []

        for train_idx in range(trainCount):
            train_seed = seed + idx_value * 100 + train_idx * 10  # Safe separation
            np.random.seed(train_seed)
            random.seed(train_seed)

            # --- Build kwargs dynamically ---
            train_kwargs = {
                'edges': edges,
                'goal_edges': goal_edges,
                'pSwap': pSwap,
                'pGen': pGen,
                'max_age': max_age,
                'seed': train_seed,
                'totalSteps': totalSteps,
                'alpha': alpha,
                'gamma': gamma,
                'edr_window_size': edr_window_size,
                'reward_mode': reward_mode,
                'nestedSwaps': nestedSwaps,
            }

            if training_function.__name__ == 'train_sarsa_linear_policy':
                train_kwargs['nLookahead'] = nLookahead
                train_kwargs['initial_temperature'] = initial_temperature
                train_kwargs['temperature_decay'] = temperature_decay
                train_kwargs['log_interval'] = edr_window_size
            elif training_function.__name__ == 'train_q_learning_linear_policy':
                train_kwargs['temperature'] = initial_temperature
                train_kwargs['temperature_decay'] = temperature_decay

            train_kwargs.update(extra_kwargs)

            Q = training_function(**train_kwargs)

            # --- Simulate multiple times for each trained model ---
            sim_seeds = [train_seed + s * 1000 for s in range(simulateCount)]  # Different seeds for simulation
            multi_sim_result = multi_simulate_policy(
                Q_table=Q,
                edges=edges,
                goal_edges=goal_edges,
                pSwap=pSwap,
                pGen=pGen,
                max_age=max_age,
                num_steps=simulate_steps,
                edr_window_size=edr_window_size,
                burn_in=int(simulate_steps * burn_in_ratio),
                seeds=sim_seeds,
                plot=False,
                nestedSwaps=nestedSwaps
            )

            # --- Collect results ---
            for g in goal_edges:
                edrs_all_runs[g].append(multi_sim_result["avg_edrs"][g])
            jains_all_runs.append(multi_sim_result["avg_jain"])
            tp_all_runs.append(multi_sim_result["avg_tp"])

        # --- After all training runs ---
        final_avg_edrs = {g: np.mean(edrs_all_runs[g]) for g in goal_edges}
        final_avg_jain = np.mean(jains_all_runs)
        final_avg_tp = np.mean(tp_all_runs)

        for g in goal_edges:
            edr_results[g].append(final_avg_edrs[g])
        jain_results.append(final_avg_jain)
        pareto_points.append((final_avg_tp, final_avg_jain))
        full_raw_data[value] = {
            "edrs_by_run": edrs_all_runs,
            "jains_by_run": jains_all_runs,
            "tp_by_run": tp_all_runs
        }

    # --- Plotting ---
    plt.figure(figsize=(8, 5))
    for g in goal_edges:
        plt.plot(param_values, edr_results[g], marker='o', label=f"EDR {g}")
    plt.xlabel(param_name)
    plt.ylabel("Mean EDR")
    plt.title(f"EDR vs {param_name}")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(8, 5))
    plt.plot(param_values, jain_results, marker='s', label="Jain's Fairness")
    plt.xlabel(param_name)
    plt.ylabel("Jain's Index")
    plt.title(f"Jain's Fairness vs {param_name}")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

    pareto_tp, pareto_jain = zip(*pareto_points)
    plt.figure(figsize=(8, 6))
    plt.scatter(pareto_tp, pareto_jain, c='crimson', s=80)
    for (tp, jain), val in zip(pareto_points, param_values):
        plt.text(tp, jain, f"{val:.2f}", fontsize=9)
    plt.xlabel("Throughput (sum EDRs)")
    plt.ylabel("Jain's Fairness Index")
    plt.title(f"Pareto Plot for different {param_name} values")
    plt.grid(True)
    plt.xlim(0, max(pareto_tp) * 1.1)
    plt.ylim(0, 1.05)
    plt.tight_layout()
    plt.show()

    return {
        "param_name": param_name,
        "param_values": param_values,
        "edr_results": edr_results,
        "jain_results": jain_results,
        "pareto_points": pareto_points,
        "full_raw_data": full_raw_data
    }


In [7]:
def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

def plot_training_results(
    q_diffs, q_diffs_per_goal, edr_steps, edr_hist,
    goal_edges, fairness_history, edge_creation_counter=None,
    method_name="",
    smoothing_window=1000  # <--- added
):
    import matplotlib.pyplot as plt
    import numpy as np

    # --- Plot Global Q-value Convergence (no smoothing) ---
    plt.figure(figsize=(10, 4))
    plt.plot(q_diffs, color='royalblue', alpha=0.8)
    plt.xlabel(f"{method_name} Updates")
    plt.ylabel("Global Q-value Difference")
    plt.grid(True)
    plt.title(f"{method_name}: Global Q-value Convergence")
    plt.tight_layout()
    plt.show()

    # --- Plot Per-goal Q-value Convergence + No-Op Together (smoothed) ---
    plt.figure(figsize=(12, 6))
    linestyles = ['-', '--', ':', '-.']
    colors = plt.cm.tab10(np.linspace(0, 1, len(goal_edges) + 1))

    for i, g in enumerate(goal_edges):
        diffs = np.array(q_diffs_per_goal[g])
        steps = np.arange(len(diffs))
        mask = ~np.isnan(diffs)

        if np.sum(mask) > smoothing_window:  # only smooth if enough points
            smooth_diffs = moving_average(diffs[mask], smoothing_window)
            smooth_steps = steps[mask][:len(smooth_diffs)]
            plt.plot(
                smooth_steps, smooth_diffs,
                label=f"Goal {g}",
                linestyle=linestyles[i % len(linestyles)],
                color=colors[i],
                alpha=0.9,
                linewidth=2
            )

    # Plot No-op line if available
    if 'noop' in q_diffs_per_goal:
        diffs = np.array(q_diffs_per_goal['noop'])
        steps = np.arange(len(diffs))
        mask = ~np.isnan(diffs)

        if np.sum(mask) > smoothing_window:
            smooth_diffs = moving_average(diffs[mask], smoothing_window)
            smooth_steps = steps[mask][:len(smooth_diffs)]
            plt.plot(
                smooth_steps, smooth_diffs,
                label="No-op Action",
                linestyle=':',
                color='black',
                alpha=0.7,
                linewidth=2
            )

    plt.xlabel(f"{method_name} Updates")
    plt.ylabel("Per-goal Q-value Difference (smoothed)")
    plt.legend()
    plt.grid(True)
    plt.title(f"{method_name}: Per-Goal Q-value Convergence + No-Op (smoothed)")
    plt.tight_layout()
    plt.show()

    # --- Plot EDR Evolution + Jain's Fairness ---
    plt.figure(figsize=(10, 5))
    for g in goal_edges:
        plt.plot(edr_steps, edr_hist[g], label=f"EDR {g}")
    plt.plot(edr_steps, fairness_history, label="Jain's Fairness", linestyle="--", color="black")
    plt.xlabel("Training Step")
    plt.ylabel("EDR Estimate / Jain's Fairness")
    plt.ylim(0, 1.05)
    plt.legend()
    plt.grid(True)
    plt.title(f"{method_name}: EDR + Fairness Over Training")
    plt.tight_layout()
    plt.show()

    # --- Plot Swapped Edge Creation Frequency ---
    if edge_creation_counter is not None:
        edges = list(edge_creation_counter.keys())
        counts = list(edge_creation_counter.values())
        edge_labels = [f"{u}-{v}" for u, v in edges]

        plt.figure(figsize=(10, 5))
        plt.bar(edge_labels, counts)
        plt.xticks(rotation=45)
        plt.ylabel("Number of Swapped Creations")
        plt.grid(True, axis='y')
        plt.title(f"{method_name}: Edge Creation Frequency")
        plt.tight_layout()
        plt.show()


#  **Q-LEARNING CODE** 

In [8]:
def run_q_learning_linear_policy(
    initialEdges, goalEdges, totalSteps,
    gamma, alpha, pGen, pSwap, maxAge,
    edr_window_size=100, reward_mode="basic",
    noop_penalty=0.00, log_interval=1000,
    softmax=True, temperature=1.0, temperature_decay=0.9999,
    epsilon=0.01, nestedSwaps=False
):
    q_value_diffs = []
    q_value_diffs_per_goal = {g: [] for g in goalEdges}

    goal_success_queues = {
        g: deque([1] * (edr_window_size // 2) + [0] * (edr_window_size // 2), maxlen=edr_window_size)
        for g in goalEdges
    }

    raw = [(e, -1) for e in initialEdges]
    edr_snap = {g: 0.0 for g in goalEdges}
    current = (tuple(raw), tuple(edr_snap[g] for g in goalEdges))

    # --- Build Master Edge List ---
    nodes = set()
    for u, v in initialEdges:
        nodes.add(u)
        nodes.add(v)
    nodes = sorted(list(nodes))

    master_edge_list = []
    for i in range(len(nodes)):
        for j in range(i+1, len(nodes)):
            master_edge_list.append((nodes[i], nodes[j]))

    feature_size = len(master_edge_list) + len(goalEdges)
    Q = LinearQApproximator(feature_size=feature_size)
    temperature_curr = temperature

    edr_tracking_steps = []
    fairness_history = []
    edge_creation_counter = Counter()
    edr_tracking_history = {g: [] for g in goalEdges}

    def select_action(state, temperature):
        feats = featurize_state(state, goalEdges, master_edge_list)
        acts = get_possible_multi_actions(state[0], goalEdges)
        if ([], None) not in acts:
            acts.append(([], None))

        if softmax:
            q_vals = np.array([Q.get_q_value(feats, a) for a in acts], dtype=np.float64)
            scaled_qs = q_vals / max(temperature, 1e-6)
            exp_qs = np.exp(scaled_qs - np.max(scaled_qs))
            probs = exp_qs / np.sum(exp_qs)
            chosen = acts[np.random.choice(len(acts), p=probs)]
        else:
            if random.random() < epsilon:
                chosen = random.choice(acts)
            else:
                chosen = max(acts, key=lambda a: Q.get_q_value(feats, a))

        return chosen

    state = current

    for t in range(totalSteps):
        temperature_curr = max(0.01, temperature_curr * temperature_decay)

        action = select_action(state, temperature_curr)
        next_state = performAction(action, state, pSwap=pSwap, nestedSwaps=nestedSwaps)
        # Track newly created swapped edges (not initial ones)
        new_edges = next_state[0]  # list of (edge, age)
        for (u, v), age in new_edges:
            if (u, v) not in initialEdges and (v, u) not in initialEdges:
                edge_creation_counter[(u, v)] += 1

        
        next_state = ageEntanglements(next_state, maxAge)
        next_state = generateEntanglement(next_state, pGen, initialEdges)

        r, succ = compute_reward(
            action, goal_success_queues, pSwap,
            mode=reward_mode, noop_penalty=noop_penalty
        )

        consumed_edges, goal_list = action
        successful_goals = set(goal_list) if succ else set()
        for gh in goalEdges:
            goal_success_queues[gh].append(1 if gh in successful_goals else 0)

        edr_snap = {g: sum(goal_success_queues[g]) / len(goal_success_queues[g]) for g in goalEdges}
        augmented_next_state = (next_state[0], tuple(edr_snap[g] for g in goalEdges))

        if t % log_interval == 0:
            edr_tracking_steps.append(t)
            for g in goalEdges:
                edr_tracking_history[g].append(edr_snap[g])
            fairness = jains_index(edr_snap)                
            fairness_history.append(fairness)

        feats = featurize_state(state, goalEdges, master_edge_list)
        feats_next = featurize_state(augmented_next_state, goalEdges, master_edge_list)
        next_actions = get_possible_multi_actions(augmented_next_state[0], goalEdges)
        max_q_next = max([Q.get_q_value(feats_next, a) for a in next_actions], default=0.0)

        target = r + gamma * max_q_next
        current_q = Q.get_q_value(feats, action)
        diff = abs(current_q - target)
        q_value_diffs.append(diff)

        for g in goalEdges:
            if action[1] is not None and g in action[1]:
                q_value_diffs_per_goal[g].append(diff)
            else:
                q_value_diffs_per_goal[g].append(np.nan)

        Q.update(feats, action, target, alpha)
        state = augmented_next_state

    return Q, q_value_diffs, q_value_diffs_per_goal, edr_tracking_steps, edr_tracking_history, fairness_history, edge_creation_counter

def train_q_learning_linear_policy(
    edges, goal_edges, pSwap, pGen, max_age,
    seed, totalSteps, alpha, gamma,
    edr_window_size, reward_mode,
    softmax, temperature, temperature_decay,
    nestedSwaps, noop_penalty=0.0, plotTraining=False

):
    random.seed(seed)
    np.random.seed(seed)
    log_interval = edr_window_size

    result = run_q_learning_linear_policy(
        initialEdges=edges,
        goalEdges=goal_edges,
        totalSteps=totalSteps,
        gamma=gamma,
        alpha=alpha,
        pGen=pGen,
        pSwap=pSwap,
        maxAge=max_age,
        edr_window_size=edr_window_size,
        reward_mode=reward_mode,
        noop_penalty=noop_penalty,
        log_interval=log_interval,
        softmax=softmax,
        temperature=temperature,
        temperature_decay=temperature_decay,
        nestedSwaps=nestedSwaps
    )

    if result is None:
        print("Error: Q-learning returned None.")
        return

    Q, q_diffs, q_diffs_per_goal, edr_steps, edr_hist, fairness_history, edge_creation_counter = result


    if plotTraining:
        plot_training_results(
        q_diffs,
        q_diffs_per_goal,
        edr_steps,
        edr_hist,
        goal_edges,
        fairness_history,      # <<< ADD THIS
        edge_creation_counter,
        method_name="Q-Learning"
    )


    return Q


#  **SARSA CODE**

In [9]:
def run_n_step_sarsa_linear_multi(
    initialEdges, goalEdges, totalSteps, nLookahead,
    gamma, alpha, pGen, pSwap, maxAge,
    edr_window_size=100, reward_mode="basic",
    noop_penalty=0.00, log_interval=1000,
    initial_temperature=1.0, temperature_decay=0.9999, nestedSwaps=False
):
    # --- Build Master Edge List ---
    nodes = set()
    for u, v in initialEdges:
        nodes.add(u)
        nodes.add(v)
    nodes = sorted(list(nodes))

    master_edge_list = []
    for i in range(len(nodes)):
        for j in range(i+1, len(nodes)):
            master_edge_list.append((nodes[i], nodes[j]))

    feature_size = len(master_edge_list) + len(goalEdges)
    Q = LinearQApproximator(feature_size=feature_size)

    q_value_diffs = []
    q_value_diffs_per_goal = {g: [] for g in goalEdges}
    q_value_diffs_per_goal['noop'] = []  # Add No-op as a "goal" for tracking

    edge_creation_counter = Counter()


    goal_success_queues = {
        g: deque([1] * (edr_window_size // 2) + [0] * (edr_window_size // 2), maxlen=edr_window_size)
        for g in goalEdges
    }

    raw = [(e, -1) for e in initialEdges]
    edr_snap = {g: 0.0 for g in goalEdges}
    current = (tuple(raw), tuple(edr_snap[g] for g in goalEdges))

    state_buffer = deque([current])
    reward_buffer = deque()
    temperature = initial_temperature

    edr_tracking_steps = []
    fairness_history = []
    edr_tracking_history = {g: [] for g in goalEdges}


    # --- Updated select_action ---
    def select_action(state, temperature):
        feats = featurize_state(state, goalEdges, master_edge_list)
        acts = get_possible_multi_actions(state[0], goalEdges, nestedSwaps=nestedSwaps)

        if ([], None) not in acts:
            acts.append(([], None))

        q_values = np.array([Q.get_q_value(feats, a) for a in acts], dtype=np.float64)
        scaled_qs = q_values / max(temperature, 1e-6)
        exp_qs = np.exp(scaled_qs - np.max(scaled_qs))
        probs = exp_qs / np.sum(exp_qs)

        idx = np.random.choice(len(acts), p=probs)
        chosen = acts[idx]
        
        return chosen

    # --- Start Training Loop ---
    action_buffer = deque([select_action(current, temperature)])

    for t in range(totalSteps):
        if (t+1) % 100_000 == 0:
            print(f"Step {t+1}")

        temperature = max(0.01, temperature * temperature_decay)
        S_t = state_buffer[-1]
        A_t = action_buffer[-1]
        
        # --- Debug print when exactly 3 actions are available ---
        
        acts = get_possible_multi_actions(S_t[0], goalEdges, nestedSwaps=nestedSwaps)
        if ([], None) not in acts:
            acts.append(([], None))

        # if len(acts) == 3:
        #     feats = featurize_state(S_t, goalEdges, master_edge_list)
            
        #     print(f"\n[DEBUG] Step {t} - 3 Actions available:")

        #     all_qs = []
        #     for a in acts:
        #         q_val = Q.get_q_value(feats, a)
        #         all_qs.append((a, q_val))
            
        #     # Sort actions by Q descending for easier reading
        #     all_qs = sorted(all_qs, key=lambda x: -x[1])

        #     for i, (a, q) in enumerate(all_qs):
        #         print(f"  Action {i+1}: {a}, Q={q:.4f}")

        #     chosen_action = A_t
        #     chosen_q = Q.get_q_value(feats, chosen_action)
        #     print(f"--> CHOSEN action: {chosen_action} | Q={chosen_q:.4f}")


        ns = performAction(A_t, S_t, pSwap=pSwap, nestedSwaps=nestedSwaps, system_goals=goalEdges)
        # Track newly created swapped edges (not original edges)
        new_edges = ns[0]  # list of (edge, age)
        for (u, v), age in new_edges:
            if (u, v) not in initialEdges and (v, u) not in initialEdges:
                edge_creation_counter[(u, v)] += 1



        ns = ageEntanglements(ns, maxAge)
        ns = generateEntanglement(ns, pGen, initial_edges=initialEdges)
        r, succ = compute_reward(
            A_t, goal_success_queues, pSwap,
            mode=reward_mode, noop_penalty=noop_penalty
        )

        consumed_edges, goal_list = A_t
        successful_goals = set(goal_list) if succ else set()
        
        for gh in goalEdges:
            goal_success_queues[gh].append(1 if gh in successful_goals else 0)

        reward_buffer.append(r)

        edr_snap = {g: sum(goal_success_queues[g]) / len(goal_success_queues[g]) for g in goalEdges}
        next_state = (ns[0], tuple(edr_snap[g] for g in goalEdges))
        if t % log_interval == 0:
            edr_tracking_steps.append(t)
            for g in goalEdges:
                edr_tracking_history[g].append(edr_snap[g])
                # Compute Jain's Index at each log point
            fairness = jains_index(edr_snap)
            fairness_history.append(fairness)


        A_next = select_action(next_state, temperature)

        state_buffer.append(next_state)
        action_buffer.append(A_next)

        if len(reward_buffer) >= nLookahead:
            G = sum((gamma**i) * reward_buffer[i] for i in range(nLookahead))
            s_n = state_buffer[nLookahead]
            a_n = action_buffer[nLookahead]
            feats_n = featurize_state(s_n, goalEdges, master_edge_list)
            G += (gamma**nLookahead) * Q.get_q_value(feats_n, a_n)

            s_tau = state_buffer[0]
            a_tau = action_buffer[0]
            feats_tau = featurize_state(s_tau, goalEdges, master_edge_list)
            old_q = Q.get_q_value(feats_tau, a_tau)
            diff = abs(G - old_q)
            q_value_diffs.append(diff)
            
            
            for gg in goalEdges:
                if a_tau[1] is not None and gg in a_tau[1]:
                    q_value_diffs_per_goal[gg].append(diff)
                else:
                    q_value_diffs_per_goal[gg].append(float('nan'))

            # Special handling for no-op
            if a_tau[1] is None:  # If it's no-op action
                q_value_diffs_per_goal['noop'].append(diff)
            else:
                q_value_diffs_per_goal['noop'].append(float('nan'))


            Q.update(feats_tau, a_tau, G, alpha)

            state_buffer.popleft()
            action_buffer.popleft()
            reward_buffer.popleft()

    while reward_buffer:
        n = len(reward_buffer)
        G = sum((gamma**i) * reward_buffer[i] for i in range(n))
        if n < len(state_buffer):
            s_n = state_buffer[n]
            a_n = action_buffer[n]
            feats_n = featurize_state(s_n, goalEdges, master_edge_list)
            G += (gamma**n) * Q.get_q_value(feats_n, a_n)

        s_tau = state_buffer[0]
        a_tau = action_buffer[0]
        feats_tau = featurize_state(s_tau, goalEdges, master_edge_list)
        old_q = Q.get_q_value(feats_tau, a_tau)
        diff = abs(G - old_q)
        q_value_diffs.append(diff)

        for gg in goalEdges:
            if a_tau[1] is not None and gg in a_tau[1]:
                q_value_diffs_per_goal[gg].append(diff)
            else:
                q_value_diffs_per_goal[gg].append(float('nan'))

        Q.update(feats_tau, a_tau, G, alpha)

        state_buffer.popleft()
        action_buffer.popleft()
        reward_buffer.popleft()

    return Q, q_value_diffs, q_value_diffs_per_goal, edr_tracking_steps, edr_tracking_history, fairness_history, edge_creation_counter


def train_sarsa_linear_policy(
    edges, goal_edges, pSwap, pGen, max_age,
    seed, totalSteps, nLookahead,
    alpha, gamma,
    edr_window_size, reward_mode,
    log_interval,
    initial_temperature, temperature_decay, nestedSwaps,
    noop_penalty=0.0, plotTraining=False
):
    random.seed(seed)
    np.random.seed(seed)
    log_interval = edr_window_size

    result = run_n_step_sarsa_linear_multi(
        initialEdges=edges,
        goalEdges=goal_edges,
        totalSteps=totalSteps,
        nLookahead=nLookahead,
        gamma=gamma,
        alpha=alpha,
        pGen=pGen,
        pSwap=pSwap,
        maxAge=max_age,
        edr_window_size=edr_window_size,
        reward_mode=reward_mode,
        noop_penalty=noop_penalty,
        log_interval=log_interval,
        initial_temperature=initial_temperature,
        temperature_decay=temperature_decay,
        nestedSwaps=nestedSwaps
    )

    if result is None:
        print("Error: run_n_step_sarsa_linear_multi returned None.")
        return

    Q, q_diffs, q_diffs_per_goal, edr_steps, edr_hist, fairness_history, edge_creation_counter = result

    if plotTraining:
        plot_training_results(
            q_diffs, q_diffs_per_goal, edr_steps, edr_hist,
            goal_edges, fairness_history, edge_creation_counter,
            method_name="SARSA"
        )

    return Q




# **RUNNING CODE**

In [None]:
seed = 30
random.seed(seed)
np.random.seed(seed)
#####
edges = [(0,1), (1,2), (2,3), (3,4), (2,5)]
goal_edges = [(0,4), (3,5)]

edges = [(0,1), (1,2), (2,3), (3,4)]
goal_edges = [(1,4),(0,2)]
pSwap       = 0.7
pGen        = 0.7
maxAge      =  3
totalSteps     = 4_000_000
nestedSwap = False
#####
nLookahead     = 5
gamma          = 0.995
alpha          = 0.01
windowSize     = 1000
reward_mode    = 'basic'
initial_temperature = 6.0
final_temperature   = 0.2
temperature_decay   = (final_temperature / initial_temperature) ** (1.0 / (totalSteps * 0.99))
temperature_decay = (final_temperature / initial_temperature) ** (1.0 / (totalSteps * 1.2))



# --- Train N-step SARSA policy (with temperature-based softmax) ---
Q1 = train_sarsa_linear_policy(
    edges=edges,
    goal_edges=goal_edges,
    pSwap=pSwap,
    pGen=pGen,
    max_age=maxAge,
    seed=seed,
    totalSteps=totalSteps,
    nLookahead=nLookahead,
    alpha=alpha,
    gamma=gamma,
    edr_window_size=windowSize,
    reward_mode=reward_mode,
    noop_penalty=0.0,
    log_interval=1000,
    initial_temperature=initial_temperature,
    temperature_decay=temperature_decay,
    nestedSwaps=nestedSwap,
    plotTraining=True
)
print('done training')
simulate_policy(
    Q_table=Q1,
    edges=edges,
    goal_edges=goal_edges,
    pSwap=pSwap,
    pGen=pGen,
    max_age=maxAge,
    num_steps=100_000,
    edr_window_size=windowSize,
    plot=True,
    nestedSwaps=nestedSwap
)




Step 100000
Step 200000
Step 300000
Step 400000
Step 500000
Step 600000
Step 700000
Step 800000
Step 900000
Step 1000000
Step 1100000
Step 1200000
Step 1300000
Step 1400000
Step 1500000
Step 1600000
Step 1700000


In [11]:
# --- Print top Q-values for each goal after training ---
print("\nTop Q-values per goal after training:")

master_edge_list = []
nodes = set()
for u, v in edges:
    nodes.add(u)
    nodes.add(v)
nodes = sorted(list(nodes))
for i in range(len(nodes)):
    for j in range(i+1, len(nodes)):
        master_edge_list.append((nodes[i], nodes[j]))

# Create an artificial "perfect state" where all initial edges are alive
ent_state = [(edge, 1) for edge in edges]
edr_vector = tuple(0.1 for _ in goal_edges)  # dummy EDRs
state = (tuple(ent_state), edr_vector)

feats = featurize_state(state, goal_edges, master_edge_list)
acts_all = get_possible_multi_actions(ent_state, goal_edges, nestedSwaps=nestedSwap)

goal_qs = {g: [] for g in goal_edges}
for a in acts_all:
    consumed_paths, goals = a
    if goals is not None:
        for g in goals:
            goal_qs[g].append(Q1.get_q_value(feats, a))

for g in goal_edges:
    if goal_qs[g]:
        print(f"Goal {g}: Best Q = {max(goal_qs[g]):.4f}")
    else:
        print(f"Goal {g}: No available actions")



Top Q-values per goal after training:
Goal (0, 4): Best Q = 89.1089
Goal (0, 2): Best Q = 91.3557


In [12]:
edges = [(0,1), (1,2), (2,3), (3,4), (2,5)]
goal_edges = [(0,4), (3,5)]
pSwap       = 0.6
pGen        = 0.6
maxAge      =  3
totalSteps     = 4_000_000
nestedSwap = False


gamma          = 0.99
alpha          = 0.01
windowSize     = 1000
reward_mode    = 'basic'
softmax= True
# Softmax temperature parameters
initial_temperature = 6.0
final_temperature   = 0.1
temperature_decay   = (final_temperature / initial_temperature) ** (1.0 / (totalSteps * 0.9))

# Seed
seed = 30
random.seed(seed)
np.random.seed(seed)

Q = train_q_learning_linear_policy(
    edges=edges,
    goal_edges=goal_edges,
    pSwap=pSwap,
    pGen=pGen,
    max_age=maxAge,
    seed=seed,
    totalSteps=totalSteps,
    alpha=alpha,
    gamma=gamma,
    edr_window_size=windowSize,
    reward_mode=reward_mode,
    noop_penalty=0.0,          
    softmax=softmax,
    temperature=initial_temperature,
    temperature_decay=temperature_decay,
    epsilon=0.05,       
    nestedSwaps=nestedSwap,
    plotTraining=True
)

    
print('done training')
# --- Evaluate learned policy ---
simulate_policy(
    Q_table=Q,
    edges=edges,
    goal_edges=goal_edges,
    pSwap=pSwap,
    pGen=pGen,
    max_age=maxAge,
    num_steps=100_000,
    edr_window_size=windowSize,
    plot=True,
    nestedSwaps=nestedSwap
)


TypeError: train_q_learning_linear_policy() got an unexpected keyword argument 'epsilon'

In [None]:
nestedSwap = True

# --- Train N-step SARSA policy (with temperature-based softmax) ---
Q1 = train_sarsa_linear_policy(
    edges=edges,
    goal_edges=goal_edges,
    pSwap=pSwap,
    pGen=pGen,
    max_age=maxAge,
    seed=seed,
    totalSteps=totalSteps,
    nLookahead=nLookahead,
    alpha=alpha,
    gamma=gamma,
    edr_window_size=windowSize,
    reward_mode=reward_mode,
    noop_penalty=0.0,
    log_interval=1000,
    initial_temperature=initial_temperature,
    temperature_decay=temperature_decay,
    nestedSwaps=nestedSwap,
    plotTraining=True
)
print('done training')
# --- Evaluate learned policy ---
simulate_policy(
    Q_table=Q1,
    edges=edges,
    goal_edges=goal_edges,
    pSwap=pSwap,
    pGen=pGen,
    max_age=maxAge,
    num_steps=100_000,
    edr_window_size=windowSize,
    plot=True,
    nestedSwaps=nestedSwap
)

In [None]:
paramValues = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
totalSteps = 500_000
seed = 11
final_temperature= 0.1
initial_temperature = 5
temperature_decay   = (final_temperature / initial_temperature) ** (1.0 / (totalSteps * 0.9))


In [None]:
compareOverParamRobust(
    param_name='pGen',
    param_values=paramValues,
    edges=[(0,1), (1,2), (2,3), (3,4), (0,5), (5,3)],
    goal_edges=[(0,4), (1,3)],
    pSwap=0.6,
    pGen=0.6,
    max_age=5,
    totalSteps=totalSteps,
    nLookahead=3,
    alpha=0.01,
    gamma=0.99,
    edr_window_size=1000,
    reward_mode='basic',
    initial_temperature=5.0,
    temperature_decay=temperature_decay,
    seed=30,
    nestedSwaps=False,
    training_function=train_q_learning_linear_policy,
    softmax=True,          # passed into training function
    temperature=5.0,       # passed into training function
    epsilon=0.01,          # passed into training function
    trainCount=3,          # new robust arg
    simulateCount=3        # new robust arg
)


In [None]:
compareOverParamRobust(
    param_name='pGen',
    param_values=paramValues,
    edges=[(0,1), (1,2), (2,3), (3,4), (0,5), (5,3)],
    goal_edges=[(0,4), (1,3)],
    pSwap=0.6,
    pGen=0.6,
    max_age=5,
    totalSteps=500000,
    nLookahead=3,
    alpha=0.01,
    gamma=0.99,
    edr_window_size=1000,
    reward_mode='basic',
    initial_temperature=5.0,
    temperature_decay=temperature_decay,
    seed=30,
    nestedSwaps=False,
    training_function=train_sarsa_linear_policy,
    trainCount=2,          # new robust arg
    simulateCount=2        # new robust arg
)
