Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stochastic block model generator #1200

Merged
merged 6 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 2 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ ahash = "0.8.6"
fixedbitset = "0.4.2"
hashbrown = { version = ">=0.13, <0.15", features = ["rayon"] }
indexmap = { version = ">=1.9, <3", features = ["rayon"] }
ndarray = { version = "0.15.6", features = ["rayon"] }
num-traits = "0.2"
numpy = "0.21.0"
petgraph = "0.6.5"
Expand All @@ -44,6 +45,7 @@ ahash.workspace = true
fixedbitset.workspace = true
hashbrown.workspace = true
indexmap.workspace = true
ndarray.workspace = true
ndarray-stats = "0.5.1"
num-bigint = "0.4"
num-complex = "0.4"
Expand All @@ -63,10 +65,6 @@ rustworkx-core = { path = "rustworkx-core", version = "=0.15.0" }
version = "0.21.2"
features = ["abi3-py38", "extension-module", "hashbrown", "num-bigint", "num-complex", "indexmap"]

[dependencies.ndarray]
version = "^0.15.6"
features = ["rayon"]

[dependencies.sprs]
version = "^0.11"
features = ["multi_thread"]
Expand Down
2 changes: 2 additions & 0 deletions docs/source/api/random_graph_generator_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Random Graph Generator Functions
rustworkx.undirected_gnp_random_graph
rustworkx.directed_gnm_random_graph
rustworkx.undirected_gnm_random_graph
rustworkx.directed_sbm_random_graph
rustworkx.undirected_sbm_random_graph
rustworkx.random_geometric_graph
rustworkx.hyperbolic_random_graph
rustworkx.barabasi_albert_graph
Expand Down
9 changes: 9 additions & 0 deletions releasenotes/notes/sbm-random-graph-bf7ccd8e938f4218.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
features:
- |
Adds new random graph generator in rustworkx for the stochastic block model.
There is a generator for directed :func:`.directed_sbm_random_graph` and
undirected graphs :func:`.undirected_sbm_random_graph`.
- |
Adds new function ``sbm_random_graph`` to the rustworkx-core module
``rustworkx_core::generators`` that samples a graph from the stochastic
block model.
Comment on lines +7 to +9
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add drawings of some example graphs if you want, check

:func:`.directed_barabasi_albert_graph` and :func:`.barabasi_albert_graph`,
for an example

1 change: 1 addition & 0 deletions rustworkx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ahash.workspace = true
fixedbitset.workspace = true
hashbrown.workspace = true
indexmap.workspace = true
ndarray.workspace = true
num-traits.workspace = true
petgraph.workspace = true
priority-queue = "2.0"
Expand Down
1 change: 1 addition & 0 deletions rustworkx-core/src/generators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ pub use random_graph::gnp_random_graph;
pub use random_graph::hyperbolic_random_graph;
pub use random_graph::random_bipartite_graph;
pub use random_graph::random_geometric_graph;
pub use random_graph::sbm_random_graph;
pub use star_graph::star_graph;
287 changes: 286 additions & 1 deletion rustworkx-core/src/generators/random_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use std::hash::Hash;

use ndarray::ArrayView2;
use petgraph::data::{Build, Create};
use petgraph::visit::{
Data, EdgeRef, GraphBase, GraphProp, IntoEdgeReferences, IntoEdgesDirected,
Expand Down Expand Up @@ -305,6 +306,131 @@ where
Ok(graph)
}

