Here we will create a basic n-step policy


In [26]:
def getImmediateReward(achieved_goals, goal_success_counts, total_timesteps):
    if achieved_goals is None:
        return 0

    total_reward = 0
    for goal_edge, success in achieved_goals:
        if success:
            start, end = goal_edge
            num_edges = abs(end - start)
            instant_rate = pSwap ** (num_edges - 1)

            #edr = max(0.0001, goal_success_counts[goal_edge] / max(1, total_timesteps))
            edr = goal_success_counts[goal_edge] / max(1, total_timesteps) + 0.001

            if instant_rate > 0 and edr > 0:
                total_reward += instant_rate / edr
    return total_reward

In [27]:
from collections import defaultdict, deque
from itertools import product, combinations
import random
import numpy as np
import matplotlib.pyplot as plt


def getPossibleStates(edges, max_age):
    sorted_edges = sorted(tuple(sorted(e)) for e in edges)
    possible_ages = [-1] + list(range(1, max_age + 1))
    return [tuple(zip(sorted_edges, age_combo)) for age_combo in product(possible_ages, repeat=len(sorted_edges))]

def getAgedStates(state, maxAge):
    new_state = []
    for edge, age in state:
        if age == -1:
            new_state.append((edge, -1))
        else:
            new_age = age + 1
            new_state.append((edge, new_age if new_age <= maxAge else -1))
    return tuple(sorted(new_state))

def generateAllOutcomes(state, pGen):
    empty_edges = [edge for edge, age in state if age == -1]
    outcomes = []
    for pattern in product([0, 1], repeat=len(empty_edges)):
        outcome_map = {edge: 1 if outcome else -1 for edge, outcome in zip(empty_edges, pattern)}
        new_state = [(edge, outcome_map.get(edge, age)) for edge, age in state]
        outcomes.append((tuple(sorted(new_state)), 1.0))
    return outcomes

def generateAllSwappingOutcomes(state, goalEdges, pSwap):
    def find_path(current, target, visited):
        if current == target:
            return [current]
        visited.add(current)
        for next_node in graph.get(current, []):
            if next_node not in visited:
                path = find_path(next_node, target, visited)
                if path:
                    return [current] + path
        return None

    entangled_edges = [(edge, age) for edge, age in state if age > 0]
    graph = defaultdict(list)
    for (edge, _) in entangled_edges:
        graph[edge[0]].append(edge[1])
        graph[edge[1]].append(edge[0])

    swap_attempts = []
    for goal_edge in goalEdges:
        path = find_path(goal_edge[0], goal_edge[1], set())
        if path and len(path) > 1:
            path_edges = list(zip(path[:-1], path[1:]))
            used_edges = [(e, age) for p_edge in path_edges for e, age in entangled_edges if set(p_edge) == set(e)]
            swap_attempts.append({
                'goal': goal_edge,
                'used_edges': used_edges,
                'num_swaps': len(path_edges) - 1
            })

    outcomes = [(state, 1.0, None)]
    for r in range(1, len(swap_attempts) + 1):
        for attempt_combo in combinations(swap_attempts, r):
            used = set()
            if any(e in used or used.add(e) for attempt in attempt_combo for e, _ in attempt['used_edges']):
                continue
            for pattern in product([True, False], repeat=r):
                new_state = []
                goals = []
                for attempt, success in zip(attempt_combo, pattern):
                    goals.append((attempt['goal'], success))
                for edge, age in state:
                    if any(not s and edge in [e for e, _ in a['used_edges']] for a, s in zip(attempt_combo, pattern)):
                        new_state.append((edge, -1))
                    else:
                        new_state.append((edge, age))
                outcomes.append((tuple(sorted(new_state)), 1.0, goals))
    return outcomes

def getAllTransitionProbabilities(state, goalEdges, pSwap, pGen, maxAge):
    transitions = []
    for swap_state, _, goals in generateAllSwappingOutcomes(state, goalEdges, pSwap):
        aged_state = getAgedStates(swap_state, maxAge)
        for final_state, _ in generateAllOutcomes(aged_state, pGen):
            transitions.append((final_state, 1.0, goals))
    return transitions

def generateAllStateTransitions(edges, goalEdges, pSwap, pGen, maxAge):
    return {s: getAllTransitionProbabilities(s, goalEdges, pSwap, pGen, maxAge) for s in getPossibleStates(edges, maxAge)}

def getImmediateReward(goals, counts, t):
    return sum(10 for g, success in goals or [] if success)

# Parameters
edges = [(0, 1), (1, 2), (2, 3), (3, 4)]
goalEdges = [(0, 4)]
pSwap, pGen, maxAge = 1, 1, 3
alpha, gamma, epsilon, n = 0.5, 0.98, 0.2, 3
transitions = generateAllStateTransitions(edges, goalEdges, pSwap, pGen, maxAge)

