-
Notifications
You must be signed in to change notification settings - Fork 53
/
Copy pathsnapshot.rs
118 lines (102 loc) · 4.01 KB
/
snapshot.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#[cfg(test)]
mod tests {
use insta::assert_debug_snapshot;
use ndarray;
use ndarray::{Array, Array2, ArrayBase, Dim, Ix, OwnedRepr};
use ndarray_rand::rand::rngs::StdRng;
use ndarray_rand::rand::{RngCore, SeedableRng};
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
use cleora::embedding::{MarkovType, NdArrayMatrix};
use cleora::sparse_matrix::SparseMatrix;
fn round(arr: Array2<f32>) -> Array2<i32> {
arr.map(|v| (v * 1000.) as i32)
}
#[test]
fn test_markov_left_01() {
let (graph, embeddings) = create_graph_embeddings_complex_reflexive();
let embedding_out = NdArrayMatrix::multiply(&graph, embeddings.view(), MarkovType::Left, 8);
let embedding_out = round(embedding_out);
assert_debug_snapshot!(embedding_out);
}
#[test]
fn test_markov_left_02() {
let (graph, embeddings) = create_graph_embeddings_complex_complex();
let embedding_out = NdArrayMatrix::multiply(&graph, embeddings.view(), MarkovType::Left, 8);
let embedding_out = round(embedding_out);
assert_debug_snapshot!(embedding_out);
}
#[test]
fn test_markov_sym_01() {
let (graph, embeddings) = create_graph_embeddings_complex_reflexive();
let embedding_out =
NdArrayMatrix::multiply(&graph, embeddings.view(), MarkovType::Symmetric, 8);
let embedding_out = round(embedding_out);
assert_debug_snapshot!(embedding_out);
}
#[test]
fn test_markov_sym_02() {
let (graph, embeddings) = create_graph_embeddings_complex_complex();
let embedding_out =
NdArrayMatrix::multiply(&graph, embeddings.view(), MarkovType::Symmetric, 8);
let embedding_out = round(embedding_out);
assert_debug_snapshot!(embedding_out);
}
fn create_graph_embeddings_complex_complex(
) -> (SparseMatrix, ArrayBase<OwnedRepr<f32>, Dim<[Ix; 2]>>) {
let num_embeddings: usize = 100;
let mut rng: StdRng = SeedableRng::seed_from_u64(21_37);
let mut edges: Vec<_> = vec![];
for _ in 0..1000 {
let col_1_node_1 = rng.next_u32() % (num_embeddings as u32);
let col_1_node_2 = rng.next_u32() % (num_embeddings as u32);
let col_2_node_1 = rng.next_u32() % (num_embeddings as u32);
let col_2_node_2 = rng.next_u32() % (num_embeddings as u32);
edges.push(format!(
"{} {}\t{} {}",
col_1_node_1, col_1_node_2, col_2_node_1, col_2_node_2
))
}
let edges_ref: Vec<&str> = edges.iter().map(|s| s.as_ref()).collect();
let graph = SparseMatrix::from_rust_iterator(
"complex::entity_a complex::entity_b",
16,
edges_ref.into_iter(),
None,
)
.unwrap();
let feature_dim: usize = 32;
let embeddings = Array::random_using(
(num_embeddings, feature_dim),
Uniform::new(0., 10.),
&mut rng,
);
(graph, embeddings)
}
fn create_graph_embeddings_complex_reflexive(
) -> (SparseMatrix, ArrayBase<OwnedRepr<f32>, Dim<[Ix; 2]>>) {
let num_embeddings: usize = 100;
let mut rng: StdRng = SeedableRng::seed_from_u64(21_37);
let mut edges: Vec<_> = vec![];
for _ in 0..1000 {
let node_a = rng.next_u32() % (num_embeddings as u32);
let node_b = rng.next_u32() % (num_embeddings as u32);
edges.push(format!("{} {}", node_a, node_b))
}
let edges_ref: Vec<&str> = edges.iter().map(|s| s.as_ref()).collect();
let graph = SparseMatrix::from_rust_iterator(
"reflexive::complex::entity_id",
16,
edges_ref.into_iter(),
None,
)
.unwrap();
let feature_dim: usize = 32;
let embeddings = Array::random_using(
(num_embeddings, feature_dim),
Uniform::new(0., 10.),
&mut rng,
);
(graph, embeddings)
}
}