In [179]:
import os
import subprocess
import networkx as nx
import random
import shutil

In [2]:
class ATREL_Model():
  def __init__(self,
               path_data:str,
               path_prog:str,
               seed:int):
    #self.model=GNN
    random.seed(seed)
    self.path_dat=path_data
    self.path_prog=path_prog
    self.seed=random.random()

  def _predict(self,G:nx.Graph,reliability=0)->float:
    home_dir = os.path.expanduser("~")
    os.chdir(home_dir)
    path_file = self.path_dat
    path_run = self.path_prog
    os.chdir(path_file)
    if not os.path.isfile('reliability_{}.dat'.format(self.seed)) and not os.path.isfile('connections_{}.dat'.format(self.seed)) and not os.path.isfile('grid_{}.dat'.format(self.seed)):  
        rels = open('reliability_{}.dat'.format(self.seed), 'w')
        cons = open('connections_{}.dat'.format(self.seed), 'w')
        ter = open('grid_{}.dat'.format(self.seed), 'w')
        nodes = G.number_of_nodes()
        for key , value in nx.get_edge_attributes(G, 'reliability').items():
            rels.write(f'{value} ')
            cons.write(f'{key[0]} {key[1]}\n')

        ter.write(f'{0} {nodes-1}\n')
        rels.close()
        cons.close()
        ter.close()
        home_dir = os.path.expanduser("~")
        os.chdir(home_dir)
        os.chdir(path_run)
        cmd = ["./reliability", "-allrel", '../tmp_files/connections_{}.dat'.format(self.seed), '../tmp_files/grid_{}.dat'.format(self.seed), '../tmp_files/reliability_{}.dat'.format(self.seed)]
        result = subprocess.run(cmd, text=True, capture_output=True)
        home_dir = os.path.expanduser("~")
        os.chdir(home_dir)
        os.chdir(path_file)
        os.remove('reliability_{}.dat'.format(self.seed))
        os.remove('connections_{}.dat'.format(self.seed))
        os.remove('grid_{}.dat'.format(self.seed))
    return float((str(result).split(',')[-3][str(result).split(',')[-3].index('prob')+7:str(result).split(',')[-3].index('\\n')]))

  def predict(self,G:nx.Graph)->float:
    connected=nx.is_connected(G)
    if connected:
      r = self._predict(G)
    else:
      r=0
    return r
  

In [95]:
!pwd

/Users/farid/Downloads/All-Terminal-Reliability/tmp_files


In [96]:
%cd ~

/Users/farid


In [97]:
os.getcwd()

'/Users/farid'

In [180]:
def compute_rel(G:nx.Graph, 
                folder_path:str = '/Downloads/All-Terminal-Reliability',
                run_path:str = '/Downloads/All-Terminal-Reliability/reliability_tdzdd',
                seed: int = None)->float:
    random.seed(seed)
    id = random.random()
    os.chdir(os.path.expanduser("~")+folder_path)
    os.mkdir('tmp_files_{}'.format(id)) if not os.path.isdir('tmp_files_{}'.format(id)) else None
    #home_dir = os.path.expanduser("~")
    os.chdir('tmp_files_{}'.format(id))
    #os.chdir(data_path)
    if not os.path.isfile('reliability_{}.dat'.format(id)) and not os.path.isfile('connections_{}.dat'.format(id)) and not os.path.isfile('grid_{}.dat'.format(id)):  
        rels = open('reliability_{}.dat'.format(id), 'w')
        cons = open('connections_{}.dat'.format(id), 'w')
        ter = open('grid_{}.dat'.format(id), 'w')
        nodes = G.number_of_nodes()
        for key , value in nx.get_edge_attributes(G, 'reliability').items():
            rels.write(f'{value} ')
            cons.write(f'{key[0]} {key[1]}\n')

        ter.write(f'{0} {nodes-1}\n')
        rels.close()
        cons.close()
        ter.close()
        os.chdir(os.path.expanduser("~") + run_path)
        cmd = ["./reliability", "-allrel", '../tmp_files_{}/connections_{}.dat'.format(id, id),
                '../tmp_files_{}/grid_{}.dat'.format(id, id), '../tmp_files_{}/reliability_{}.dat'.format(id, id)]
        result = subprocess.run(cmd, text=True, capture_output=True)
        os.chdir(os.path.expanduser("~") + folder_path)
        shutil.rmtree('tmp_files_{}'.format(id))
        
        # os.chdir(home_dir)
        # os.chdir(data_path)
        # os.remove('reliability_{}.dat'.format(id))
        # os.remove('connections_{}.dat'.format(id))
        # os.remove('grid_{}.dat'.format(id))
    return float((str(result).split(',')[-3][str(result).split(',')[-3].index('prob')+7:str(result).split(',')[-3].index('\\n')]))