class QuantumNetworkNStepSARSA:
    def __init__(self, edges, goalEdges, transitions, pSwap, pGen, maxAge, alpha, gamma, epsilon, n):
        self.edges, self.goalEdges, self.transitions = edges, goalEdges, transitions
        self.alpha, self.gamma, self.epsilon, self.n = alpha, gamma, epsilon, n
        self.Q = defaultdict(float)
        self.goal_success_counts = {goal: 0 for goal in goalEdges}
        self.reward_history, self.episode_avg_q_values, self.episode_avg_rewards = [], [], []
        self.edr_history = []
        self.per_goal_edr_history = {goal: [] for goal in goalEdges}

    def choose_action(self, state, train=True):
        ts = self.transitions[state]  # List of possible transitions
        if train and random.random() < self.epsilon:
            # Random action during training
            action = random.choice(ts)
        else:
            # Best action (max Q-value)
            action = max(ts, key=lambda t: self.Q[(state, tuple(t[0]))])  # Convert action to tuple
        return action  # Return the full action tuple (state, action, reward)

    def train(self, total_steps=10000):
        state = tuple((e, -1) for e in self.edges)
        self.goal_success_counts = {goal: 0 for goal in self.goalEdges}
        buffer = deque()
        action = self.choose_action(state)
        total_timesteps = 1

        for t in range(total_steps):
            next_state, _, goals = action  # Unpack correctly
            reward = getImmediateReward(goals, self.goal_success_counts, total_timesteps)
            if goals:
                for g, s in goals:
                    if s:
                        self.goal_success_counts[g] += 1

            print(f"Step {t}: State = {state}\n  Action = {action}\n  Reward = {reward}\n  Q-value before = {self.Q[(state, tuple(action[0]))]:.4f}")
            buffer.append((state, tuple(action), reward))
            state = next_state
            action = self.choose_action(state)
            total_timesteps += 1

            tau = t - self.n + 1
            if tau >= 0:
                G = sum((self.gamma ** (i - tau)) * buffer[i][2] for i in range(tau, min(tau + self.n, len(buffer))))
                if tau + self.n < len(buffer):
                    s_n, a_n, _ = buffer[tau + self.n]
                    G += (self.gamma ** self.n) * self.Q[(s_n, a_n)]
                s_tau, a_tau, _ = buffer.popleft()
                old_q = self.Q[(s_tau, a_tau)]
                self.Q[(s_tau, a_tau)] += self.alpha * (G - old_q)
                print(f"  G = {G:.4f}, Updated Q = {self.Q[(s_tau, a_tau)]:.4f}")

            if (t + 1) % 100 == 0:
                avg_q = np.mean(list(self.Q.values()))
                self.reward_history.append(reward)
                self.episode_avg_q_values.append(avg_q)
                if len(self.reward_history) >= 100:
                    self.episode_avg_rewards.append(np.mean(self.reward_history[-100:]))
                edrs = [self.goal_success_counts[g] / max(1, total_timesteps) for g in self.goalEdges]
                self.edr_history.append(np.mean(edrs))
                for g in self.goalEdges:
                    self.per_goal_edr_history[g].append(self.goal_success_counts[g] / max(1, total_timesteps))

    def plot_results(self):
        plt.figure(figsize=(18, 5))
        plt.subplot(1, 3, 1)
        plt.plot(self.reward_history, label='Reward')
        plt.plot(self.episode_avg_rewards, label='100-step avg')
        plt.title('Reward over Time')
        plt.xlabel('Steps (x100)')
        plt.ylabel('Reward')
        plt.legend()
        plt.grid()

        plt.subplot(1, 3, 2)
        plt.plot(self.episode_avg_q_values)
        plt.title('Average Q-value over Time')
        plt.xlabel('Steps (x100)')
        plt.ylabel('Avg Q-value')
        plt.grid()

        plt.subplot(1, 3, 3)
        for goal in sorted(self.goalEdges):
            plt.plot(self.per_goal_edr_history[goal], label=f"EDR {goal}")
        plt.title('Per-Goal Estimated Delivery Rate (EDR)')
        plt.xlabel('Steps (x100)')
        plt.ylabel('EDR')
        plt.legend()
        plt.grid()

        plt.tight_layout()
        plt.show()

# Running the agent training
agent = QuantumNetworkNStepSARSA(edges, goalEdges, transitions, pSwap, pGen, maxAge, alpha, gamma, epsilon, n)
agent.train(total_steps=10000)
agent.plot_results()


Step 0: State = (((0, 1), -1), ((1, 2), -1), ((2, 3), -1), ((3, 4), -1))
  Action = ((((0, 1), -1), ((1, 2), -1), ((2, 3), -1), ((3, 4), -1)), 1.0, None)
  Reward = 0
  Q-value before = 0.0000
Step 1: State = (((0, 1), -1), ((1, 2), -1), ((2, 3), -1), ((3, 4), -1))
  Action = ((((0, 1), -1), ((1, 2), -1), ((2, 3), -1), ((3, 4), -1)), 1.0, None)
  Reward = 0
  Q-value before = 0.0000
Step 2: State = (((0, 1), -1), ((1, 2), -1), ((2, 3), -1), ((3, 4), -1))
  Action = ((((0, 1), -1), ((1, 2), -1), ((2, 3), -1), ((3, 4), -1)), 1.0, None)
  Reward = 0
  Q-value before = 0.0000
  G = 0.0000, Updated Q = 0.0000
Step 3: State = (((0, 1), -1), ((1, 2), -1), ((2, 3), -1), ((3, 4), -1))
  Action = ((((0, 1), -1), ((1, 2), -1), ((2, 3), -1), ((3, 4), -1)), 1.0, None)
  Reward = 0
  Q-value before = 0.0000
  G = 0.0000, Updated Q = 0.0000
Step 4: State = (((0, 1), -1), ((1, 2), -1), ((2, 3), -1), ((3, 4), -1))
  Action = ((((0, 1), -1), ((1, 2), -1), ((2, 3), -1), ((3, 4), -1)), 1.0, None)
  Reward

TypeError: unhashable type: 'list'