From ccb8b7b18b3e33ecfeae401ffce40c8cdef44a65 Mon Sep 17 00:00:00 2001 From: Ehud-Karavani Date: Wed, 25 Oct 2023 12:02:27 +0300 Subject: [PATCH] Allow networkx >3 dependency Signed-off-by: Ehud-Karavani --- causallib/simulation/CausalSimulator3.py | 2 +- causallib/tests/test_causal_simulator3.py | 4 ++-- requirements.txt | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/causallib/simulation/CausalSimulator3.py b/causallib/simulation/CausalSimulator3.py index 07673d5..7757c40 100644 --- a/causallib/simulation/CausalSimulator3.py +++ b/causallib/simulation/CausalSimulator3.py @@ -186,7 +186,7 @@ def __init__(self, topology, var_types, prob_categories, link_types, snr, treatm # Create a graph out of matrix topology: self.topology = topology - self.graph_topology = nx.from_numpy_matrix(topology.transpose(), create_using=nx.DiGraph()) # type:nx.DiGraph + self.graph_topology = nx.from_numpy_array(topology.transpose(), create_using=nx.DiGraph()) # type:nx.DiGraph self.graph_topology = nx.relabel_nodes(self.graph_topology, dict(list(zip(list(range(self.m)), self.var_names)))) diff --git a/causallib/tests/test_causal_simulator3.py b/causallib/tests/test_causal_simulator3.py index ccab946..c02ec97 100644 --- a/causallib/tests/test_causal_simulator3.py +++ b/causallib/tests/test_causal_simulator3.py @@ -308,7 +308,7 @@ def test_random_topology_generation(self): np.testing.assert_array_equal(T.loc[X.columns, :].sum(axis="columns"), np.zeros(5)) # Test for DAGness: - from networkx import DiGraph, from_numpy_matrix, is_directed_acyclic_graph + from networkx import DiGraph, from_numpy_array, is_directed_acyclic_graph NUM_TESTS = 50 for test in range(NUM_TESTS): n_cov = np.random.randint(low=10, high=100) @@ -317,7 +317,7 @@ def test_random_topology_generation(self): n_cen = np.random.randint(low=0, high=n_tre_out) T, _ = CS3m.generate_random_topology(n_covariates=n_cov, p=p, n_treatments=n_tre_out, n_outcomes=n_tre_out, n_censoring=n_cen, given_vars=[], p_hidden=0) - G = from_numpy_matrix(T.values.transpose(), create_using=DiGraph()) + G = from_numpy_array(T.values.transpose(), create_using=DiGraph()) res = is_directed_acyclic_graph(G) self.assertTrue(res) diff --git a/requirements.txt b/requirements.txt index 218def9..6ca6aa9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pandas>=0.25.2,<3 +pandas>=0.25.2,<4 scipy>=0.19,<2 statsmodels>=0.9,<1 networkx>=1.1,<3