Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stochastic block model generator #1200

Merged
merged 6 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/api/random_graph_generator_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Random Graph Generator Functions
rustworkx.undirected_gnp_random_graph
rustworkx.directed_gnm_random_graph
rustworkx.undirected_gnm_random_graph
rustworkx.directed_sbm_random_graph
rustworkx.undirected_sbm_random_graph
rustworkx.random_geometric_graph
rustworkx.hyperbolic_random_graph
rustworkx.barabasi_albert_graph
Expand Down
9 changes: 9 additions & 0 deletions releasenotes/notes/sbm-random-graph-bf7ccd8e938f4218.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
features:
- |
Adds new random graph generator in rustworkx for the stochastic block model.
There is a generator for directed :func:`.directed_sbm_random_graph` and
undirected graphs :func:`.undirected_sbm_random_graph`.
- |
Adds new function ``sbm_random_graph`` to the rustworkx-core module
``rustworkx_core::generators`` that samples a graph from the stochastic
block model.
Comment on lines +7 to +9
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add drawings of some example graphs if you want, check

:func:`.directed_barabasi_albert_graph` and :func:`.barabasi_albert_graph`,
for an example

1 change: 1 addition & 0 deletions rustworkx-core/src/generators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ pub use random_graph::gnp_random_graph;
pub use random_graph::hyperbolic_random_graph;
pub use random_graph::random_bipartite_graph;
pub use random_graph::random_geometric_graph;
pub use random_graph::sbm_random_graph;
pub use star_graph::star_graph;
308 changes: 307 additions & 1 deletion rustworkx-core/src/generators/random_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,133 @@ where
Ok(graph)
}

/// Generate a graph from the stochastic block model.
///
/// The stochastic block model is a generalization of the G<sub>np</sub> graph model
/// (see [rustworkx_core::generators::gnp_random_graph] ). The connection probability of
/// nodes `u` and `v` depends on their block (or community) and is given by
/// `probabilities[blocks[u]][blocks[v]]`. The number of nodes and the number of blocks
/// are inferred from `blocks`.
///
/// Arguments:
///
/// * `blocks` - Block membership (between 0 and B-1) of each node.
/// * `probabilities` - B x B matrix that contains the connection probability between
/// nodes of different blocks. Must be symmetric for undirected graphs.
/// * `loops` - Determines whether the graph can have loops or not.
/// * `seed` - An optional seed to use for the random number generator.
/// * `default_node_weight` - A callable that will return the weight to use
/// for newly created nodes.
/// * `default_edge_weight` - A callable that will return the weight object
/// to use for newly created edges.
///
/// # Example
/// ```rust
/// use rustworkx_core::petgraph;
/// use rustworkx_core::generators::sbm_random_graph;
///
/// let g = sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
/// &vec![1, 0, 1],
/// &[vec![0., 1.], vec![0., 1.]],
/// true,
/// Some(10),
/// || (),
/// || (),
/// )
/// .unwrap();
/// assert_eq!(g.node_count(), 3);
/// assert_eq!(g.edge_count(), 6);
/// ```
pub fn sbm_random_graph<G, T, F, H, M>(
blocks: &[usize],
probabilities: &[Vec<f64>],
IvanIsCoding marked this conversation as resolved.
Show resolved Hide resolved
loops: bool,
seed: Option<u64>,
mut default_node_weight: F,
mut default_edge_weight: H,
) -> Result<G, InvalidInputError>
where
G: Build + Create + Data<NodeWeight = T, EdgeWeight = M> + NodeIndexable + GraphProp,
F: FnMut() -> T,
H: FnMut() -> M,
G::NodeId: Eq + Hash,
{
let num_nodes = blocks.len();
if num_nodes == 0 {
return Err(InvalidInputError {});
}
let num_communities = probabilities.len();
if probabilities
.iter()
.any(|xs| xs.len() != num_communities || xs.iter().any(|&x| !(0. ..=1.).contains(&x)))
{
return Err(InvalidInputError {});
}
if blocks.iter().max().unwrap_or(&usize::MAX) >= &num_communities {
return Err(InvalidInputError {});
}
if blocks.len() != num_nodes {
return Err(InvalidInputError {});
}

let mut graph = G::with_capacity(num_nodes, num_nodes);
let directed = graph.is_directed();
if !directed && !symmetric_matrix(probabilities) {
return Err(InvalidInputError {});
}

for _ in 0..num_nodes {
graph.add_node(default_node_weight());
}
let mut rng: Pcg64 = match seed {
Some(seed) => Pcg64::seed_from_u64(seed),
None => Pcg64::from_entropy(),
};
let between = Uniform::new(0.0, 1.0);

let mut block_partition: Vec<Vec<usize>> = (0..num_communities).map(|_| Vec::new()).collect();
for (node, block) in blocks.iter().enumerate() {
block_partition[*block].push(node);
}

for (v, &b_v) in blocks.iter().enumerate().take(if directed || loops {
num_nodes
} else {
num_nodes - 1
}) {
for (w, &b_w) in blocks
.iter()
.enumerate()
.skip(if directed { 0 } else { v })
.filter(|&(w, _)| w != v || loops)
{
if between.sample(&mut rng) < probabilities[b_v][b_w] {
graph.add_edge(
graph.from_index(v),
graph.from_index(w),
default_edge_weight(),
);
}
}
}
Ok(graph)
}