In [181]:
import networkx as nx
import random
import itertools
import ast
import logging
import pandas as pd
import multiprocessing
import argparse
from reliability_comp import reliability

In [182]:
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", default=1871, type=int, help="seed number")
    parser.add_argument("--n_graph", default=10, type=int, help="number of graphs")
    parser.add_argument("--n_node", default=8, type=int, help="number of nodes")
    parser.add_argument("--l_bound", default=0.3, type=float, help="edge addition lower bound")
    parser.add_argument("--u_bound", default=0.6, type=float, help="edge addition upper bound")
    args = parser.parse_known_args()
    return args

In [183]:
args = get_args()

In [184]:
def create_graph(args):

    random.seed(args.seed)

    dataset = []
    number_of_graphs_created = 0
    #unconnected_graphs = 0

    while number_of_graphs_created < args.n_graph:

        N = args.n_node

        reliabilty = [0.80, 0.85, 0.90, 0.95, 0.99]
        # Create an empty graph object
        g = nx.Graph()


        # Adding nodes
        g.add_nodes_from(range(0, N))


        # Add edges to the graph randomly.
        for i in g.nodes():
            for j in g.nodes():
                if (i < j):

                    # Take random number R.
                    R = random.random()

                    # Check if R is in the range [0.3, 0.56]
                    if (args.l_bound <= R <= args.u_bound):
                        g.add_edge(i, j)
                        nx.set_edge_attributes(g, {(i, j): {"reliability": random.choice(reliabilty)}})


        if nx.is_connected(g) == True:
            dataset.append(g)
            number_of_graphs_created += 1
        else:
            continue
    print('Graphs created: ', number_of_graphs_created)
    return dataset

In [185]:
dataset = create_graph(args[0])

Graphs created:  10


In [186]:
dataset

[<networkx.classes.graph.Graph at 0x13fc9db70>,
 <networkx.classes.graph.Graph at 0x13fc9fbe0>,
 <networkx.classes.graph.Graph at 0x13fc9de70>,
 <networkx.classes.graph.Graph at 0x13fc9c400>,
 <networkx.classes.graph.Graph at 0x13fc9ee30>,
 <networkx.classes.graph.Graph at 0x13fc9d3f0>,
 <networkx.classes.graph.Graph at 0x13fc9c040>,
 <networkx.classes.graph.Graph at 0x13fc9f2b0>,
 <networkx.classes.graph.Graph at 0x13fc9e4d0>,
 <networkx.classes.graph.Graph at 0x13fc9f8b0>]

In [190]:
compute_rel(G = dataset[3], seed = random.randint(5000, 10000))

0.678391831

In [121]:
%cd ~/Downloads/All-Terminal-Reliability

/Users/farid/Downloads/All-Terminal-Reliability


In [191]:
args = [(i, random.randint(5000, 10000)) for i in dataset]

In [192]:
args[1]

(<networkx.classes.graph.Graph at 0x13fc9fbe0>, 5749)

In [194]:
with multiprocessing.Pool(processes=32) as pool:
    results = pool.starmap(compute_rel, args)

Process SpawnPoolWorker-329:
Process SpawnPoolWorker-331:
Process SpawnPoolWorker-328:
Process SpawnPoolWorker-330:
Process SpawnPoolWorker-332:
Process SpawnPoolWorker-333:
Traceback (most recent call last):
  File "/Users/farid/miniconda3/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/Users/farid/miniconda3/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/farid/miniconda3/lib/python3.10/multiprocessing/pool.py", line 114, in worker
    task = get()
Traceback (most recent call last):
  File "/Users/farid/miniconda3/lib/python3.10/multiprocessing/queues.py", line 367, in get
    return _ForkingPickler.loads(res)
AttributeError: Can't get attribute 'compute_rel' on <module '__main__' (built-in)>
  File "/Users/farid/miniconda3/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/Users/farid/miniconda3/lib/python3.10/multiprocessing/

KeyboardInterrupt: 

In [195]:
results

NameError: name 'results' is not defined

In [74]:
from reliability_comp import reliability

In [75]:
def multi_process(dataset):
    args = [(i,t) for i,t in zip(dataset, range(len(dataset)))]

    with multiprocessing.Pool(processes=32) as pool:
        results = pool.starmap(reliability, args)
    return results

In [76]:
multi_process(dataset)

[0.7044042672468,
 0.5629071194999999,
 0.8411593600626164,
 0.6783918309900001,
 0.9706029275096438,
 0.6260682780000001,
 0.9038100346015527,
 0.8067499831147501,
 0.4913080920000001,
 0.9126962728908751]

In [77]:
!pwd

/Users/farid/Downloads/All-Terminal-Reliability/tmp_files
