Skip to content

Commit c64d1f6

Browse files
committed
Fix the new implementation of AVPR
1 parent be95e67 commit c64d1f6

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

core/scores.cpp

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ double average_cluster_reliability(const ugraph_t & graph,
8282
return acr;
8383
}
8484

85-
double avpr(const ugraph_t & graph,
85+
double average_vertex_pairwise_reliability_new(const ugraph_t & graph,
8686
const std::vector<ClusterVertex> & vinfo,
8787
CCSampler & sampler) {
8888
// Instead of looking at the connection probabilities of single
@@ -107,34 +107,46 @@ double avpr(const ugraph_t & graph,
107107
// The vectors that will hold the counts for the clusters, one for
108108
// each thread
109109
std::vector<std::vector<size_t>> t_cluster_counts(n_threads,
110-
std::vector<size_t>(n_clusters));
110+
std::vector<size_t>(n_clusters, 0));
111111
// The samples
112112
const std::vector<CCSampler::component_vector_t> &samples = sampler.get_samples();
113113
const size_t n_samples = samples.size();
114-
114+
115115
// For each sample in parallel, accumulate counts
116116
#pragma omp parallel for
117117
for(size_t sample_idx=0; sample_idx < n_samples; sample_idx++){
118118
const auto & sample = samples[sample_idx];
119119
REQUIRE(sample.size() == n, "Samples are of the wrong size!");
120-
const size_t num_connected_components =
121-
*(std::max_element(sample.cbegin(), sample.cend()));
122-
LOG_INFO("There are " << num_connected_components << " connected components");
120+
std::unordered_map<size_t, size_t> connected_components_ids;
121+
size_t cur_comp_id = 0;
122+
for(const auto cc_id : sample) {
123+
if (connected_components_ids.count(cc_id) == 0) {
124+
connected_components_ids[cc_id] = cur_comp_id++;
125+
}
126+
}
127+
const size_t num_connected_components = connected_components_ids.size();
128+
123129
const auto tid = omp_get_thread_num();
124130
auto& cluster_counts = t_cluster_counts[tid];
125131

126132
// A matrix of `num_clusters` x `num_components elements that
127133
// contains in element (i,j) the number of elements of cluster i
128134
// belonging to the connected component j.
129-
std::vector<std::vector<size_t>> intersection_sizes(
130-
n_clusters, std::vector<size_t>(num_connected_components));
135+
std::vector<std::vector<size_t>> intersection_sizes;
136+
for (size_t i=0; i<n_clusters; i++) {
137+
std::vector<size_t> vec;
138+
for(size_t j=0; j<num_connected_components; j++) {
139+
vec.push_back(0);
140+
}
141+
intersection_sizes.push_back(vec);
142+
}
131143

132144
for (size_t i=0; i<n; i++) {
133145
const size_t cluster_id = cluster_ids[vinfo[i].center()];
134-
const size_t component_id = sample[i];
135-
intersection_sizes[cluster_id][component_id]++;
146+
const size_t component_id = connected_components_ids[sample[i]];
147+
intersection_sizes.at(cluster_id).at(component_id)++;
136148
}
137-
149+
138150
for (size_t cluster_idx=0; cluster_idx<n_clusters; cluster_idx++) {
139151
size_t cnt=0;
140152
const auto & sizes = intersection_sizes[cluster_idx];
@@ -244,14 +256,17 @@ void add_scores(const ugraph_t & graph,
244256
sampler.min_probability(graph, min_p);
245257
LOG_INFO("Computing ACR");
246258
double acr = average_cluster_reliability(graph, clusters, sampler);
259+
LOG_INFO("Computing AVPR with new method");
260+
double avpr_new_method = average_vertex_pairwise_reliability_new(graph, vinfo, sampler);
247261
LOG_INFO("Computing AVPR");
248262
double avpr = average_vertex_pairwise_reliability(graph, clusters, sampler);
249-
263+
250264
LOG_INFO("Clustering with:" <<
251265
"\n\t# clusters = " << num_clusters <<
252266
"\n\tp_min = " << min_p <<
253267
"\n\taverage p = " << avg_p <<
254-
"\n\tavpr = " << avpr <<
268+
"\n\tavpr = " << avpr <<
269+
"\n\tavpr_new = " << avpr_new_method <<
255270
"\n\tacr = " << acr);
256271
EXPERIMENT_APPEND("scores", {{"acr", acr},
257272
{"p_min", min_p},

core/scores.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ double average_vertex_pairwise_reliability(const ugraph_t & graph,
2424
CCSampler & sampler);
2525

2626
/// Computes the Average Vertex Pairwise Reliability
27-
double avpr(const ugraph_t & graph,
27+
double average_vertex_pairwise_reliability_new(const ugraph_t & graph,
2828
const std::vector<ClusterVertex> & vinfo,
2929
CCSampler & sampler);
3030

0 commit comments

Comments
 (0)