fn symmetric_matrix<T: std::cmp::PartialEq>(mat: &[Vec<T>]) -> bool {
let n = mat.len();
for (i, row) in mat.iter().enumerate().take(n - 1) {
if row.len() != n {
return false;
}
for (j, m_ij) in row.iter().enumerate().skip(i + 1) {
if m_ij != &mat[j][i] {
return false;
}
}
}
true
}

#[inline]
fn pnorm(x: f64, p: f64) -> f64 {
if p == 1.0 || p == std::f64::INFINITY {
Expand Down Expand Up @@ -749,7 +876,7 @@ mod tests {
use crate::generators::InvalidInputError;
use crate::generators::{
barabasi_albert_graph, gnm_random_graph, gnp_random_graph, hyperbolic_random_graph,
path_graph, random_bipartite_graph, random_geometric_graph,
path_graph, random_bipartite_graph, random_geometric_graph, sbm_random_graph,
};
use crate::petgraph;

Expand Down Expand Up @@ -879,6 +1006,185 @@ mod tests {
};
}

// Test sbm_random_graph
#[test]
fn test_sbm_directed_complete_blocks() {
let g = sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
&vec![1, 0, 1],
&[vec![0., 1.], vec![0., 1.]],
true,
Some(10),
|| (),
|| (),
)
.unwrap();
assert_eq!(g.node_count(), 3);
assert_eq!(g.edge_count(), 6);
for (u, v) in [(0, 0), (0, 2), (2, 0), (2, 2), (1, 0), (1, 2)] {
assert_eq!(g.contains_edge(u.into(), v.into()), true);
}
assert_eq!(g.contains_edge(0.into(), 1.into()), false);
assert_eq!(g.contains_edge(2.into(), 1.into()), false);
}

#[test]
fn test_sbm_directed_complete_blocks_loops() {
let g = sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
&vec![1, 0, 1],
&[vec![0., 1.], vec![0., 1.]],
true,
Some(10),
|| (),
|| (),
)
.unwrap();
assert_eq!(g.node_count(), 3);
assert_eq!(g.edge_count(), 6);
for (u, v) in [(0, 0), (0, 2), (2, 0), (2, 2), (1, 0), (1, 2)] {
assert_eq!(g.contains_edge(u.into(), v.into()), true);
}
assert_eq!(g.contains_edge(0.into(), 1.into()), false);
assert_eq!(g.contains_edge(2.into(), 1.into()), false);
}

#[test]
fn test_sbm_undirected_complete_blocks_loops() {
let g = sbm_random_graph::<petgraph::graph::UnGraph<(), ()>, (), _, _, ()>(
&vec![1, 0, 1],
&[vec![0., 1.], vec![1., 1.]],
true,
Some(10),
|| (),
|| (),
)
.unwrap();
assert_eq!(g.node_count(), 3);
assert_eq!(g.edge_count(), 5);
for (u, v) in [(0, 0), (0, 2), (2, 2), (1, 0), (1, 2)] {
assert_eq!(g.contains_edge(u.into(), v.into()), true);
}
assert_eq!(g.contains_edge(1.into(), 1.into()), false);
}

#[test]
fn test_sbm_directed_complete_blocks_noloops() {
let g = sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
&vec![1, 0, 1],
&[vec![0., 1.], vec![0., 1.]],
false,
Some(10),
|| (),
|| (),
)
.unwrap();
assert_eq!(g.node_count(), 3);
assert_eq!(g.edge_count(), 4);
for (u, v) in [(0, 2), (2, 0), (1, 0), (1, 2)] {
assert_eq!(g.contains_edge(u.into(), v.into()), true);
}
assert_eq!(g.contains_edge(0.into(), 1.into()), false);
assert_eq!(g.contains_edge(2.into(), 1.into()), false);
for u in 0..2 {
assert_eq!(g.contains_edge(u.into(), u.into()), false);
}
}