/// Generate a graph from the stochastic block model.
///
/// The stochastic block model is a generalization of the G<sub>np</sub> random graph
/// (see [gnp_random_graph] ). The connection probability of
/// nodes `u` and `v` depends on their block and is given by
/// `probabilities[blocks[u]][blocks[v]]`, where `blocks[u]` is the block membership
/// of vertex `u`. The number of nodes and the number of blocks are inferred from
/// `sizes`.
///
/// Arguments:
///
/// * `sizes` - Number of nodes in each block.
/// * `probabilities` - B x B array that contains the connection probability between
/// nodes of different blocks. Must be symmetric for undirected graphs.
/// * `loops` - Determines whether the graph can have loops or not.
/// * `seed` - An optional seed to use for the random number generator.
/// * `default_node_weight` - A callable that will return the weight to use
/// for newly created nodes.
/// * `default_edge_weight` - A callable that will return the weight object
/// to use for newly created edges.
///
/// # Example
/// ```rust
/// use ndarray::arr2;
/// use rustworkx_core::petgraph;
/// use rustworkx_core::generators::sbm_random_graph;
///
/// let g = sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
/// &vec![1, 2],
/// &ndarray::arr2(&[[0., 1.], [0., 1.]]).view(),
/// true,
/// Some(10),
/// || (),
/// || (),
/// )
/// .unwrap();
/// assert_eq!(g.node_count(), 3);
/// assert_eq!(g.edge_count(), 6);
/// ```
pub fn sbm_random_graph<G, T, F, H, M>(
sizes: &[usize],
probabilities: &ndarray::ArrayView2<f64>,
loops: bool,
seed: Option<u64>,
mut default_node_weight: F,
mut default_edge_weight: H,
) -> Result<G, InvalidInputError>
where
G: Build + Create + Data<NodeWeight = T, EdgeWeight = M> + NodeIndexable + GraphProp,
F: FnMut() -> T,
H: FnMut() -> M,
G::NodeId: Eq + Hash,
{
let num_nodes: usize = sizes.iter().sum();
if num_nodes == 0 {
return Err(InvalidInputError {});
}
let num_communities = sizes.len();
if probabilities.nrows() != num_communities
|| probabilities.ncols() != num_communities
|| probabilities.iter().any(|&x| !(0. ..=1.).contains(&x))
{
return Err(InvalidInputError {});
}

let mut graph = G::with_capacity(num_nodes, num_nodes);
let directed = graph.is_directed();
if !directed && !symmetric_array(probabilities) {
return Err(InvalidInputError {});
}

for _ in 0..num_nodes {
graph.add_node(default_node_weight());
}
let mut rng: Pcg64 = match seed {
Some(seed) => Pcg64::seed_from_u64(seed),
None => Pcg64::from_entropy(),
};
let mut blocks = Vec::new();
{
let mut block = 0;
let mut vertices_left = sizes[0];
for _ in 0..num_nodes {
while vertices_left == 0 {
block += 1;
vertices_left = sizes[block];
}
blocks.push(block);
vertices_left -= 1;
}
}

let between = Uniform::new(0.0, 1.0);
for v in 0..(if directed || loops {
num_nodes
} else {
num_nodes - 1
}) {
for w in ((if directed { 0 } else { v })..num_nodes).filter(|&w| w != v || loops) {
if &between.sample(&mut rng)
< probabilities.get((blocks[v], blocks[w])).unwrap_or(&0_f64)
{
graph.add_edge(
graph.from_index(v),
graph.from_index(w),
default_edge_weight(),
);
}
}
}
Ok(graph)
}

fn symmetric_array<T: std::cmp::PartialEq>(mat: &ArrayView2<T>) -> bool {
let n = mat.nrows();
for (i, row) in mat.rows().into_iter().enumerate().take(n - 1) {
for (j, m_ij) in row.iter().enumerate().skip(i + 1) {
if m_ij != mat.get((j, i)).unwrap() {
return false;
}
}
}
true
}

#[inline]
fn pnorm(x: f64, p: f64) -> f64 {
if p == 1.0 || p == std::f64::INFINITY {
Expand Down Expand Up @@ -749,7 +875,7 @@ mod tests {
use crate::generators::InvalidInputError;
use crate::generators::{
barabasi_albert_graph, gnm_random_graph, gnp_random_graph, hyperbolic_random_graph,
path_graph, random_bipartite_graph, random_geometric_graph,
path_graph, random_bipartite_graph, random_geometric_graph, sbm_random_graph,
};
use crate::petgraph;

Expand Down Expand Up @@ -879,6 +1005,165 @@ mod tests {
};
}

// Test sbm_random_graph
#[test]
fn test_sbm_directed_complete_blocks_loops() {
let g = sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
&vec![1, 2],
&ndarray::arr2(&[[0., 1.], [0., 1.]]).view(),
true,
Some(10),
|| (),
|| (),
)
.unwrap();
assert_eq!(g.node_count(), 3);
assert_eq!(g.edge_count(), 6);
for (u, v) in [(1, 1), (1, 2), (2, 1), (2, 2), (0, 1), (0, 2)] {
assert_eq!(g.contains_edge(u.into(), v.into()), true);
}
assert_eq!(g.contains_edge(1.into(), 0.into()), false);
assert_eq!(g.contains_edge(2.into(), 0.into()), false);
}

#[test]
fn test_sbm_undirected_complete_blocks_loops() {
let g = sbm_random_graph::<petgraph::graph::UnGraph<(), ()>, (), _, _, ()>(
&vec![1, 2],
&ndarray::arr2(&[[0., 1.], [1., 1.]]).view(),
true,
Some(10),
|| (),
|| (),
)
.unwrap();
assert_eq!(g.node_count(), 3);
assert_eq!(g.edge_count(), 5);
for (u, v) in [(1, 1), (1, 2), (2, 2), (0, 1), (0, 2)] {
assert_eq!(g.contains_edge(u.into(), v.into()), true);
}
assert_eq!(g.contains_edge(0.into(), 0.into()), false);
}

#[test]
fn test_sbm_directed_complete_blocks_noloops() {
let g = sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
&vec![1, 2],
&ndarray::arr2(&[[0., 1.], [0., 1.]]).view(),
false,
Some(10),
|| (),
|| (),
)
.unwrap();
assert_eq!(g.node_count(), 3);
assert_eq!(g.edge_count(), 4);
for (u, v) in [(1, 2), (2, 1), (0, 1), (0, 2)] {
assert_eq!(g.contains_edge(u.into(), v.into()), true);
}
assert_eq!(g.contains_edge(1.into(), 0.into()), false);
assert_eq!(g.contains_edge(2.into(), 0.into()), false);
for u in 0..2 {
assert_eq!(g.contains_edge(u.into(), u.into()), false);
}
}

