-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathfrequency_table_of_tree_distance.hpp
65 lines (61 loc) · 2.3 KB
/
frequency_table_of_tree_distance.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include "centroid_decomposition.hpp"
#include <algorithm>
#include <utility>
#include <vector>
struct frequency_table_of_tree_distance {
std::vector<std::vector<int>> tos;
std::vector<int> cd;
std::vector<std::pair<int, int>> tmp;
std::vector<int> alive;
void _dfs(int now, int prv, int depth) {
// if (int(tmp.size()) <= depth) tmp.resize(depth + 1, 0);
// tmp[depth]++;
tmp.emplace_back(now, depth);
for (auto nxt : tos[now]) {
if (alive[nxt] and nxt != prv) _dfs(nxt, now, depth + 1);
}
}
std::vector<std::pair<int, int>> cnt_dfs(int root) {
return tmp.clear(), _dfs(root, -1, 0), tmp;
}
frequency_table_of_tree_distance(const std::vector<std::vector<int>> &to) {
tos = to;
CentroidDecomposition c(to.size());
for (int i = 0; i < int(to.size()); i++) {
for (int j : to[i]) {
if (i < j) c.add_edge(i, j);
}
}
cd = c.centroid_decomposition(0);
}
template <class S, std::vector<S> (*conv)(const std::vector<S> &, const std::vector<S> &)>
std::vector<S> solve(const std::vector<S> &weight) {
alive.assign(tos.size(), 1);
std::vector<S> ret(tos.size());
std::vector<S> v;
for (auto root : cd) {
std::vector<std::vector<S>> vv;
alive[root] = 0;
for (auto nxt : tos[root]) {
if (!alive[nxt]) continue;
v.clear();
for (auto p : cnt_dfs(nxt)) {
while (int(v.size()) <= p.second) v.push_back(S(0));
v[p.second] += weight[p.first];
}
for (int i = 0; i < int(v.size()); i++) ret[i + 1] += v[i] * weight[root];
vv.emplace_back(v);
}
std::sort(vv.begin(), vv.end(), [&](const std::vector<S> &l, const std::vector<S> &r) {
return l.size() < r.size();
});
for (size_t j = 1; j < vv.size(); j++) {
const std::vector<S> c = conv(vv[j - 1], vv[j]);
for (size_t i = 0; i < c.size(); i++) ret[i + 2] += c[i];
for (size_t i = 0; i < vv[j - 1].size(); i++) vv[j][i] += vv[j - 1][i];
}
tos[root].clear();
}
return ret;
}
};