In [33]:
import networkx as nx 
import random
import matplotlib.pyplot as plt
import numpy as np 
import math 
import copy 
import graphviz
from scipy.optimize import fsolve
from typing import List
import sympy as sp
from sympy import Symbol
from collections import deque
import pickle

In [34]:
class Arrival():
    
    def __init__(self,n:int,do_reachability:bool = False):
        self.n = n # number of nodes
        self.vertices = [v for v in range(n)] # zero is origin and n is the destination 
        self.s_0 = np.array([random.choice([i for i in range(n) if i != v]) for v in self.vertices])  # even successors
        self.s_1 = np.array([random.choice([i for i in range(n) if i != v]) for v in self.vertices])  # odd successors
      
        # Some edges can have both successor as themselves, acting as sinks. should this be allowed??
        self.start_node = 0 # current node 
        self.target_node = n-1
        self.sink_node = -1
        
        self.function_calls = 0
        # Construct Graph Structure and Equations for node visit counts
        # self.draw_graph(self.raw_graph)
        
        self.graph = self.get_network_graph()
        if do_reachability:
            
            self.graph = self.combine_unreachable_nodes()
            self.s_0, self.s_1 = self.update_successor_list()
        
            self.X,self.equations = self.get_equations()
        # self.draw_graph(self.graph)
         
    
    def __repr__(self):
        return f"Vertices: {self.vertices} \nEven Successors: {self.s_0}\nOdd Successors: {self.s_1}\n"
    
    def update_successor_list(self):
        ## now the successor lists is 0 indexed with last element being the successors for -1
        new_s0, new_s1 = [-1]*self.n, [-1]*self.n
        for node in self.graph.nodes:
            if node == -1:
                continue
            
            for source,target,attr in self.graph.out_edges(node,data=True):
                # print(edge)
                label = attr['label']
                # print(node ,edge, label)
                
                if label == '1':
                    new_s1[node] = target
                elif label == '0':
                    new_s0[node] = target     
        # print(new_s1) 
        return new_s0,new_s1
        
    def get_network_graph(self):
        G = nx.MultiDiGraph()
        G.add_nodes_from(self.vertices, d_dash=False)
        
        for v in self.vertices:
            G.add_edge(v,self.s_0[v],label='0')
            G.add_edge(v,self.s_1[v],label='1')
            
        return G

    def combine_unreachable_nodes(self):
        # reachable_nodes = nx.descendants(self.raw_graph, self.start_node) | {self.start_node}
        reachable_nodes = nx.ancestors(self.graph, self.target_node) | {self.target_node}

        new_G = nx.MultiDiGraph()
        new_G.add_edges_from([(-1, -1,{"label" : '1'}),(-1, -1,{"label" : '0'})])
        
        for s,t,attr in self.graph.edges(data=True):
            if s not in reachable_nodes:
                new_G.add_edge(s, -1,label=attr['label'])
            else:
                new_G.add_edge(s, t,label=attr['label'])
                
        ## Changing the vertices
        self.vertices = list(new_G.nodes)
        self.vertices.sort()
        
        self.n = len(self.vertices)
        self.target_node = self.n-2 
        
        # print(n)
        return new_G

    def get_equations(self):
        symbol_nodes = list(self.graph.nodes)
        symbol_nodes.sort()
       
        symbols = sp.symbols(' '.join([f"X{i}" for i in symbol_nodes]),positive=True)
        s_mappings = {n:s for n,s in zip(symbol_nodes,symbols)}
        
        
        equations = []  
        for v in symbol_nodes:
            odd_parents = []
            even_parents = []
            for s, t, attr in self.graph.in_edges(v,data=True):
                
                label = attr['label']
                if label == '1':
                    odd_parents.append(s)
                elif label == '0':
                    even_parents.append(s)
            
            parent_sum = sp.sympify(0)
            for p in odd_parents:
                parent_sum += sp.floor(s_mappings[p]/2)
            for p in even_parents:
                parent_sum += sp.ceiling(s_mappings[p]/2)
            # parent_sum = sum([math.floor(X[p]/2) for p in odd_parents]) + sum([math.ceil(X[p]/2) for p in even_parents])
            # eq = X[v] - (parent_sum + 1) if v == 0 else X[v] - parent_sum # origin is visited one more time
            total_sum = (parent_sum + 1) if v == 0 else parent_sum # origin is visited one more time
            # eq = sp.Min(total_sum,self.n*(2**self.n))
            eq = sp.Min(total_sum,2**self.n)
            
            equations.append(eq)
            
        return list(s_mappings.values()),equations
    
    def evaluate(self,x):
        assert len(x) == len(self.equations)
        self.function_calls += 1
        # make sympy assignments of given values 
        assignment = {self.X[i]: value_i for i, value_i in enumerate(x)}
        
        results = []
        for eq in self.equations:
            # substitute assignments in equations
            result = eq.subs(assignment)
            results.append(result)
        
        return np.array(results)
    
    def draw_graph(self):
        G = self.graph
        pos = nx.spring_layout(G)
        edge_labels = {(i, j): attr['label'] for i, j, attr in G.edges(data=True)}
        # node_labels = nx.get_node_attributes(self.graph, 'd_dash')
        nx.draw(G,pos, with_labels=True)
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
        # nx.draw_networkx_labels(self.raw_graph, pos, labels=node_labels)
        
        plt.show()
        # plt.clf()
    
    def save_graph(self,filename):
        with open(filename, 'wb') as f:
            pickle.dump(self.graph, f, pickle.HIGHEST_PROTOCOL)
            
            
        g = graphviz.Digraph('G', filename=filename)
        # g.edges(self.vertices)
        for v in range(self.n - 1):
            g.edge(str(v),str(self.s_0[v]),label='0')
            g.edge(str(v),str(self.s_1[v]),label='1')
        g.edge(str(-1),str(-1),label='0')
        g.edge(str(-1),str(-1),label='1')            
        g.view()
        
    def run_procedure(self):
        v = self.start_node
        s_curr = np.copy(self.s_0) # current switches for each node
        s_next = np.copy(self.s_1) # next switch for each node
        counter = 0
        while counter < self.n * 2**self.n:
            if counter % 10**7 == 0:
                print(f'move {counter} :{v}')
            
            if v == -1:
                print(f"Sink Node reached at {counter}")
                return False
            elif v == self.target_node:
                print(f"Target Node reached at {counter}")
                return True
            
            w = s_curr[v]
            s_curr[v] = s_next[v]
            s_next[v] = w  
            v = w
            
            counter += 1
        
        return False
    
    def plot_successor_graph(self):
        # Create a directed graph
        G = self.graph
        
        pos = nx.spring_layout(G, seed=42)
        
        # Draw nodes
        nx.draw_networkx_nodes(G, pos, node_size=300)
        
        # Extract edge attributes to style them differently
        edges = G.edges(data=True)
        
        # Draw edges with different styles and colors
        nx.draw_networkx_edges(G, pos, style='solid', arrows=True)
        nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): d['label'] for u, v, d in G.edges(data=True)})
        
        # Draw labels
        nx.draw_networkx_labels(G, pos, font_size=12, font_family='sans-serif')
        
        # Show plot
        plt.axis('off')
        plt.show()

    
    