#[test]
fn test_sbm_undirected_complete_blocks_noloops() {
let g = sbm_random_graph::<petgraph::graph::UnGraph<(), ()>, (), _, _, ()>(
&vec![1, 2],
&ndarray::arr2(&[[0., 1.], [1., 1.]]).view(),
false,
Some(10),
|| (),
|| (),
)
.unwrap();
assert_eq!(g.node_count(), 3);
assert_eq!(g.edge_count(), 3);
for (u, v) in [(1, 2), (0, 1), (0, 2)] {
assert_eq!(g.contains_edge(u.into(), v.into()), true);
}
for u in 0..2 {
assert_eq!(g.contains_edge(u.into(), u.into()), false);
}
}

#[test]
fn test_sbm_bad_array_rows_error() {
match sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
&vec![1, 2],
&ndarray::arr2(&[[0., 1.], [1., 1.], [1., 1.]]).view(),
true,
Some(10),
|| (),
|| (),
) {
Ok(_) => panic!("Returned a non-error"),
Err(e) => assert_eq!(e, InvalidInputError),
};
}
#[test]

fn test_sbm_bad_array_cols_error() {
match sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
&vec![1, 2],
&ndarray::arr2(&[[0., 1., 1.], [1., 1., 1.]]).view(),
true,
Some(10),
|| (),
|| (),
) {
Ok(_) => panic!("Returned a non-error"),
Err(e) => assert_eq!(e, InvalidInputError),
};
}

#[test]
fn test_sbm_asymmetric_array_error() {
match sbm_random_graph::<petgraph::graph::UnGraph<(), ()>, (), _, _, ()>(
&vec![1, 2],
&ndarray::arr2(&[[0., 1.], [0., 1.]]).view(),
true,
Some(10),
|| (),
|| (),
) {
Ok(_) => panic!("Returned a non-error"),
Err(e) => assert_eq!(e, InvalidInputError),
};
}

#[test]
fn test_sbm_invalid_probability_error() {
match sbm_random_graph::<petgraph::graph::UnGraph<(), ()>, (), _, _, ()>(
&vec![1, 2],
&ndarray::arr2(&[[0., 1.], [0., -1.]]).view(),
true,
Some(10),
|| (),
|| (),
) {
Ok(_) => panic!("Returned a non-error"),
Err(e) => assert_eq!(e, InvalidInputError),
};
}

#[test]
fn test_sbm_empty_error() {
match sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
&vec![],
&ndarray::arr2(&[[]]).view(),
true,
Some(10),
|| (),
|| (),
) {
Ok(_) => panic!("Returned a non-error"),
Err(e) => assert_eq!(e, InvalidInputError),
};
}

// Test random_geometric_graph

#[test]
Expand Down
2 changes: 2 additions & 0 deletions rustworkx/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ from .rustworkx import directed_gnm_random_graph as directed_gnm_random_graph
from .rustworkx import undirected_gnm_random_graph as undirected_gnm_random_graph
from .rustworkx import directed_gnp_random_graph as directed_gnp_random_graph
from .rustworkx import undirected_gnp_random_graph as undirected_gnp_random_graph
from .rustworkx import directed_sbm_random_graph as directed_sbm_random_graph
from .rustworkx import undirected_sbm_random_graph as undirected_sbm_random_graph
from .rustworkx import random_geometric_graph as random_geometric_graph
from .rustworkx import hyperbolic_random_graph as hyperbolic_random_graph
from .rustworkx import barabasi_albert_graph as barabasi_albert_graph
Expand Down
Loading