Skip to content

Commit

Permalink
Allow networkx >3 dependency
Browse files Browse the repository at this point in the history
Signed-off-by: Ehud-Karavani <ehud.karavani@ibm.com>
  • Loading branch information
ehudkr committed Oct 25, 2023
1 parent 6601c9c commit ccb8b7b
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion causallib/simulation/CausalSimulator3.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __init__(self, topology, var_types, prob_categories, link_types, snr, treatm

# Create a graph out of matrix topology:
self.topology = topology
self.graph_topology = nx.from_numpy_matrix(topology.transpose(), create_using=nx.DiGraph()) # type:nx.DiGraph
self.graph_topology = nx.from_numpy_array(topology.transpose(), create_using=nx.DiGraph()) # type:nx.DiGraph
self.graph_topology = nx.relabel_nodes(self.graph_topology,
dict(list(zip(list(range(self.m)), self.var_names))))

Expand Down
4 changes: 2 additions & 2 deletions causallib/tests/test_causal_simulator3.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def test_random_topology_generation(self):
np.testing.assert_array_equal(T.loc[X.columns, :].sum(axis="columns"), np.zeros(5))

# Test for DAGness:
from networkx import DiGraph, from_numpy_matrix, is_directed_acyclic_graph
from networkx import DiGraph, from_numpy_array, is_directed_acyclic_graph
NUM_TESTS = 50
for test in range(NUM_TESTS):
n_cov = np.random.randint(low=10, high=100)
Expand All @@ -317,7 +317,7 @@ def test_random_topology_generation(self):
n_cen = np.random.randint(low=0, high=n_tre_out)
T, _ = CS3m.generate_random_topology(n_covariates=n_cov, p=p, n_treatments=n_tre_out, n_outcomes=n_tre_out,
n_censoring=n_cen, given_vars=[], p_hidden=0)
G = from_numpy_matrix(T.values.transpose(), create_using=DiGraph())
G = from_numpy_array(T.values.transpose(), create_using=DiGraph())
res = is_directed_acyclic_graph(G)
self.assertTrue(res)

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pandas>=0.25.2,<3
pandas>=0.25.2,<4
scipy>=0.19,<2
statsmodels>=0.9,<1
networkx>=1.1,<3
Expand Down

0 comments on commit ccb8b7b

Please sign in to comment.