In [None]:
import os
os.environ['CASTLE_BACKEND'] = 'pytorch'

from collections import OrderedDict
import warnings

import numpy as np
import networkx as nx

from scipy import linalg 

from sklearn.linear_model import LinearRegression

import castle
from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation 

from castle.algorithms import PC, GES
from castle.algorithms import ANMNonlinear, ICALiNGAM, DirectLiNGAM
from castle.algorithms import Notears, NotearsNonlinear, GOLEM

from castle.common.priori_knowledge import PrioriKnowledge

from castle.common.independence_tests import hsic_test

import matplotlib.pyplot as plt

from dowhy import gcm
import pandas as pd

In [None]:
#Causal graph
causal_graph = nx.DiGraph([('d', 'messageReceived'),
                           ('d', 'graft'),
                           ('d', 'prune'),
                           ('dlo', 'graft'),
                           ('dhi', 'prune'),
                           ('dout', 'prune'),
                           # ('interval', 'graft'),
                           # ('interval', 'prune'),
                           ('prune', 'graft'),
                           ('messageReceived', 'messageOverhead'),
                           ('graft','messageReceived'),
                           ('topics', 'messageReceived'),
                           ('d', 'messageOverhead'),
                           ('topicSize', 'messageReceived'),
                           # ('topicSize','prune'),
                           ('topics', 'messageOverhead'),
                         ])


G=gcm.util.plot(causal_graph, figure_size=[6, 8])

In [None]:
#From digraph to adjancency matrix

# true_causal_matrix = nx.adjacency_matrix(causal_graph)
true_causal_matrix = nx.to_numpy_array(causal_graph)

print(true_causal_matrix)
print(causal_graph.nodes)

In [None]:
#Treat dataset

#Import data 
# data = pd.read_csv('../datasets/consolidated.csv',header=0, index_col=0)
data = pd.read_csv('../../Datasets/Networkwise/consolidated_5s.csv',header=0, index_col=0)
consolidated = pd.DataFrame(data)


consolidated = consolidated.drop(columns=['experiment', 'topology'])#, 'messageBandwidth'])#, 'gossipBandwidth'])#, 'messageDuplicated'])#, 'n_nodes', 'consensus', 'messageDuplicated'])
# consolidated = consolidated.loc[consolidated['totalBandwidth'] > 0]
# consolidated.head(100)

consolidateddiff = consolidated.dropna().dropna()

X = consolidateddiff[['d', 'messageReceived', 'graft', 'prune', 'dlo', 'dhi', 'dout', 'messageOverhead', 'topics', 'topicSize']].to_numpy()
consolidated = consolidateddiff[['d', 'messageReceived', 'graft', 'prune', 'dlo', 'dhi', 'dout', 'messageOverhead', 'topics', 'topicSize']]


# X = consolidateddiff[['d', 'messageReceived', 'graft', 'prune', 'dlo', 'dhi', 'dout', 'interval', 'messageOverhead', 'topics', 'topicSize']].to_numpy()
# consolidated = consolidateddiff[['d', 'messageReceived', 'graft', 'prune', 'dlo', 'dhi', 'dout', 'interval', 'messageOverhead', 'topics', 'topicSize']]

#Rearrange columns according to nodes list
# nodes = causal_graph.nodes

# X = consolidated[['d', 'messageReceived', 'graft', 'prune', 'iwant', 'dlo', 'dhi', 'dscore', 'dlazy', 'ihave', 'dout', 'gossipFactor', 'interval', 'messageOverhead']].to_numpy()

# X = consolidated.to_numpy()
# print(X)
consolidated.head(500)

In [None]:
g = DirectLiNGAM()
g.learn(X)

# plot est_dag and true_dag
GraphDAG(g.causal_matrix, true_causal_matrix)#, true_dag)

# calculate accuracy
met = MetricsDAG(g.causal_matrix, true_causal_matrix)#, true_dag)
print(met.metrics)

In [None]:
learned_causal_graph = nx.from_numpy_array(g.causal_matrix, create_using=nx.DiGraph)

mapping = {0:'d',
            1:'messageReceived',
            2:'graft',
            3:'prune',
            4:'dlo',
            5:'dhi',
            6:'dout',
            # 7:'interval',
            7:'messageOverhead',
            8:'topics',
            9:'topicSize'
          }

# nx.set_node_attributes(learned_causal_graph, labels, "name")
H = nx.relabel_nodes(learned_causal_graph, mapping)

# H.remove_edge('dout','dlo')

# fig, ax = plt.subplots(figsize=(30,30))

G=gcm.util.plot(H, figure_size=[6, 8])

nx.write_adjlist(H, "./AdjLists/DirectLiNGAM.adjlist")

A = nx.nx_agraph.to_agraph(H)

A.layout("dot")

A.get_node('d').attr['style']='bold'
A.get_node('dlo').attr['style']='bold'
A.get_node('dhi').attr['style']='bold'
A.get_node('dout').attr['style']='bold'
# A.get_node('interval').attr['style']='bold'
A.get_node('topicSize').attr['style']='bold'
A.get_node('topics').attr['style']='bold'


A.get_node('messageOverhead').attr['style']='filled'
A.get_node('messageOverhead').attr['penwidth']=2


A.draw("./Graphs/DirectLiNGAM.png")#, prog="dot")

# nx.draw(H, 'pc5s.png')

# plt.savefig("PCNoPrioriKnowledge.png") # save as png
# plt.show() # displab

In [None]:
g2 = ICALiNGAM()
g2.learn(X)

# plot est_dag and true_dag
GraphDAG(g2.causal_matrix, true_causal_matrix)

# calculate accuracy
met = MetricsDAG(g2.causal_matrix, true_causal_matrix)
print(met.metrics)

In [None]:
learned_causal_graph = nx.from_numpy_array(g.causal_matrix, create_using=nx.DiGraph)

mapping = {0:'d',
            1:'messageReceived',
            2:'graft',
            3:'prune',
            4:'dlo',
            5:'dhi',
            6:'dout',
            # 7:'interval',
            7:'messageOverhead',
            8:'topics',
            9:'topicSize'
          }

# nx.set_node_attributes(learned_causal_graph, labels, "name")
H = nx.relabel_nodes(learned_causal_graph, mapping)

# H.remove_edge('dout','dlo')

# fig, ax = plt.subplots(figsize=(30,30))

G=gcm.util.plot(H, figure_size=[6, 8])

nx.write_adjlist(H, "./AdjLists/ICALiNGAM.adjlist")

A = nx.nx_agraph.to_agraph(H)

A.layout("dot")

A.get_node('d').attr['style']='bold'
A.get_node('dlo').attr['style']='bold'
A.get_node('dhi').attr['style']='bold'
A.get_node('dout').attr['style']='bold'
# A.get_node('interval').attr['style']='bold'
A.get_node('topicSize').attr['style']='bold'
A.get_node('topics').attr['style']='bold'


A.get_node('messageOverhead').attr['style']='filled'
A.get_node('messageOverhead').attr['penwidth']=2


A.draw("./Graphs/ICALiNGAM.png")#, prog="dot")

# nx.draw(H, 'pc5s.png')

# plt.savefig("PCNoPrioriKnowledge.png") # save as png
# plt.show() # displab