In [1]:
import pandas as pd
import random
import numpy as np
import copy
import matplotlib.pyplot as plt
import collections
from collections import deque  # Add this import
import math
!which python

/Users/berat/Desktop/quantum_entanglement/.venv/bin/python


In [2]:
class QuantumInternet():
    def __init__(self, initialEdges, pGen, cutOffAge, goalStates, goalWeights):
        self.initialEdges = initialEdges  
        self.currentEdges = {} 
        self.pGen = pGen
        self.cutOffAge = cutOffAge
        self.goalStates = goalStates
        self.goalWeights = goalWeights
        self.maxLinks = 1
        self.total_timesteps = 0
        self.successful_links = {i: 0.0 for i in range(len(goalStates))}
    
    def get_edrs(self):
        """Returns current EDR for each goal state"""
        return {
            i: self.successful_links[i] / max(1, self.total_timesteps)
            for i in range(len(self.goalStates))
        }
    
    def reset(self):
        self.currentEdges = {}
        self.globallyGenerateEntanglements()
    
    def getState(self) -> dict:
        return self.currentEdges
                
    def generateEntanglement(self, node1, node2):
        edge = tuple(sorted([node1, node2]))
        if edge not in self.currentEdges:
            self.currentEdges[edge] = deque([0])
        else:
            if len(self.currentEdges[edge]) < self.maxLinks:
                self.currentEdges[edge].appendleft(0)

    def globallyGenerateEntanglements(self):
        for edge in self.initialEdges:
            if random.random() < self.pGen:
                self.generateEntanglement(*edge)
    
    def discardEntanglement(self, edge: tuple):
        if edge in self.currentEdges and len(self.currentEdges[edge]) > 0:
            self.currentEdges[edge].pop()
        if len(self.currentEdges[edge]) == 0:
            del self.currentEdges[edge]
                
    
    def ageEntanglements(self):
        edges_to_check = list(self.currentEdges.keys())
        for edge in edges_to_check:
            newAges = [age + 1 for age in self.currentEdges[edge] if age + 1 <= self.cutOffAge]
            self.currentEdges[edge] = deque(newAges)
            
            if len(self.currentEdges[edge]) == 0:
                self.discardEntanglement(edge)
                
        
    def isTerminal(self) -> tuple[bool, list]:
        graph = collections.defaultdict(set)
        for (a, b) in self.currentEdges:
            graph[a].add(b)
            graph[b].add(a)
        
        def has_path(start, end):
            if start == end:
                return True
            
            visited = set()
            stack = [start]
            
            while stack:
                current = stack.pop()
                if current not in visited:
                    visited.add(current)
                    
                    if current == end:
                        return True
                    
                    # Add unvisited neighbors to stack
                    stack.extend(
                        next_node for next_node in graph[current] 
                        if next_node not in visited
                    )
            
            return False
        
        matching = [goal for goal in self.goalStates if has_path(goal[0], goal[-1])]
        return bool(matching), matching
                
    def rewardForAction(self, action): #Returns reward function
        pass
                
        

In [3]:
random.seed(27)
initialEdges = [(1,3), (2,3), (3,4), (4,5)]
goalStates = [(1, 4), (2,4)]
goalWeights = [0.3, 0.7]
pGen = 0.7
n = 2
cutOffAge = 0
alpha = 0.3
gamma = 0.95  # discount factor
epsilon = 0.2  # exploration rate
num_episodes = 30000
myNetwork = QuantumInternet(initialEdges, pGen, cutOffAge, goalStates, goalWeights)
myNetwork.globallyGenerateEntanglements()
print(myNetwork.getState())
myNetwork.ageEntanglements() # All entanglements are discarded after 1 timestep
myNetwork.globallyGenerateEntanglements()
print(myNetwork.getState())

state = myNetwork.getState()
state_key = tuple(
            (edge, age[0])
            for edge, age in sorted(state.items())
        )
print(state)
print(state_key)


{(1, 3): deque([0]), (4, 5): deque([0])}
{(1, 3): deque([0]), (3, 4): deque([0]), (4, 5): deque([0])}
{(1, 3): deque([0]), (3, 4): deque([0]), (4, 5): deque([0])}
(((1, 3), 0), ((3, 4), 0), ((4, 5), 0))


In [4]:
def epsilon_greedy_policy(Q, state_key, epsilon):
    if np.random.rand() < epsilon:
        return random.choice([True, False]) # For swap or not to swap entire connection
    else:
        # Exploitation: choose action with highest Q-value
        if state_key not in Q:
            Q[state_key] = {True: 0, False: 0}  # Initialize both actions
        
        return max(Q[state_key].items(), key=lambda x: x[1])[0]
    
    Q = {}
    


In [5]:
Q = {}