In [35]:
def branch_instance_old(size_of_instance:int , split_ratio:float = 0.0):
    if split_ratio == 0.0:
        split_ratio = random.random()
    ## first branch is 0 to m and second branch is m to m+n   
    m = int(split_ratio*size_of_instance) 
    n = size_of_instance 
    game = Arrival(n)
    
    for v in game.vertices:
        if v == game.start_node:
            r = random.random()
            game.s_0[v] = 1 if r > 0.5 else m
            game.s_1[v] = 1 if r <= 0.5 else m
        elif v in [game.target_node, game.sink_node]:
            game.s_0[v] = v
            game.s_1[v] = v
        elif v == m -1 :
            r = random.random()
            game.s_0[v] = 0 if r > 0.5 else -1
            game.s_1[v] = 0 if r <= 0.5 else -1
        else:
            r = random.random()
            game.s_0[v] = v+1 if r > 0.5 else 0
            game.s_1[v] = v+1 if r <= 0.5 else 0
        
        
    game.graph = game.get_network_graph()
    game.draw_graph(game.graph)
    # game.save_graph(f'./hard_{size_of_instance}_{split_ratio}.gpickle')
    return game
    

In [36]:
def get_branch_instance(size_of_instance:int , split_ratio:float = 0.0):
    if split_ratio == 0.0:
        split_ratio = random.random()
    ## first branch is 0 to m and second branch is m to m+n   
    m = int(split_ratio*size_of_instance) 
    n = size_of_instance 
    game = Arrival(n,do_reachability=False)
    s_0 = [0]*(n+1)
    s_1 = [0]*(n+1)
    
    for v in range(-1,n):
        if v == game.start_node:
            r = random.random()
            s_0[v] = 1 if r > 0.5 else m
            s_1[v] = 1 if r <= 0.5 else m
        elif v in [game.target_node, game.sink_node]:
            s_0[v] = v
            s_1[v] = v
        elif v == m -1 :
            r = random.random()
            s_0[v] = 0 if r > 0.5 else -1
            s_1[v] = 0 if r <= 0.5 else -1
        else:
            r = random.random()
            s_0[v] = v+1 if r > 0.5 else 0
            s_1[v] = v+1 if r <= 0.5 else 0
     
    game.s_0 , game.s_1 = np.array(s_0), np.array(s_1) 
    game.vertices = list(range(-1,n))  
        
    game.graph = game.get_network_graph()
    # game.draw_graph()
    # game.save_graph(f'./hard_{size_of_instance}_{split_ratio}.gpickle')
    return game
    

