In [1]:
import pandas as pd
import random as rd
import numpy as np
import copy
import networkx as nx
import matplotlib.pyplot as plt
import itertools
import collections
from collections import deque  

In [2]:
class QuantumInternet():
    def __init__(self, initialEdges, pGen, cutOffAge, goalStates, goalWeights):
        self.initialEdges = initialEdges  
        self.currentEdges = {} 
        self.pGen = pGen
        self.cutOffAge = cutOffAge
        self.goalStates = goalStates
        self.goalWeights = goalWeights
        self.maxLinks = 1
        self.total_timesteps = 0
        self.successful_links = {i: 0.0 for i in range(len(goalStates))}
        
    def reset(self):
        self.currentEdges = {}
    
    def getState(self) -> dict:
        return self.currentEdges
                
    def _generateEntanglement(self, node1, node2):
        edge = tuple(sorted([node1, node2]))
        if edge not in self.currentEdges:
            self.currentEdges[edge] = deque([0])
        else:
            if len(self.currentEdges[edge]) < self.maxLinks:
                self.currentEdges[edge].appendleft(0)

    def globallyGenerateEntanglements(self):
        for edge in self.initialEdges:
            if random.random() < self.pGen:
                self._generateEntanglement(*edge)
    
    def _discardEntanglement(self, edge: tuple):
        if edge in self.currentEdges and len(self.currentEdges[edge]) > 0:
            self.currentEdges[edge].pop()
        if len(self.currentEdges[edge]) == 0:
            del self.currentEdges[edge]
                
    def ageEntanglements(self):
        edgeToDiscard= []
        for edge in list(self.currentEdges.keys()):  # Create a list to avoid modifying dict during iteration
            newAges = [age + 1 for age in self.currentEdges[edge] if age < self.cutOffAge]
            self.currentEdges[edge] = deque(newAges)
            
        for edge in edgeToDiscard:
            self._discardEntanglement(edgeToDiscard)
        
    def isTerminal(self) -> tuple[bool, list]:
        graph = collections.defaultdict(set)
        for (a, b) in self.currentEdges:
            graph[a].add(b)
            graph[b].add(a)
        
        def has_path(start, end):
            if start == end:
                return True
            
            visited = set()
            stack = [start]
            
            while stack:
                current = stack.pop()
                if current not in visited:
                    visited.add(current)
                    
                    if current == end:
                        return True
                    
                    # Add unvisited neighbors to stack
                    stack.extend(
                        next_node for next_node in graph[current] 
                        if next_node not in visited
                    )
            
            return False
        
        matching = [goal for goal in self.goalStates if has_path(goal[0], goal[-1])]
        return bool(matching), matching
                
    def rewardForAction(self, action): #Returns reward function
        pass
                
        