def n_step_sarsa(env, n, alpha, gamma, epsilon, num_episodes):
    global Q
    episode_rewards = np.zeros(num_episodes)
    
    for episode in range(num_episodes):
        if episode % 100 == 0:
            edrs = env.get_edrs()
            print(f"Episode {episode} EDRs:")
            for goal_idx, edr in edrs.items():
                print(f"Goal {env.goalStates[goal_idx]}: EDR = {edr:.3f}")
        
        env.reset()
        state = env.getState()
        # Hashing key to be stored
        state_key = tuple(
            (edge, age)
            for edge, age in sorted(state.items())
        )
        
        action = epsilon_greedy_policy(Q, state_key, epsilon)
        
        T = float('inf')
        t = 0
        tau = 0
        
        # Store states, actions, rewards
        states = [state_key]
        actions = [action]
        rewards = []
        
        while tau < (T - 1):  # Add step limit check
            if t < T:
                # Take action and get reward
                reward = env.rewardForAction(action)
                rewards.append(reward)
                
                # Age entanglements and generate new ones
                env.ageEntanglements()
                env.globallyGenerateEntanglements()
                
                # Get next state
                next_state = env.getState()
                is_terminal, _ = env.isTerminal()
                
                # Convert next_state to hashable format
                state_key = tuple(
                    (edge, age)
                    for edge, age in sorted(state.items())
                )
                states.append(next_state_key)
                
                if is_terminal:
                    T = t + 1
                else:
                    next_action = epsilon_greedy_policy(Q, next_state_key, epsilon)
                    actions.append(next_action)
            
            tau = t - n + 1
            
            if tau >= 0:
                G = sum([gamma**(i - tau - 1) * rewards[i] for i in range(tau + 1, min(tau + n, T))])
                
                if tau + n < T:
                    if states[tau + n] not in Q:
                        Q[states[tau + n]] = {True: 0, False: 0}
                    G += gamma**n * Q[states[tau + n]][actions[tau + n]]
                
                # Update Q-value
                if states[tau] not in Q:
                    Q[states[tau]] = {True: 0, False: 0}
                Q[states[tau]][actions[tau]] += alpha * (G - Q[states[tau]][actions[tau]])
                print(Q[states[tau]][actions[tau]])
            t += 1
            state = next_state
            action = next_action
        
        episode_rewards[episode] = sum(rewards)
    
    return Q, episode_rewards


# env
random.seed(27)
initialEdges = [(1,3), (2,3), (3,4), (4,5)]
goalStates = [(1, 4), (2,4)]
goalWeights = [0.3, 0.7]
pGen = 0.7
cutOffAge = 1
# sarsa
n = 2
alpha = 0.3
gamma = 0.95  # discount factor
epsilon = 0.2  # exploration rate
num_episodes = 30000
myNetwork = QuantumInternet(initialEdges, pGen, cutOffAge, goalStates, goalWeights)
myQ, myEpisodeRewards = n_step_sarsa(myNetwork, n, alpha, gamma, epsilon, num_episodes)

Episode 0 EDRs:
Goal (1, 4): EDR = 0.000
Goal (2, 4): EDR = 0.000


NameError: name 'next_state_key' is not defined

In [50]:
# Get all Q-values in a flat list with their corresponding states and actions
q_values = []
for state, actions in myQ.items():
    for action, value in actions.items():
        q_values.append((value, state, action))

# Sort by Q-value in descending order and get top 10
top_10 = sorted(q_values, key=lambda x: x[0], reverse=True)[:10]

# Print the results
print("Top 10 highest Q-values:")
print("-" * 50)
for i, (value, state, action) in enumerate(top_10, 1):
    print(f"\n{i}. Q-value: {value:.4f}")
    print(f"Action: {action}")
    print("State:")
    for edge, ages in state:
        print(f"  Edge {edge}: Ages {ages}")

Top 10 highest Q-values:
--------------------------------------------------

1. Q-value: 0.0000
Action: True
State:
  Edge (1, 3): Ages (0,)
  Edge (2, 3): Ages (0,)
  Edge (3, 4): Ages (0,)

2. Q-value: 0.0000
Action: False
State:
  Edge (1, 3): Ages (0,)
  Edge (2, 3): Ages (0,)
  Edge (3, 4): Ages (0,)

3. Q-value: 0.0000
Action: True
State:
  Edge (1, 3): Ages (0,)
  Edge (2, 3): Ages (0,)
  Edge (3, 4): Ages (0,)
  Edge (4, 5): Ages (0,)
  Edge (6, 7): Ages (0,)
  Edge (6, 8): Ages (0,)

4. Q-value: 0.0000
Action: False
State:
  Edge (1, 3): Ages (0,)
  Edge (2, 3): Ages (0,)
  Edge (3, 4): Ages (0,)
  Edge (4, 5): Ages (0,)
  Edge (6, 7): Ages (0,)
  Edge (6, 8): Ages (0,)

5. Q-value: 0.0000
Action: True
State:
  Edge (1, 3): Ages (0,)
  Edge (3, 4): Ages (0,)
  Edge (6, 7): Ages (0,)

6. Q-value: 0.0000
Action: False
State:
  Edge (1, 3): Ages (0,)
  Edge (3, 4): Ages (0,)
  Edge (6, 7): Ages (0,)

7. Q-value: 0.0000
Action: True
State:
  Edge (1, 3): Ages (0,)
  Edge (2, 3): A