In [37]:
def get_branch_instance_without_random(size_of_instance:int , split_ratio:float = 0.0):
    if split_ratio == 0.0:
        split_ratio = random.random()
    ## first branch is 0 to m and second branch is m to m+n   
    m = int(split_ratio*size_of_instance) 
    n = size_of_instance 
    game = Arrival(n,do_reachability=False)
    s_0 = [0]*(n+1)
    s_1 = [0]*(n+1)
    
    for v in range(-1,n):
        if v == game.start_node:
            s_0[v] = 1 
            s_1[v] = m
        elif v in [game.target_node, game.sink_node]:
            s_0[v] = v
            s_1[v] = v
        elif v == m -1 :
            r = random.random()
            s_0[v] = 0 
            s_1[v] = -1
        else:
            r = random.random()
            s_0[v] = 0
            s_1[v] = v+1 
     
    game.s_0 , game.s_1 = np.array(s_0), np.array(s_1) 
    game.vertices = list(range(-1,n))  
        
    game.graph = game.get_network_graph()
    # game.draw_graph()
    # game.save_graph(f'./hard_{size_of_instance}_{split_ratio}.gpickle')
    return game
    

In [47]:
def get_branch_instance_with_numbers(sink_side:int,target_side:int):
    a = sink_side 
    b = target_side
    ## convert a,b to binary 
    a_bin = bin(a)[2:][::-1]
    b_bin = bin(b)[2:][::-1]
    # making the length of a and b equal
    if len(a_bin) < len(b_bin):
        a_bin += '0'*(len(b_bin) - len(a_bin))
    elif len(b_bin) < len(a_bin):
        b_bin += '0'*(len(a_bin) - len(b_bin))
    assert len(a_bin) == len(b_bin)
    
    print(a_bin[::-1],b_bin[::-1])
    # extract one bit at a time and create the graph
    m = len(a_bin)
    n = len(a_bin)+len(b_bin)
    
    game = Arrival(n,do_reachability=False)
    # print(m,n)
    s_0 = [0]*(n+1)
    s_1 = [0]*(n+1)
    
    for v in range(-1,m):
        if v == game.start_node:
            r = int(a_bin[v])
            s_0[v] = 1 if r == 0 else m
            s_1[v] = 1 if r == 1 else m
        elif v in [game.target_node, game.sink_node]:
            s_0[v] = v
            s_1[v] = v
        elif v == m -1 :
            r = int(a_bin[v])
            s_0[v] = 0 if r > 0.5 else -1
            s_1[v] = 0 if r <= 0.5 else -1
        else:
            r = int(a_bin[v])
            s_0[v] = v+1 if r == 0 else 0
            s_1[v] = v+1 if r == 1 else 0
            
    for v in range(m,n):
    
        if v == game.target_node:
            s_0[v] = v
            s_1[v] = v
        else:
            r = int(b_bin[v-m+1])
            s_0[v] = v+1 if r == 0 else 0
            s_1[v] = v+1 if r == 1 else 0
     
    game.s_0 , game.s_1 = np.array(s_0), np.array(s_1) 
    game.vertices = list(range(-1,n))  
        
    game.graph = game.get_network_graph()
        
    # game.save_graph(f'./hard_.gpickle')
    return game
    
    
    