#[test]
fn test_sbm_undirected_complete_blocks_noloops() {
let g = sbm_random_graph::<petgraph::graph::UnGraph<(), ()>, (), _, _, ()>(
&vec![1, 0, 1],
&[vec![0., 1.], vec![1., 1.]],
false,
Some(10),
|| (),
|| (),
)
.unwrap();
assert_eq!(g.node_count(), 3);
assert_eq!(g.edge_count(), 3);
for (u, v) in [(0, 2), (1, 0), (1, 2)] {
assert_eq!(g.contains_edge(u.into(), v.into()), true);
}
for u in 0..2 {
assert_eq!(g.contains_edge(u.into(), u.into()), false);
}
}

#[test]
fn test_sbm_block_outofrange_error() {
match sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
&vec![1, 0, 2],
&[vec![0., 1.], vec![1., 1.]],
true,
Some(10),
|| (),
|| (),
) {
Ok(_) => panic!("Returned a non-error"),
Err(e) => assert_eq!(e, InvalidInputError),
};
}

#[test]
fn test_sbm_invalid_matrix_error() {
match sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
&vec![1, 0, 1],
&[vec![0., 1.], vec![1.]],
true,
Some(10),
|| (),
|| (),
) {
Ok(_) => panic!("Returned a non-error"),
Err(e) => assert_eq!(e, InvalidInputError),
};
}

#[test]
fn test_sbm_asymmetric_matrix_error() {
match sbm_random_graph::<petgraph::graph::UnGraph<(), ()>, (), _, _, ()>(
&vec![1, 0, 1],
&[vec![0., 1.], vec![0., 1.]],
true,
Some(10),
|| (),
|| (),
) {
Ok(_) => panic!("Returned a non-error"),
Err(e) => assert_eq!(e, InvalidInputError),
};
}

#[test]
fn test_sbm_invalid_probability_error() {
match sbm_random_graph::<petgraph::graph::UnGraph<(), ()>, (), _, _, ()>(
&vec![1, 0, 1],
&[vec![0., 1.], vec![0., -1.]],
true,
Some(10),
|| (),
|| (),
) {
Ok(_) => panic!("Returned a non-error"),
Err(e) => assert_eq!(e, InvalidInputError),
};
}

#[test]
fn test_sbm_empty_error() {
match sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
&vec![],
&[],
true,
Some(10),
|| (),
|| (),
) {
Ok(_) => panic!("Returned a non-error"),
Err(e) => assert_eq!(e, InvalidInputError),
};
}

// Test random_geometric_graph

#[test]
Expand Down
2 changes: 2 additions & 0 deletions rustworkx/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ from .rustworkx import directed_gnm_random_graph as directed_gnm_random_graph
from .rustworkx import undirected_gnm_random_graph as undirected_gnm_random_graph
from .rustworkx import directed_gnp_random_graph as directed_gnp_random_graph
from .rustworkx import undirected_gnp_random_graph as undirected_gnp_random_graph
from .rustworkx import directed_sbm_random_graph as directed_sbm_random_graph
from .rustworkx import undirected_sbm_random_graph as undirected_sbm_random_graph
from .rustworkx import random_geometric_graph as random_geometric_graph
from .rustworkx import hyperbolic_random_graph as hyperbolic_random_graph
from .rustworkx import barabasi_albert_graph as barabasi_albert_graph
Expand Down
14 changes: 14 additions & 0 deletions rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,20 @@ def undirected_gnp_random_graph(
/,
seed: int | None = ...,
) -> PyGraph: ...
def directed_sbm_random_graph(
blocks: list[int],
probabilities: list[list[float]],
loops: bool,
/,
seed: int | None = ...,
) -> PyDiGraph: ...
def undirected_sbm_random_graph(
blocks: list[int],
probabilities: list[list[float]],
loops: bool,
/,
seed: int | None = ...,
) -> PyGraph: ...
def random_geometric_graph(
num_nodes: int,
radius: float,
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,8 @@ fn rustworkx(py: Python<'_>, m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(undirected_gnp_random_graph))?;
m.add_wrapped(wrap_pyfunction!(directed_gnm_random_graph))?;
m.add_wrapped(wrap_pyfunction!(undirected_gnm_random_graph))?;
m.add_wrapped(wrap_pyfunction!(undirected_sbm_random_graph))?;
m.add_wrapped(wrap_pyfunction!(directed_sbm_random_graph))?;
m.add_wrapped(wrap_pyfunction!(random_geometric_graph))?;
m.add_wrapped(wrap_pyfunction!(hyperbolic_random_graph))?;
m.add_wrapped(wrap_pyfunction!(barabasi_albert_graph))?;
Expand Down
Loading